Skip to content

Commit

Permalink
#418 Add a function to retrieve the list of projects (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Dec 6, 2024
1 parent fda9585 commit fe79b3a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
33 changes: 22 additions & 11 deletions netspresso/netspresso.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

from loguru import logger

Expand Down Expand Up @@ -51,33 +51,44 @@ def create_project(self, project_name: str, project_path: str = "./projects") ->

# Create the main project folder
project_folder_path = Path(project_path) / project_name
project_abs_path = project_folder_path.resolve()

# Check if the project folder already exists
if project_folder_path.exists():
logger.warning(f"Project '{project_name}' already exists at {project_abs_path}.")
logger.warning(f"Project '{project_name}' already exists at {project_folder_path.resolve()}.")
else:
project_folder_path.mkdir(parents=True, exist_ok=True)
project_abs_path = project_folder_path.resolve()

# Create subfolders
for folder in SUB_FOLDERS:
(project_folder_path / folder).mkdir(parents=True, exist_ok=True)

logger.info(f"Project '{project_name}' created at {project_abs_path}.")

try:
with get_db() as db:
project = Project(
project_name=project_name,
user_id=self.user_info.user_id,
project_abs_path=project_abs_path.as_posix(),
)
project = project_repository.save(db=db, model=project)

return project

except Exception as e:
logger.error(f"Failed to save project '{project_name}' to the database: {e}")
raise

def get_projects(self) -> List[Project]:
try:
with get_db() as db:
project = Project(
project_name=project_name,
user_id=self.user_info.user_id,
project_abs_path=project_abs_path.as_posix(),
)
project = project_repository.save(db=db, model=project)
projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id)

return project
return projects

except Exception as e:
logger.error(f"Failed to save project '{project_name}' to the database: {e}")
logger.error(f"Failed to get project list from the database: {e}")
raise

def trainer(
Expand Down
2 changes: 1 addition & 1 deletion netspresso/utils/db/models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ class Project(Base, TimestampMixin):

id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False)
project_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="project"))
project_name = Column(String(30), nullable=False)
project_name = Column(String(30), nullable=False, unique=True)
user_id = Column(String(36), nullable=False)
project_abs_path = Column(String(500), nullable=False)

0 comments on commit fe79b3a

Please sign in to comment.