Skip to content

Commit

Permalink
#438 define the schema for model information (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jan 10, 2025
1 parent d6211be commit ffa2b56
Show file tree
Hide file tree
Showing 29 changed files with 812 additions and 32 deletions.
15 changes: 12 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@ RUN apt-get update && \
RUN python -m pip install --upgrade pip
RUN python -m pip install --no-cache-dir tensorflow-gpu==${TENSORFLOW_VERSION} protobuf==${PROTOBUF_VERSION} && rm -rf /root/.cache/pip

RUN mkdir -p /home/appuser/netspresso
WORKDIR /home/appuser/netspresso
# set environment variables
ENV HOME=/app
ENV APP_PATH=$HOME/pynetspresso

COPY . /home/appuser/netspresso
# locale settings are needed for python uvicorn compatibility
ENV LC_ALL C.UTF-8
ENV LANG C.UTF-8
ENV PYTHONPATH $APP_PATH
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR $APP_PATH

# copy files to docker internal
COPY . $APP_PATH/

RUN pip install -r requirements.txt && rm -rf /root/.cache/pip
3 changes: 2 additions & 1 deletion app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from fastapi import APIRouter

from app.api.v1.endpoints import project, system, user
from app.api.v1.endpoints import model, project, system, user

api_router = APIRouter()
api_router.include_router(user.router, prefix="/users", tags=["user"])
api_router.include_router(project.router, prefix="/projects", tags=["project"])
api_router.include_router(model.router, prefix="/models", tags=["model"])
api_router.include_router(system.router, prefix="/system", tags=["system"])
32 changes: 32 additions & 0 deletions app/api/v1/endpoints/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from app.api.deps import api_key_header
from app.api.v1.schemas.model import ExperimentStatus, ExperimentStatusResponse, ModelDetailResponse, ModelsResponse
from app.services.model import model_service
from netspresso.utils.db.session import get_db

router = APIRouter()


@router.get("", response_model=ModelsResponse)
def get_models(
*,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
) -> ModelsResponse:
models = model_service.get_models(db=db, api_key=api_key)

return ModelsResponse(data=models)


@router.get("/{model_id}", response_model=ModelDetailResponse)
def get_model(
*,
model_id: str,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
) -> ModelDetailResponse:
model = model_service.get_model(db=db, model_id=model_id, api_key=api_key)

return ModelDetailResponse(data=model)
50 changes: 50 additions & 0 deletions app/api/v1/schemas/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datetime import datetime
from typing import List, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator

from app.api.v1.schemas.base import ResponseItem, ResponsePaginationItems
from app.api.v1.schemas.train_task import TrainTaskSchema
from netspresso.enums import Status


class ModelPayload(BaseModel):
model_config = ConfigDict(from_attributes=True)

model_id: str = Field(..., description="The unique identifier for the model.")
name: str = Field(..., description="The name of the model.")
type: str = Field(..., description="The type of the model (e.g., trained_model, compressed_model).")
is_retrainable: bool
status: Status = Field(default=Status.NOT_STARTED, description="The current status of the model.")
train_task_id: str
project_id: str
user_id: str
compress_tasks: Optional[List] = []
convert_tasks: Optional[List] = []
benchmark_tasks: Optional[List] = []
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
train_task: TrainTaskSchema = Field(exclude=True)

@model_validator(mode="after")
def set_status(cls, values):
values.status = values.train_task.status

return values


class ExperimentStatus(BaseModel):
convert: Status = Field(default=Status.NOT_STARTED, description="The status of the conversion experiment.")
benchmark: Status = Field(default=Status.NOT_STARTED, description="The status of the benchmark experiment.")


class ExperimentStatusResponse(ResponseItem):
data: ExperimentStatus


class ModelDetailResponse(ResponseItem):
data: ModelPayload


class ModelsResponse(ResponsePaginationItems):
data: List[ModelPayload]
90 changes: 90 additions & 0 deletions app/api/v1/schemas/train_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from datetime import datetime
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from app.api.v1.schemas.base import ResponseItem


class AugmentationSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

name: str
parameters: dict
phase: str
hyperparameter_id: int


class DatasetSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

name: str
format: str
root_path: str
train_path: str
valid_path: Optional[str]
test_path: Optional[str]
storage_location: str
train_valid_split_ratio: float
id_mapping: Optional[List]
palette: Optional[dict]


class HyperparameterSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

epochs: int
batch_size: int
learning_rate: float
optimizer_name: str
optimizer_params: Optional[dict]
scheduler_name: str
scheduler_params: Optional[dict]
augmentations: List[AugmentationSchema] = []


class EnvironmentSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

seed: int
num_workers: int
gpus: str


class PerformanceSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

train_losses: dict
valid_losses: dict
train_metrics: dict
valid_metrics: dict
metrics_list: List[str]
primary_metric: str
flops: int
params: int
total_train_time: float
best_epoch: int
last_epoch: int
total_epoch: int
status: str


class TrainTaskSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

task_id: str
pretrained_model_name: str
task: str
framework: str
input_shapes: List[Dict]
status: str
error_detail: Optional[Dict] = None
dataset: Optional[DatasetSchema]
hyperparameter: Optional[HyperparameterSchema]
environment: Optional[EnvironmentSchema]
performance: Optional[PerformanceSchema]
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None


class TrainTaskDetailResponse(ResponseItem):
data: TrainTaskSchema
27 changes: 27 additions & 0 deletions app/services/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import List
from sqlalchemy.orm import Session

from app.services.user import user_service
from app.api.v1.schemas.model import ModelPayload
from netspresso.utils.db.repositories.model import trained_model_repository


class ModelService:
def get_models(self, db: Session, api_key: str) -> List[ModelPayload]:
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key)

models = trained_model_repository.get_all_by_user_id(db=db, user_id=netspresso.user_info.user_id)
models = [ModelPayload.model_validate(model) for model in models]

return models

def get_model(self, db: Session, model_id: str, api_key: str) -> ModelPayload:
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key)

model = trained_model_repository.get_by_model_id(db=db, model_id=model_id, user_id=netspresso.user_info.user_id)
model = ModelPayload.model_validate(model)

return model


model_service = ModelService()
5 changes: 4 additions & 1 deletion app/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@


class UserService:
def create_user(self, db: Session, email: str, password: str, api_key: str) -> User:
def create_user(self, db: Session, email: str, password: str, api_key: str, user_id: str) -> User:
user = User(
email=email,
password=password,
api_key=api_key,
user_id=user_id,
)
user = user_repository.save(db=db, model=user)

Expand All @@ -29,11 +30,13 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl
user.api_key = generated_id
user = user_repository.save(db=db, model=user)
else:
netspresso = NetsPresso(email=email, password=password)
user = self.create_user(
db=db,
email=email,
password=password,
api_key=generated_id,
user_id=netspresso.user_info.user_id,
)

api_key = ApiKeyPayload(api_key=user.api_key)
Expand Down
33 changes: 33 additions & 0 deletions docker-compose-backend.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
version: "3"

services:
maria :
image: mariadb:10.6
container_name: netspresso-mariadb
ports:
- 3306:3306
volumes:
- ./mariadb/sql-scripts:/docker-entrypoint-initdb.d
- ./mariadb/data:/var/lib/mysql
environment:
MYSQL_ROOT_PASSWORD: netspresso1234!
networks:
- backend
restart: always

backend:
image: netspresso-backend:latest
container_name: netspresso-backend
ports:
- 80:80
env_file:
- .env-dev
networks:
- backend
stdin_open: true
tty: true
#command: bash

networks:
backend:
driver: bridge
22 changes: 22 additions & 0 deletions docker-compose-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
version: "3.9"

# docker compose run --service-ports --name netspresso-dev netspresso bash

services:
netspresso:
build:
context: .
dockerfile: Dockerfile
image: netspresso-backend:latest
container_name: netspresso
ipc: host
ports:
- 50001:80 # backend
- 50002:3306 # db
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["0", "1"] # your GPU id(s)
capabilities: [gpu]
8 changes: 8 additions & 0 deletions netspresso/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from netspresso.enums import ServiceCredit, ServiceTask, Status
from netspresso.exceptions.common import NotEnoughCreditException
from netspresso.metadata.common import BaseMetadata
from netspresso.utils.db.repositories.project import project_repository
from netspresso.utils.db.session import get_db_session


class NetsPressoBase:
Expand Down Expand Up @@ -49,3 +51,9 @@ def handle_stop(self, metadata: BaseMetadata, task_name: ServiceTask):
logger.error(f"{task_name} task was interrupted by the user.")

return metadata

def get_project(self, project_id):
with get_db_session() as db:
project = project_repository.get_by_project_id(db=db, project_id=project_id)

return project
7 changes: 7 additions & 0 deletions netspresso/enums/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class SubFolder(str, Enum):
TRAINED_MODELS = "Trained models"
COMPRESSED_MODELS = "Compressed models"
PRETRAINED_MODELS = "Pretrained models"
10 changes: 10 additions & 0 deletions netspresso/enums/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@ def to_display_name(cls, name: str) -> str:
"cosine": cls.COSINE_ANNEALING_WARM_RESTARTS_WITH_CUSTOM_WARM_UP,
}
return name_map[name.lower()].value


class StorageLocation(str, Enum):
LOCAL = "local"
STORAGE = "storage"


class AugmentationType(str, Enum):
train = "train"
inference = "inference"
8 changes: 7 additions & 1 deletion netspresso/trainer/augmentations/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Union

Expand All @@ -11,6 +11,12 @@
class Transform:
name: str = MISSING

def to_parameters(self) -> Dict:
"""
Extract all fields except 'name' as parameters.
"""
return {k: v for k, v in asdict(self).items() if k != 'name'}


@dataclass
class AugmentationConfig:
Expand Down
6 changes: 6 additions & 0 deletions netspresso/trainer/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ class BaseOptimizer:
def asdict(self) -> Dict:
return asdict(self)

def to_parameters(self) -> Dict:
"""
Extract all fields except 'name' as parameters.
"""
return {k: v for k, v in asdict(self).items() if k != 'name'}


@dataclass
class Adadelta(BaseOptimizer):
Expand Down
Loading

0 comments on commit ffa2b56

Please sign in to comment.