Skip to content

Commit

Permalink
#260 Adding error schema in metadata (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jul 4, 2024
1 parent b96e999 commit e3a4c5e
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 49 deletions.
1 change: 1 addition & 0 deletions netspresso/benchmarker/v2/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def benchmark_model(
except Exception as e:
logger.error(f"Benchmark failed. Error: {e}")
benchmarker_metadata.status = Status.ERROR
benchmarker_metadata.update_message(exception_detail=e.args[0])
metadatas[-1] = asdict(benchmarker_metadata)
MetadataHandler.save_json(
data=metadatas,
Expand Down
2 changes: 2 additions & 0 deletions netspresso/compressor/v2/compressor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Optional
from urllib import request
Expand Down Expand Up @@ -626,6 +627,7 @@ def automatic_compression(
except Exception as e:
logger.error(f"Automatic compression failed. Error: {e}")
metadata.update_status(status=Status.ERROR)
metadata.update_message(exception_detail=e.args[0])
MetadataHandler.save_json(data=metadata.asdict(), folder_path=output_dir)
raise e

Expand Down
1 change: 1 addition & 0 deletions netspresso/converter/v2/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def convert_model(
except Exception as e:
logger.error(f"Convert failed. Error: {e}")
converter_metadata.status = Status.ERROR
converter_metadata.update_message(exception_detail=e.args[0])
MetadataHandler.save_json(
data=asdict(converter_metadata), folder_path=output_dir
)
Expand Down
6 changes: 2 additions & 4 deletions netspresso/metadata/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
DeviceName,
HardwareType,
SoftwareVersion,
Status,
TaskType,
)
from netspresso.metadata.common import BaseMetadata


@dataclass
Expand Down Expand Up @@ -44,9 +44,7 @@ class BenchmarkEnvironment:


@dataclass
class BenchmarkerMetadata:
status: Status = Status.IN_PROGRESS
message: str = ""
class BenchmarkerMetadata(BaseMetadata):
task_type: TaskType = TaskType.BENCHMARK
input_model_path: str = ""
benchmark_task_info: BenchmarkTaskInfo = field(default_factory=BenchmarkTaskInfo)
Expand Down
51 changes: 45 additions & 6 deletions netspresso/metadata/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from dataclasses import dataclass, field
from typing import List
import json
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional

from netspresso.enums.device import DeviceName, HardwareType, SoftwareVersion
from netspresso.enums.model import (
DataType,
Framework,
)
from netspresso.enums.metadata import Status
from netspresso.enums.model import DataType, Framework


@dataclass
Expand Down Expand Up @@ -43,3 +42,43 @@ class AvailableOption:
framework: Framework = ""
display_framework: str = ""
devices: List[DeviceInfo] = field(default_factory=list)


@dataclass
class LinkInfo:
type: str
value: str


@dataclass
class AdditionalData:
origin: Optional[str] = ""
error_log: Optional[str] = ""
link: Optional[LinkInfo] = None


@dataclass
class ExceptionDetail:
data: Optional[AdditionalData] = field(default_factory=AdditionalData)
error_code: Optional[str] = ""
name: Optional[str] = ""
message: Optional[str] = ""


@dataclass
class BaseMetadata:
status: Status = Status.IN_PROGRESS
message: ExceptionDetail = field(default_factory=ExceptionDetail)

def asdict(self) -> Dict:
_dict = json.loads(json.dumps(asdict(self)))
return _dict

def update_message(self, exception_detail):
if isinstance(exception_detail, str):
self.message.message = exception_detail
else:
self.message = ExceptionDetail(**exception_detail)

def update_status(self, status: Status):
self.status = status
18 changes: 4 additions & 14 deletions netspresso/metadata/compressor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from netspresso.enums.metadata import Status, TaskType
from netspresso.enums.metadata import TaskType
from netspresso.enums.model import DataType, Framework
from netspresso.metadata.common import AvailableOption, InputShape
from netspresso.metadata.common import AvailableOption, BaseMetadata, InputShape
from netspresso.metadata.trainer import TrainingInfo


Expand Down Expand Up @@ -44,9 +43,7 @@ class Results:


@dataclass
class CompressorMetadata:
status: Status = Status.IN_PROGRESS
message: str = ""
class CompressorMetadata(BaseMetadata):
task_type: TaskType = TaskType.COMPRESS
input_model_path: str = ""
compressed_model_path: str = ""
Expand All @@ -58,13 +55,6 @@ class CompressorMetadata:
results: Results = field(default_factory=Results)
available_options: List[AvailableOption] = field(default_factory=list)

def asdict(self) -> Dict:
_dict = json.loads(json.dumps(asdict(self)))
return _dict

def update_status(self, status: Status):
self.status = status

def update_is_retrainable(self, is_retrainable):
self.is_retrainable = is_retrainable

Expand Down
7 changes: 2 additions & 5 deletions netspresso/metadata/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
DeviceName,
Framework,
SoftwareVersion,
Status,
TaskType,
)
from netspresso.metadata.common import AvailableOption, ModelInfo
from netspresso.metadata.common import AvailableOption, BaseMetadata, ModelInfo


@dataclass
Expand All @@ -29,9 +28,7 @@ class ConvertInfo:


@dataclass
class ConverterMetadata:
status: Status = Status.IN_PROGRESS
message: str = ""
class ConverterMetadata(BaseMetadata):
task_type: TaskType = TaskType.CONVERT
input_model_path: str = ""
converted_model_path: str = ""
Expand Down
18 changes: 4 additions & 14 deletions netspresso/metadata/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from typing import Dict, List

from netspresso.enums.metadata import Status, TaskType
from netspresso.metadata.common import AvailableOption, InputShape
from netspresso.enums.metadata import TaskType
from netspresso.metadata.common import AvailableOption, BaseMetadata, InputShape


@dataclass
Expand All @@ -23,9 +22,7 @@ class TrainingInfo:


@dataclass
class TrainerMetadata:
status: Status = Status.IN_PROGRESS
message: str = ""
class TrainerMetadata(BaseMetadata):
task_type: TaskType = TaskType.TRAIN
logging_dir: str = ""
best_fx_model_path: str = ""
Expand All @@ -36,13 +33,6 @@ class TrainerMetadata:
traning_result: Dict = field(default_factory=dict)
available_options: List[AvailableOption] = field(default_factory=list)

def asdict(self) -> Dict:
_dict = json.loads(json.dumps(asdict(self)))
return _dict

def update_status(self, status: Status):
self.status = status

def update_model_info(self, task, model, dataset, input_shapes):
self.model_info.task = task
self.model_info.model = model
Expand Down
19 changes: 13 additions & 6 deletions netspresso/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def _change_transforms(self, transforms: Transform):

if field_type == List:
transform.size = [self.img_size, self.img_size]
elif field_type == int:
elif isinstance(field_type, int):
transform.size = self.img_size

return transforms
Expand All @@ -389,6 +389,14 @@ def _get_available_options(self):

return available_options

def _check_status(self, training_summary):
if training_summary.get("success"):
status = Status.COMPLETED
else:
status = Status.STOPPED if training_summary.get("error_stat", None) is None else Status.ERROR

return status

def train(self, gpus: str, project_name: str, output_dir: Optional[str] = "./outputs") -> TrainerMetadata:
"""Train the model with the specified configuration.
Expand Down Expand Up @@ -444,11 +452,7 @@ def train(self, gpus: str, project_name: str, output_dir: Optional[str] = "./out
logging=configs.logging,
environment=configs.environment,
)
training_summary_path = logging_dir / "training_summary.json"
training_summary = FileHandler.load_json(file_path=training_summary_path)
is_success = training_summary["success"]
status = Status.COMPLETED if is_success else Status.STOPPED

training_summary = FileHandler.load_json(file_path=logging_dir / "training_summary.json")
FileHandler.remove_folder(configs.temp_folder)
logger.info(f"Removed {configs.temp_folder} folder.")

Expand All @@ -459,6 +463,8 @@ def train(self, gpus: str, project_name: str, output_dir: Optional[str] = "./out
best_fx_paths = list(Path(destination_folder).glob("*best_fx.pt"))
best_onnx_paths = list(Path(destination_folder).glob("*best.onnx"))
hparams_path = destination_folder / "hparams.yaml"
status = self._check_status(training_summary)
error_stat = training_summary.get("error_stat", "")

available_options = self._get_available_options()

Expand All @@ -469,6 +475,7 @@ def train(self, gpus: str, project_name: str, output_dir: Optional[str] = "./out
metadata.update_training_result(training_summary=training_summary)
metadata.update_hparams(hparams=hparams_path.resolve().as_posix())
metadata.update_status(status=status)
metadata.update_message(exception_detail=error_stat)
metadata.update_available_options(available_options)

MetadataHandler.save_json(data=metadata.asdict(), folder_path=destination_folder)
Expand Down

0 comments on commit e3a4c5e

Please sign in to comment.