Skip to content

Commit

Permalink
#421 Add API for checking project duplication (#434)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Dec 7, 2024
1 parent 6b229f2 commit d0281d1
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 18 deletions.
23 changes: 11 additions & 12 deletions app/api/v1/endpoints/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ def create_project(
@router.post("/duplicate", response_model=ProjectDuplicationCheckResponse)
def check_project_duplication(
*,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
request_body: ProjectCreate,
) -> ProjectDuplicationCheckResponse:
is_duplicated = project_service.check_project_duplication(
db=db, project_name=request_body.project_name, api_key=api_key
)

duplication_status = ProjectDuplicationStatus(is_duplicated=False)
duplication_status = ProjectDuplicationStatus(is_duplicated=is_duplicated)

return ProjectDuplicationCheckResponse(data=duplication_status)

Expand All @@ -68,33 +73,27 @@ def get_projects(


@router.get("/{project_id}", response_model=ProjectDetailResponse)
def get_project(
*,
project_id: str
) -> ProjectDetailResponse:
def get_project(*, project_id: str) -> ProjectDetailResponse:

models = [
ModelSummary(
model_id="9beb6f14-fe8a-4d70-8243-c51f5d7f36f8",
name="yolox_s_test",
type="trained_model",
status="in_progress",
latest_experiments=ExperimentStatus(convert="not_started", benchmark="not_started")
latest_experiments=ExperimentStatus(convert="not_started", benchmark="not_started"),
),
ModelSummary(
model_id="3aab6fb0-9852-4794-b668-676c06246564",
name="yolox_l_test",
type="compressed_model",
status="completed",
latest_experiments=ExperimentStatus(convert="completed", benchmark="completed")
)
latest_experiments=ExperimentStatus(convert="completed", benchmark="completed"),
),
]

project = ProjectDetailPayload(
project_id=str(uuid4()),
project_name="project_test_1",
user_id=str(uuid4()),
models=models
project_id=str(uuid4()), project_name="project_test_1", user_id=str(uuid4()), models=models
)

return ProjectDetailResponse(data=project)
6 changes: 1 addition & 5 deletions app/api/v1/endpoints/system.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from fastapi import APIRouter

from app.api.v1.schemas.system import (
GpusInfoResponse,
ServerInfoPayload,
ServerInfoResponse,
)
from app.api.v1.schemas.system import GpusInfoResponse, ServerInfoPayload, ServerInfoResponse
from app.services.system import system_service

router = APIRouter()
Expand Down
2 changes: 1 addition & 1 deletion app/exceptions/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Config:
{
"data": {
"origin": Origin.ROUTER,
"error_log": "AttributeError(\"module 'np_compressor_core.torch.pruning' has no attribute 'VBMF'\")"
"error_log": "AttributeError(\"module 'np_compressor_core.torch.pruning' has no attribute 'VBMF'\")",
},
"error_code": "CS40020",
"name": "NotFoundMethodClassException",
Expand Down
1 change: 1 addition & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async def http_exception_handler(request: Request, exc: PyNPException):

return JSONResponse(status_code=status_code, content=exc.detail)


def make_middleware() -> List[Middleware]:
origins = ["*"]
middleware = [
Expand Down
12 changes: 12 additions & 0 deletions app/services/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from app.services.user import user_service
from netspresso.netspresso import NetsPresso
from netspresso.utils.db.models.project import Project
from netspresso.utils.db.repositories.project import project_repository


class ProjectService:
Expand All @@ -15,5 +16,16 @@ def create_project(self, db: Session, project_name: str, api_key: str) -> Projec

return project

def check_project_duplication(self, db: Session, project_name: str, api_key: str) -> bool:
user = user_service.get_user_by_api_key(db=db, api_key=api_key)

netspresso = NetsPresso(email=user.email, password=user.password)

is_duplicated = project_repository.is_project_name_duplicated(
db=db, project_name=project_name, user_id=netspresso.user_info.user_id
)

return is_duplicated


project_service = ProjectService()
17 changes: 17 additions & 0 deletions netspresso/utils/db/repositories/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,22 @@ def get_all_by_user_id(
order=order,
)

def is_project_name_duplicated(self, db: Session, project_name: str, user_id: str) -> bool:
"""
Check if a project with the same name already exists for the given API key.
Args:
db (Session): Database session.
project_name (str): The name of the project to check.
user_id (str): The ID of the user to filter the user's projects.
Returns:
bool: True if the project name exists, False otherwise.
"""
return db.query(self.model).filter(
self.model.project_name == project_name,
self.model.user_id == user_id,
).first() is not None


project_repository = ProjectRepository(Project)

0 comments on commit d0281d1

Please sign in to comment.