Skip to content

Commit

Permalink
#417 Add a function to create a project (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Dec 6, 2024
1 parent a668ef4 commit fda9585
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 16 deletions.
Empty file added netspresso/constant/__init__.py
Empty file.
1 change: 1 addition & 0 deletions netspresso/constant/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SUB_FOLDERS = ["Trained models", "Compressed models", "Pretrained models"]
38 changes: 25 additions & 13 deletions netspresso/netspresso.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from netspresso.clients.auth.response_body import UserResponse
from netspresso.clients.tao import TAOTokenHandler
from netspresso.compressor import CompressorV2
from netspresso.constant.project import SUB_FOLDERS
from netspresso.converter import ConverterV2
from netspresso.enums import Task
from netspresso.inferencer.inferencer import CustomInferencer, NPInferencer
from netspresso.quantizer import Quantizer
from netspresso.tao import TAOTrainer
from netspresso.trainer import Trainer
from netspresso.utils.file import FileHandler
from netspresso.utils.db.models.project import Project
from netspresso.utils.db.repositories.project import project_repository
from netspresso.utils.db.session import get_db


class NetsPresso:
Expand Down Expand Up @@ -42,31 +45,40 @@ def get_user(self) -> UserResponse:
)
return user_info

def create_project(self, project_name: str, project_path: str = "./"):
def create_project(self, project_name: str, project_path: str = "./projects") -> Project:
if len(project_name) > 30:
raise ValueError("The project_name can't exceed 30 characters.")

# Create the main project folder
project_folder_path = Path(project_path) / project_name
project_abs_path = project_folder_path.resolve()

# Check if the project folder already exists
if project_folder_path.exists():
logger.info(f"Project '{project_name}' already exists at {project_folder_path.resolve()}.")
logger.warning(f"Project '{project_name}' already exists at {project_abs_path}.")
else:
project_folder_path.mkdir(parents=True, exist_ok=True)

# Subfolder names
subfolders = ["Trainer models", "Compressed models", "Pretrained models"]

# Create subfolders
for folder in subfolders:
for folder in SUB_FOLDERS:
(project_folder_path / folder).mkdir(parents=True, exist_ok=True)

# Create a metadata.json file
metadata_file_path = project_folder_path / "metadata.json"
metadata = {"is_project_folder": True}
logger.info(f"Project '{project_name}' created at {project_abs_path}.")

try:
with get_db() as db:
project = Project(
project_name=project_name,
user_id=self.user_info.user_id,
project_abs_path=project_abs_path.as_posix(),
)
project = project_repository.save(db=db, model=project)

# Write metadata to the json file
FileHandler.save_json(data=metadata, file_path=metadata_file_path)
return project

logger.info(f"Project '{project_name}' created at {project_folder_path.resolve()}.")
except Exception as e:
logger.error(f"Failed to save project '{project_name}' to the database: {e}")
raise

def trainer(
self, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None
Expand Down
2 changes: 2 additions & 0 deletions netspresso/utils/db/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from netspresso.utils.db.models.project import Project
from netspresso.utils.db.models.user import User
from netspresso.utils.db.session import Base, engine

Base.metadata.create_all(engine)


__all__ = [
"Project",
"User",
]
15 changes: 15 additions & 0 deletions netspresso/utils/db/models/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, String

from netspresso.utils.db.generate_uuid import generate_uuid
from netspresso.utils.db.mixins import TimestampMixin
from netspresso.utils.db.session import Base


class Project(Base, TimestampMixin):
__tablename__ = "project"

id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False)
project_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="project"))
project_name = Column(String(30), nullable=False)
user_id = Column(String(36), nullable=False)
project_abs_path = Column(String(500), nullable=False)
53 changes: 53 additions & 0 deletions netspresso/utils/db/repositories/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List, Optional

from sqlalchemy.orm import Session

from netspresso.utils.db.models.project import Project
from netspresso.utils.db.repositories.base import BaseRepository, Order


class ProjectRepository(BaseRepository[Project]):
def get_by_project_id(self, db: Session, project_id: str) -> Optional[Project]:
project = db.query(self.model).filter(self.model.project_id == project_id)

return project

def _get_projects(
self,
db: Session,
condition,
start: Optional[int] = None,
size: Optional[int] = None,
order: Optional[Order] = None,
) -> Optional[List[Project]]:
ordering_func = self.choose_order_func(order)
query = db.query(self.model).filter(condition)

if order:
query = query.order_by(ordering_func(self.model.created_at))

if start is not None and size is not None:
query = query.offset(start).limit(size)

projects = query.all()

return projects

def get_all_by_user_id(
self,
db: Session,
user_id: str,
start: Optional[int] = None,
size: Optional[int] = None,
order: Optional[Order] = None,
) -> Optional[List[Project]]:
return self._get_projects(
db=db,
condition=self.model.user_id == user_id,
start=start,
size=size,
order=order,
)


project_repository = ProjectRepository(Project)
17 changes: 14 additions & 3 deletions netspresso/utils/db/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import contextmanager
from typing import Generator

from loguru import logger
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy_utils import create_database, database_exists
Expand All @@ -21,13 +23,22 @@
Base = declarative_base()


@contextmanager
def get_db() -> Generator:
db = None
try:
db = SessionLocal()
yield db
finally:
db.close()
if db:
db.close()


if not database_exists(engine.url):
create_database(engine.url)
def check_database(engine):
if not database_exists(engine.url):
logger.info("The database did not exist, so it has been created.")
create_database(engine.url)
else:
logger.info("The database has already been created.")

check_database(engine=engine)

0 comments on commit fda9585

Please sign in to comment.