From d0281d15a54d13dcdc61a4221356e7285c207f2a Mon Sep 17 00:00:00 2001 From: Byeongman Lee Date: Sat, 7 Dec 2024 16:57:59 +0900 Subject: [PATCH] #421 Add API for checking project duplication (#434) --- app/api/v1/endpoints/project.py | 23 ++++++++++----------- app/api/v1/endpoints/system.py | 6 +----- app/exceptions/schema.py | 2 +- app/main.py | 1 + app/services/project.py | 12 +++++++++++ netspresso/utils/db/repositories/project.py | 17 +++++++++++++++ 6 files changed, 43 insertions(+), 18 deletions(-) diff --git a/app/api/v1/endpoints/project.py b/app/api/v1/endpoints/project.py index 5ccfd2b3..270a3313 100644 --- a/app/api/v1/endpoints/project.py +++ b/app/api/v1/endpoints/project.py @@ -40,10 +40,15 @@ def create_project( @router.post("/duplicate", response_model=ProjectDuplicationCheckResponse) def check_project_duplication( *, + db: Session = Depends(get_db), + api_key: str = Depends(api_key_header), request_body: ProjectCreate, ) -> ProjectDuplicationCheckResponse: + is_duplicated = project_service.check_project_duplication( + db=db, project_name=request_body.project_name, api_key=api_key + ) - duplication_status = ProjectDuplicationStatus(is_duplicated=False) + duplication_status = ProjectDuplicationStatus(is_duplicated=is_duplicated) return ProjectDuplicationCheckResponse(data=duplication_status) @@ -68,10 +73,7 @@ def get_projects( @router.get("/{project_id}", response_model=ProjectDetailResponse) -def get_project( - *, - project_id: str -) -> ProjectDetailResponse: +def get_project(*, project_id: str) -> ProjectDetailResponse: models = [ ModelSummary( @@ -79,22 +81,19 @@ def get_project( name="yolox_s_test", type="trained_model", status="in_progress", - latest_experiments=ExperimentStatus(convert="not_started", benchmark="not_started") + latest_experiments=ExperimentStatus(convert="not_started", benchmark="not_started"), ), ModelSummary( model_id="3aab6fb0-9852-4794-b668-676c06246564", name="yolox_l_test", type="compressed_model", status="completed", - latest_experiments=ExperimentStatus(convert="completed", benchmark="completed") - ) + latest_experiments=ExperimentStatus(convert="completed", benchmark="completed"), + ), ] project = ProjectDetailPayload( - project_id=str(uuid4()), - project_name="project_test_1", - user_id=str(uuid4()), - models=models + project_id=str(uuid4()), project_name="project_test_1", user_id=str(uuid4()), models=models ) return ProjectDetailResponse(data=project) diff --git a/app/api/v1/endpoints/system.py b/app/api/v1/endpoints/system.py index 78865cf0..88d94d00 100644 --- a/app/api/v1/endpoints/system.py +++ b/app/api/v1/endpoints/system.py @@ -1,10 +1,6 @@ from fastapi import APIRouter -from app.api.v1.schemas.system import ( - GpusInfoResponse, - ServerInfoPayload, - ServerInfoResponse, -) +from app.api.v1.schemas.system import GpusInfoResponse, ServerInfoPayload, ServerInfoResponse from app.services.system import system_service router = APIRouter() diff --git a/app/exceptions/schema.py b/app/exceptions/schema.py index 2d55e977..3bd846c3 100644 --- a/app/exceptions/schema.py +++ b/app/exceptions/schema.py @@ -29,7 +29,7 @@ class Config: { "data": { "origin": Origin.ROUTER, - "error_log": "AttributeError(\"module 'np_compressor_core.torch.pruning' has no attribute 'VBMF'\")" + "error_log": "AttributeError(\"module 'np_compressor_core.torch.pruning' has no attribute 'VBMF'\")", }, "error_code": "CS40020", "name": "NotFoundMethodClassException", diff --git a/app/main.py b/app/main.py index 92011135..e041f547 100644 --- a/app/main.py +++ b/app/main.py @@ -22,6 +22,7 @@ async def http_exception_handler(request: Request, exc: PyNPException): return JSONResponse(status_code=status_code, content=exc.detail) + def make_middleware() -> List[Middleware]: origins = ["*"] middleware = [ diff --git a/app/services/project.py b/app/services/project.py index 79f0014c..5ffe55a7 100644 --- a/app/services/project.py +++ b/app/services/project.py @@ -3,6 +3,7 @@ from app.services.user import user_service from netspresso.netspresso import NetsPresso from netspresso.utils.db.models.project import Project +from netspresso.utils.db.repositories.project import project_repository class ProjectService: @@ -15,5 +16,16 @@ def create_project(self, db: Session, project_name: str, api_key: str) -> Projec return project + def check_project_duplication(self, db: Session, project_name: str, api_key: str) -> bool: + user = user_service.get_user_by_api_key(db=db, api_key=api_key) + + netspresso = NetsPresso(email=user.email, password=user.password) + + is_duplicated = project_repository.is_project_name_duplicated( + db=db, project_name=project_name, user_id=netspresso.user_info.user_id + ) + + return is_duplicated + project_service = ProjectService() diff --git a/netspresso/utils/db/repositories/project.py b/netspresso/utils/db/repositories/project.py index 0bf35d06..fd237289 100644 --- a/netspresso/utils/db/repositories/project.py +++ b/netspresso/utils/db/repositories/project.py @@ -49,5 +49,22 @@ def get_all_by_user_id( order=order, ) + def is_project_name_duplicated(self, db: Session, project_name: str, user_id: str) -> bool: + """ + Check if a project with the same name already exists for the given API key. + + Args: + db (Session): Database session. + project_name (str): The name of the project to check. + user_id (str): The ID of the user to filter the user's projects. + + Returns: + bool: True if the project name exists, False otherwise. + """ + return db.query(self.model).filter( + self.model.project_name == project_name, + self.model.user_id == user_id, + ).first() is not None + project_repository = ProjectRepository(Project)