diff --git a/netspresso/netspresso.py b/netspresso/netspresso.py index 50c08ac1..7d962213 100644 --- a/netspresso/netspresso.py +++ b/netspresso/netspresso.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union from loguru import logger @@ -51,13 +51,13 @@ def create_project(self, project_name: str, project_path: str = "./projects") -> # Create the main project folder project_folder_path = Path(project_path) / project_name - project_abs_path = project_folder_path.resolve() # Check if the project folder already exists if project_folder_path.exists(): - logger.warning(f"Project '{project_name}' already exists at {project_abs_path}.") + logger.warning(f"Project '{project_name}' already exists at {project_folder_path.resolve()}.") else: project_folder_path.mkdir(parents=True, exist_ok=True) + project_abs_path = project_folder_path.resolve() # Create subfolders for folder in SUB_FOLDERS: @@ -65,19 +65,30 @@ def create_project(self, project_name: str, project_path: str = "./projects") -> logger.info(f"Project '{project_name}' created at {project_abs_path}.") + try: + with get_db() as db: + project = Project( + project_name=project_name, + user_id=self.user_info.user_id, + project_abs_path=project_abs_path.as_posix(), + ) + project = project_repository.save(db=db, model=project) + + return project + + except Exception as e: + logger.error(f"Failed to save project '{project_name}' to the database: {e}") + raise + + def get_projects(self) -> List[Project]: try: with get_db() as db: - project = Project( - project_name=project_name, - user_id=self.user_info.user_id, - project_abs_path=project_abs_path.as_posix(), - ) - project = project_repository.save(db=db, model=project) + projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id) - return project + return projects except Exception as e: - logger.error(f"Failed to save project '{project_name}' to the database: {e}") + logger.error(f"Failed to get project list from the database: {e}") raise def trainer( diff --git a/netspresso/utils/db/models/project.py b/netspresso/utils/db/models/project.py index 0233a0dc..68617984 100644 --- a/netspresso/utils/db/models/project.py +++ b/netspresso/utils/db/models/project.py @@ -10,6 +10,6 @@ class Project(Base, TimestampMixin): id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False) project_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="project")) - project_name = Column(String(30), nullable=False) + project_name = Column(String(30), nullable=False, unique=True) user_id = Column(String(36), nullable=False) project_abs_path = Column(String(500), nullable=False)