Skip to content

Commit

Permalink
Add retrieving project list to project service
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle committed Dec 7, 2024
1 parent 0f20365 commit d44d6b3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
22 changes: 10 additions & 12 deletions app/api/v1/endpoints/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.orm import Session

from app.api.deps import api_key_header
from app.api.v1.schemas.base import Order
from app.api.v1.schemas.project import (
ExperimentStatus,
ModelSummary,
Expand Down Expand Up @@ -56,20 +57,17 @@ def check_project_duplication(
@router.get("", response_model=ProjectsResponse)
def get_projects(
*,
user_id: str,
skip: Optional[int] = 0,
limit: Optional[int] = 100,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
start: Optional[int] = 0,
size: Optional[int] = 10,
order: Order = Order.DESC.value,
) -> ProjectsResponse:
projects = project_service.get_projects(db=db, start=start, size=size, order=order, api_key=api_key)
projects = [ProjectSummaryPayload.model_validate(project) for project in projects]
total_count = project_service.count_project_by_user_id(db=db, api_key=api_key)

projects = [
ProjectSummaryPayload(
project_id=str(uuid4()),
project_name="project_test_1",
user_id=str(uuid4()),
)
]

return ProjectsResponse(data=projects)
return ProjectsResponse(data=projects, result_count=len(projects), total_count=total_count)


@router.get("/{project_id}", response_model=ProjectDetailResponse)
Expand Down
28 changes: 21 additions & 7 deletions app/services/project.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
from typing import List, Optional

from sqlalchemy.orm import Session

from app.services.user import user_service
from netspresso.netspresso import NetsPresso
from netspresso.utils.db.models.project import Project
from netspresso.utils.db.repositories.base import Order
from netspresso.utils.db.repositories.project import project_repository


class ProjectService:
def create_project(self, db: Session, project_name: str, api_key: str) -> Project:
user = user_service.get_user_by_api_key(db=db, api_key=api_key)

netspresso = NetsPresso(email=user.email, password=user.password)
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key)

project = netspresso.create_project(project_name=project_name)

return project

def check_project_duplication(self, db: Session, project_name: str, api_key: str) -> bool:
user = user_service.get_user_by_api_key(db=db, api_key=api_key)

netspresso = NetsPresso(email=user.email, password=user.password)
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key)

is_duplicated = project_repository.is_project_name_duplicated(
db=db, project_name=project_name, user_id=netspresso.user_info.user_id
)

return is_duplicated

def get_projects(
self, *, db: Session, start: Optional[int], size: Optional[int], order: Order, api_key: str
) -> List[Project]:
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key)

projects = project_repository.get_all_by_user_id(
db=db, user_id=netspresso.user_info.user_id, start=start, size=size, order=order
)

return projects

def count_project_by_user_id(self, *, db: Session, api_key: str) -> int:
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key)

return project_repository.count_by_user_id(db=db, user_id=netspresso.user_info.user_id)


project_service = ProjectService()
8 changes: 8 additions & 0 deletions netspresso/utils/db/repositories/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional

from sqlalchemy import func
from sqlalchemy.orm import Session

from netspresso.utils.db.models.project import Project
Expand Down Expand Up @@ -66,5 +67,12 @@ def is_project_name_duplicated(self, db: Session, project_name: str, user_id: st
self.model.user_id == user_id,
).first() is not None

def count_by_user_id(self, db: Session, user_id: str) -> int:
return (
db.query(func.count(self.model.user_id))
.filter(self.model.user_id == user_id)
.scalar()
)


project_repository = ProjectRepository(Project)

0 comments on commit d44d6b3

Please sign in to comment.