-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
d6211be
commit ffa2b56
Showing
29 changed files
with
812 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from fastapi import APIRouter | ||
|
||
from app.api.v1.endpoints import project, system, user | ||
from app.api.v1.endpoints import model, project, system, user | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(user.router, prefix="/users", tags=["user"]) | ||
api_router.include_router(project.router, prefix="/projects", tags=["project"]) | ||
api_router.include_router(model.router, prefix="/models", tags=["model"]) | ||
api_router.include_router(system.router, prefix="/system", tags=["system"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from fastapi import APIRouter, Depends | ||
from sqlalchemy.orm import Session | ||
|
||
from app.api.deps import api_key_header | ||
from app.api.v1.schemas.model import ExperimentStatus, ExperimentStatusResponse, ModelDetailResponse, ModelsResponse | ||
from app.services.model import model_service | ||
from netspresso.utils.db.session import get_db | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("", response_model=ModelsResponse) | ||
def get_models( | ||
*, | ||
db: Session = Depends(get_db), | ||
api_key: str = Depends(api_key_header), | ||
) -> ModelsResponse: | ||
models = model_service.get_models(db=db, api_key=api_key) | ||
|
||
return ModelsResponse(data=models) | ||
|
||
|
||
@router.get("/{model_id}", response_model=ModelDetailResponse) | ||
def get_model( | ||
*, | ||
model_id: str, | ||
db: Session = Depends(get_db), | ||
api_key: str = Depends(api_key_header), | ||
) -> ModelDetailResponse: | ||
model = model_service.get_model(db=db, model_id=model_id, api_key=api_key) | ||
|
||
return ModelDetailResponse(data=model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from datetime import datetime | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel, ConfigDict, Field, model_validator | ||
|
||
from app.api.v1.schemas.base import ResponseItem, ResponsePaginationItems | ||
from app.api.v1.schemas.train_task import TrainTaskSchema | ||
from netspresso.enums import Status | ||
|
||
|
||
class ModelPayload(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
model_id: str = Field(..., description="The unique identifier for the model.") | ||
name: str = Field(..., description="The name of the model.") | ||
type: str = Field(..., description="The type of the model (e.g., trained_model, compressed_model).") | ||
is_retrainable: bool | ||
status: Status = Field(default=Status.NOT_STARTED, description="The current status of the model.") | ||
train_task_id: str | ||
project_id: str | ||
user_id: str | ||
compress_tasks: Optional[List] = [] | ||
convert_tasks: Optional[List] = [] | ||
benchmark_tasks: Optional[List] = [] | ||
created_at: Optional[datetime] = None | ||
updated_at: Optional[datetime] = None | ||
train_task: TrainTaskSchema = Field(exclude=True) | ||
|
||
@model_validator(mode="after") | ||
def set_status(cls, values): | ||
values.status = values.train_task.status | ||
|
||
return values | ||
|
||
|
||
class ExperimentStatus(BaseModel): | ||
convert: Status = Field(default=Status.NOT_STARTED, description="The status of the conversion experiment.") | ||
benchmark: Status = Field(default=Status.NOT_STARTED, description="The status of the benchmark experiment.") | ||
|
||
|
||
class ExperimentStatusResponse(ResponseItem): | ||
data: ExperimentStatus | ||
|
||
|
||
class ModelDetailResponse(ResponseItem): | ||
data: ModelPayload | ||
|
||
|
||
class ModelsResponse(ResponsePaginationItems): | ||
data: List[ModelPayload] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from datetime import datetime | ||
from typing import Dict, List, Optional | ||
|
||
from pydantic import BaseModel, ConfigDict | ||
from app.api.v1.schemas.base import ResponseItem | ||
|
||
|
||
class AugmentationSchema(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
name: str | ||
parameters: dict | ||
phase: str | ||
hyperparameter_id: int | ||
|
||
|
||
class DatasetSchema(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
name: str | ||
format: str | ||
root_path: str | ||
train_path: str | ||
valid_path: Optional[str] | ||
test_path: Optional[str] | ||
storage_location: str | ||
train_valid_split_ratio: float | ||
id_mapping: Optional[List] | ||
palette: Optional[dict] | ||
|
||
|
||
class HyperparameterSchema(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
epochs: int | ||
batch_size: int | ||
learning_rate: float | ||
optimizer_name: str | ||
optimizer_params: Optional[dict] | ||
scheduler_name: str | ||
scheduler_params: Optional[dict] | ||
augmentations: List[AugmentationSchema] = [] | ||
|
||
|
||
class EnvironmentSchema(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
seed: int | ||
num_workers: int | ||
gpus: str | ||
|
||
|
||
class PerformanceSchema(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
train_losses: dict | ||
valid_losses: dict | ||
train_metrics: dict | ||
valid_metrics: dict | ||
metrics_list: List[str] | ||
primary_metric: str | ||
flops: int | ||
params: int | ||
total_train_time: float | ||
best_epoch: int | ||
last_epoch: int | ||
total_epoch: int | ||
status: str | ||
|
||
|
||
class TrainTaskSchema(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
task_id: str | ||
pretrained_model_name: str | ||
task: str | ||
framework: str | ||
input_shapes: List[Dict] | ||
status: str | ||
error_detail: Optional[Dict] = None | ||
dataset: Optional[DatasetSchema] | ||
hyperparameter: Optional[HyperparameterSchema] | ||
environment: Optional[EnvironmentSchema] | ||
performance: Optional[PerformanceSchema] | ||
created_at: Optional[datetime] = None | ||
updated_at: Optional[datetime] = None | ||
|
||
|
||
class TrainTaskDetailResponse(ResponseItem): | ||
data: TrainTaskSchema |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import List | ||
from sqlalchemy.orm import Session | ||
|
||
from app.services.user import user_service | ||
from app.api.v1.schemas.model import ModelPayload | ||
from netspresso.utils.db.repositories.model import trained_model_repository | ||
|
||
|
||
class ModelService: | ||
def get_models(self, db: Session, api_key: str) -> List[ModelPayload]: | ||
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) | ||
|
||
models = trained_model_repository.get_all_by_user_id(db=db, user_id=netspresso.user_info.user_id) | ||
models = [ModelPayload.model_validate(model) for model in models] | ||
|
||
return models | ||
|
||
def get_model(self, db: Session, model_id: str, api_key: str) -> ModelPayload: | ||
netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) | ||
|
||
model = trained_model_repository.get_by_model_id(db=db, model_id=model_id, user_id=netspresso.user_info.user_id) | ||
model = ModelPayload.model_validate(model) | ||
|
||
return model | ||
|
||
|
||
model_service = ModelService() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
version: "3" | ||
|
||
services: | ||
maria : | ||
image: mariadb:10.6 | ||
container_name: netspresso-mariadb | ||
ports: | ||
- 3306:3306 | ||
volumes: | ||
- ./mariadb/sql-scripts:/docker-entrypoint-initdb.d | ||
- ./mariadb/data:/var/lib/mysql | ||
environment: | ||
MYSQL_ROOT_PASSWORD: netspresso1234! | ||
networks: | ||
- backend | ||
restart: always | ||
|
||
backend: | ||
image: netspresso-backend:latest | ||
container_name: netspresso-backend | ||
ports: | ||
- 80:80 | ||
env_file: | ||
- .env-dev | ||
networks: | ||
- backend | ||
stdin_open: true | ||
tty: true | ||
#command: bash | ||
|
||
networks: | ||
backend: | ||
driver: bridge |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
version: "3.9" | ||
|
||
# docker compose run --service-ports --name netspresso-dev netspresso bash | ||
|
||
services: | ||
netspresso: | ||
build: | ||
context: . | ||
dockerfile: Dockerfile | ||
image: netspresso-backend:latest | ||
container_name: netspresso | ||
ipc: host | ||
ports: | ||
- 50001:80 # backend | ||
- 50002:3306 # db | ||
deploy: | ||
resources: | ||
reservations: | ||
devices: | ||
- driver: nvidia | ||
device_ids: ["0", "1"] # your GPU id(s) | ||
capabilities: [gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from enum import Enum | ||
|
||
|
||
class SubFolder(str, Enum): | ||
TRAINED_MODELS = "Trained models" | ||
COMPRESSED_MODELS = "Compressed models" | ||
PRETRAINED_MODELS = "Pretrained models" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.