diff --git a/netspresso/constant/__init__.py b/netspresso/constant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netspresso/constant/project.py b/netspresso/constant/project.py new file mode 100644 index 00000000..aa617277 --- /dev/null +++ b/netspresso/constant/project.py @@ -0,0 +1 @@ +SUB_FOLDERS = ["Trained models", "Compressed models", "Pretrained models"] diff --git a/netspresso/netspresso.py b/netspresso/netspresso.py index 393bd4be..50c08ac1 100644 --- a/netspresso/netspresso.py +++ b/netspresso/netspresso.py @@ -8,13 +8,16 @@ from netspresso.clients.auth.response_body import UserResponse from netspresso.clients.tao import TAOTokenHandler from netspresso.compressor import CompressorV2 +from netspresso.constant.project import SUB_FOLDERS from netspresso.converter import ConverterV2 from netspresso.enums import Task from netspresso.inferencer.inferencer import CustomInferencer, NPInferencer from netspresso.quantizer import Quantizer from netspresso.tao import TAOTrainer from netspresso.trainer import Trainer -from netspresso.utils.file import FileHandler +from netspresso.utils.db.models.project import Project +from netspresso.utils.db.repositories.project import project_repository +from netspresso.utils.db.session import get_db class NetsPresso: @@ -42,31 +45,40 @@ def get_user(self) -> UserResponse: ) return user_info - def create_project(self, project_name: str, project_path: str = "./"): + def create_project(self, project_name: str, project_path: str = "./projects") -> Project: + if len(project_name) > 30: + raise ValueError("The project_name can't exceed 30 characters.") + # Create the main project folder project_folder_path = Path(project_path) / project_name + project_abs_path = project_folder_path.resolve() # Check if the project folder already exists if project_folder_path.exists(): - logger.info(f"Project '{project_name}' already exists at {project_folder_path.resolve()}.") + logger.warning(f"Project '{project_name}' already exists at {project_abs_path}.") else: project_folder_path.mkdir(parents=True, exist_ok=True) - # Subfolder names - subfolders = ["Trainer models", "Compressed models", "Pretrained models"] - # Create subfolders - for folder in subfolders: + for folder in SUB_FOLDERS: (project_folder_path / folder).mkdir(parents=True, exist_ok=True) - # Create a metadata.json file - metadata_file_path = project_folder_path / "metadata.json" - metadata = {"is_project_folder": True} + logger.info(f"Project '{project_name}' created at {project_abs_path}.") + + try: + with get_db() as db: + project = Project( + project_name=project_name, + user_id=self.user_info.user_id, + project_abs_path=project_abs_path.as_posix(), + ) + project = project_repository.save(db=db, model=project) - # Write metadata to the json file - FileHandler.save_json(data=metadata, file_path=metadata_file_path) + return project - logger.info(f"Project '{project_name}' created at {project_folder_path.resolve()}.") + except Exception as e: + logger.error(f"Failed to save project '{project_name}' to the database: {e}") + raise def trainer( self, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None diff --git a/netspresso/utils/db/models/__init__.py b/netspresso/utils/db/models/__init__.py index 9c2d5235..00b2f031 100644 --- a/netspresso/utils/db/models/__init__.py +++ b/netspresso/utils/db/models/__init__.py @@ -1,3 +1,4 @@ +from netspresso.utils.db.models.project import Project from netspresso.utils.db.models.user import User from netspresso.utils.db.session import Base, engine @@ -5,5 +6,6 @@ __all__ = [ + "Project", "User", ] diff --git a/netspresso/utils/db/models/project.py b/netspresso/utils/db/models/project.py new file mode 100644 index 00000000..0233a0dc --- /dev/null +++ b/netspresso/utils/db/models/project.py @@ -0,0 +1,15 @@ +from sqlalchemy import Column, Integer, String + +from netspresso.utils.db.generate_uuid import generate_uuid +from netspresso.utils.db.mixins import TimestampMixin +from netspresso.utils.db.session import Base + + +class Project(Base, TimestampMixin): + __tablename__ = "project" + + id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False) + project_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="project")) + project_name = Column(String(30), nullable=False) + user_id = Column(String(36), nullable=False) + project_abs_path = Column(String(500), nullable=False) diff --git a/netspresso/utils/db/repositories/project.py b/netspresso/utils/db/repositories/project.py new file mode 100644 index 00000000..0bf35d06 --- /dev/null +++ b/netspresso/utils/db/repositories/project.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +from sqlalchemy.orm import Session + +from netspresso.utils.db.models.project import Project +from netspresso.utils.db.repositories.base import BaseRepository, Order + + +class ProjectRepository(BaseRepository[Project]): + def get_by_project_id(self, db: Session, project_id: str) -> Optional[Project]: + project = db.query(self.model).filter(self.model.project_id == project_id) + + return project + + def _get_projects( + self, + db: Session, + condition, + start: Optional[int] = None, + size: Optional[int] = None, + order: Optional[Order] = None, + ) -> Optional[List[Project]]: + ordering_func = self.choose_order_func(order) + query = db.query(self.model).filter(condition) + + if order: + query = query.order_by(ordering_func(self.model.created_at)) + + if start is not None and size is not None: + query = query.offset(start).limit(size) + + projects = query.all() + + return projects + + def get_all_by_user_id( + self, + db: Session, + user_id: str, + start: Optional[int] = None, + size: Optional[int] = None, + order: Optional[Order] = None, + ) -> Optional[List[Project]]: + return self._get_projects( + db=db, + condition=self.model.user_id == user_id, + start=start, + size=size, + order=order, + ) + + +project_repository = ProjectRepository(Project) diff --git a/netspresso/utils/db/session.py b/netspresso/utils/db/session.py index 3769ac0e..b240a36d 100644 --- a/netspresso/utils/db/session.py +++ b/netspresso/utils/db/session.py @@ -1,5 +1,7 @@ +from contextlib import contextmanager from typing import Generator +from loguru import logger from sqlalchemy import create_engine from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy_utils import create_database, database_exists @@ -21,13 +23,22 @@ Base = declarative_base() +@contextmanager def get_db() -> Generator: + db = None try: db = SessionLocal() yield db finally: - db.close() + if db: + db.close() -if not database_exists(engine.url): - create_database(engine.url) +def check_database(engine): + if not database_exists(engine.url): + logger.info("The database did not exist, so it has been created.") + create_database(engine.url) + else: + logger.info("The database has already been created.") + +check_database(engine=engine)