diff --git a/netspresso/netspresso.py b/netspresso/netspresso.py index 3cce27e2..7d962213 100644 --- a/netspresso/netspresso.py +++ b/netspresso/netspresso.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union from loguru import logger @@ -80,6 +80,17 @@ def create_project(self, project_name: str, project_path: str = "./projects") -> logger.error(f"Failed to save project '{project_name}' to the database: {e}") raise + def get_projects(self) -> List[Project]: + try: + with get_db() as db: + projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id) + + return projects + + except Exception as e: + logger.error(f"Failed to get project list from the database: {e}") + raise + def trainer( self, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None ) -> Trainer: