diff --git a/netspresso/trainer/augmentations/__init__.py b/netspresso/trainer/augmentations/__init__.py index add266be..9fa351cb 100644 --- a/netspresso/trainer/augmentations/__init__.py +++ b/netspresso/trainer/augmentations/__init__.py @@ -1,5 +1,9 @@ -from netspresso_trainer.cfg.augmentation import ( +from netspresso.trainer.augmentations.augmentation import ( + AugmentationConfig, + ClassificationAugmentationConfig, ColorJitter, + DetectionAugmentationConfig, + Inference, Pad, RandomCrop, RandomCutmix, @@ -8,9 +12,19 @@ RandomResizedCrop, RandomVerticalFlip, Resize, + SegmentationAugmentationConfig, + Train, + Transform, TrivialAugmentWide, ) +AUGMENTATION_CONFIG_TYPE = { + "classification": ClassificationAugmentationConfig, + "detection": DetectionAugmentationConfig, + "segmentation": SegmentationAugmentationConfig, +} + + __all__ = [ "ColorJitter", "Pad", @@ -22,4 +36,9 @@ "TrivialAugmentWide", "RandomMixup", "RandomCutmix", + "Inference", + "Train", + "Transform", + "AugmentationConfig", + "AUGMENTATION_CONFIG_TYPE", ] diff --git a/netspresso/trainer/augmentations/augmentation.py b/netspresso/trainer/augmentations/augmentation.py new file mode 100644 index 00000000..85a1ae21 --- /dev/null +++ b/netspresso/trainer/augmentations/augmentation.py @@ -0,0 +1,142 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Union + +from omegaconf import MISSING, MissingMandatoryValue + +DEFAULT_IMG_SIZE = 256 + + +@dataclass +class Transform: + name: str = MISSING + + +@dataclass +class Train: + transforms: Optional[List] = None + mix_transforms: Optional[List] = None + + +@dataclass +class Inference: + transforms: Optional[List] = None + + +@dataclass +class AugmentationConfig: + img_size: int = DEFAULT_IMG_SIZE + train: Train = field(default_factory=lambda: Train()) + inference: Inference = field(default_factory=lambda: Inference()) + + +@dataclass +class ColorJitter(Transform): + name: str = 'colorjitter' + brightness: Optional[float] = 0.25 + contrast: Optional[float] = 0.25 + saturation: Optional[float] = 0.25 + hue: Optional[float] = 0.1 + p: Optional[float] = 0.5 + + +@dataclass +class Pad(Transform): + name: str = 'pad' + padding: int = 0 + fill: int = 0 + padding_mode: str = 'constant' + + +@dataclass +class RandomCrop(Transform): + name: str = 'randomcrop' + size: int = DEFAULT_IMG_SIZE + + +@dataclass +class RandomResizedCrop(Transform): + name: str = 'randomresizedcrop' + size: int = DEFAULT_IMG_SIZE + scale: List = field(default_factory=lambda: [0.08, 1.0]) + ratio: List = field(default_factory=lambda: [0.75, 1.33]) + interpolation: Optional[str] = 'bilinear' + + +@dataclass +class RandomHorizontalFlip(Transform): + name: str = 'randomhorizontalflip' + p: float = 0.5 + + +@dataclass +class RandomVerticalFlip(Transform): + name: str = 'randomverticalflip' + p: float = 0.5 + + +@dataclass +class Resize(Transform): + name: str = 'resize' + size: List = field(default_factory=lambda: [DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE]) + interpolation: Optional[str] = 'bilinear' + max_size: Optional[int] = None + + +class TrivialAugmentWide(Transform): + name: str = 'trivialaugmentwide' + num_magnitude_bins: int = 31 + interpolation: str = 'bilinear' + fill: Optional[int] = None + + +@dataclass +class RandomMixup(Transform): + name: str = 'mixup' + alpha: float = 0.2 + p: float = 1.0 + inplace: bool = False + + +@dataclass +class RandomCutmix(Transform): + name: str = 'cutmix' + alpha: float = 1.0 + p: float = 1.0 + inplace: bool = False + + +@dataclass +class ClassificationAugmentationConfig(AugmentationConfig): + img_size: int = 256 + train: Train = field(default_factory=lambda: Train( + transforms=[RandomResizedCrop(size=256), RandomHorizontalFlip()], + mix_transforms=[RandomCutmix()] + )) + inference: Inference = field(default_factory=lambda: Inference( + transforms=[Resize(size=[256, 256])] + )) + + +@dataclass +class SegmentationAugmentationConfig(AugmentationConfig): + img_size: int = 512 + train: Train = field(default_factory=lambda: Train( + transforms=[RandomResizedCrop(size=512), RandomHorizontalFlip(), ColorJitter()], + mix_transforms=None + )) + inference: Inference = field(default_factory=lambda: Inference( + transforms=[Resize(size=[512, 512])] + )) + + +@dataclass +class DetectionAugmentationConfig(AugmentationConfig): + img_size: int = 512 + train: Train = field(default_factory=lambda: Train( + transforms=[Resize(size=[512, 512])], + mix_transforms=None + )) + inference: Inference = field(default_factory=lambda: Inference( + transforms=[Resize(size=[512, 512])], + )) diff --git a/netspresso/trainer/registries/data.py b/netspresso/trainer/data/__init__.py similarity index 60% rename from netspresso/trainer/registries/data.py rename to netspresso/trainer/data/__init__.py index 339692ba..cafc9dfe 100644 --- a/netspresso/trainer/registries/data.py +++ b/netspresso/trainer/data/__init__.py @@ -1,7 +1,10 @@ -from netspresso_trainer.cfg import ( +from netspresso.trainer.data.data import ( + DatasetConfig, + ImageLabelPathConfig, LocalClassificationDatasetConfig, LocalDetectionDatasetConfig, LocalSegmentationDatasetConfig, + PathConfig, ) DATA_CONFIG_TYPE = { @@ -9,3 +12,5 @@ "detection": LocalDetectionDatasetConfig, "segmentation": LocalSegmentationDatasetConfig, } + +__all__ = ["ImageLabelPathConfig", "PathConfig", "DATA_CONFIG_TYPE", "DatasetConfig"] diff --git a/netspresso/trainer/data/data.py b/netspresso/trainer/data/data.py new file mode 100644 index 00000000..4a96f6fb --- /dev/null +++ b/netspresso/trainer/data/data.py @@ -0,0 +1,265 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + +from omegaconf import MISSING, MissingMandatoryValue + +__all__ = [ + "DatasetConfig", + "LocalClassificationDatasetConfig", + "LocalSegmentationDatasetConfig", + "LocalDetectionDatasetConfig", + "HuggingFaceClassificationDatasetConfig", + "HuggingFaceSegmentationDatasetConfig", + "ExampleBeansDataset", + "ExampleChessDataset", + "ExampleCocoyoloDataset", + "ExampleSidewalkDataset", + "ExampleXrayDataset", + "ExampleSidewalkDataset", + "ExampleSkincancerDataset", + "ExampleTrafficsignDataset", + "ExampleVoc12Dataset", + "ExampleVoc12CustomDataset", + "ExampleWikiartDataset", +] + + +@dataclass +class DatasetConfig: + name: str = MISSING + task: str = MISSING + format: str = MISSING # Literal['huggingface', 'local'] + + +@dataclass +class ImageLabelPathConfig: + image: Optional[Union[Path, str]] = None + label: Optional[Union[Path, str]] = None + + +@dataclass +class PathPatternConfig: + image: Optional[str] = None + label: Optional[str] = None + + +@dataclass +class PathConfig: + root: Union[Path, str] = MISSING + train: ImageLabelPathConfig = field(default_factory=lambda: ImageLabelPathConfig(image=MISSING)) + valid: ImageLabelPathConfig = field(default_factory=lambda: ImageLabelPathConfig()) + test: ImageLabelPathConfig = field(default_factory=lambda: ImageLabelPathConfig()) + pattern: PathPatternConfig = field(default_factory=lambda: PathPatternConfig()) + + +@dataclass +class HuggingFaceConfig: + custom_cache_dir: Optional[Union[Path, str]] = None # If None, it follows HF datasets default (.cache/huggingface/datasets) + repo: str = MISSING + subset: Optional[str] = None + features: Dict[str, str] = field(default_factory=lambda: { + "image": "image", "label": "labels" + }) + + +@dataclass +class LocalClassificationDatasetConfig(DatasetConfig): + task: str = "classification" + format: str = "local" + path: PathConfig = field(default_factory=lambda: PathConfig()) + id_mapping: Optional[List[str]] = None + + +@dataclass +class LocalSegmentationDatasetConfig(DatasetConfig): + task: str = "segmentation" + format: str = "local" + path: PathConfig = field(default_factory=lambda: PathConfig()) + label_image_mode: str = "L" + id_mapping: Any = None + pallete: Optional[List[List[int]]] = None + + +@dataclass +class LocalDetectionDatasetConfig(DatasetConfig): + task: str = "detection" + format: str = "local" + path: PathConfig = field(default_factory=lambda: PathConfig()) + id_mapping: Any = None + pallete: Optional[List[List[int]]] = None + + +@dataclass +class HuggingFaceClassificationDatasetConfig(DatasetConfig): + task: str = "classification" + format: str = "huggingface" + metadata: HuggingFaceConfig = field(default_factory=lambda: HuggingFaceConfig( + features={"image": "image", "label": "labels"} + )) + id_mapping: Optional[List[str]] = None + + +@dataclass +class HuggingFaceSegmentationDatasetConfig(DatasetConfig): + task: str = "segmentation" + format: str = "huggingface" + metadata: HuggingFaceConfig = field(default_factory=lambda: HuggingFaceConfig( + features={"image": "pixel_values", "label": "label"} + )) + label_image_mode: str = "L" + id_mapping: Any = None + pallete: Optional[List[List[int]]] = None + + +ExampleBeansDataset = HuggingFaceClassificationDatasetConfig( + name="beans", + metadata=HuggingFaceConfig( + custom_cache_dir=None, + repo="beans", + features={"image": "image", "label": "labels"} + ) +) + +ExampleChessDataset = LocalClassificationDatasetConfig( + name="chess", + path=PathConfig( + root="/DATA/classification-example", + train=ImageLabelPathConfig(image="train"), + valid=ImageLabelPathConfig(image="val"), + ) +) + +ExampleXrayDataset = HuggingFaceClassificationDatasetConfig( + name="chest_xray_classification", + metadata=HuggingFaceConfig( + custom_cache_dir=None, + repo="keremberke/chest-xray-classification", + subset="full", + features={"image": "image", "label": "labels"} + ) +) + +ExampleCocoyoloDataset = LocalDetectionDatasetConfig( + name="coco_for_yolo_model", + path=PathConfig( + root="/DATA/coco", + train=ImageLabelPathConfig(image="images/train2017", label="labels/train2017"), + valid=ImageLabelPathConfig(image="images/train2017", label="labels/train2017"), + pattern=PathPatternConfig(image="([0-9]{12})\\.jpg", label="([0-9]{12})\\.txt"), + ), + id_mapping=[ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' + ] +) + +ExampleSidewalkDataset = HuggingFaceSegmentationDatasetConfig( + name="sidewalk_semantic", + metadata=HuggingFaceConfig( + custom_cache_dir=None, + repo="segments/sidewalk-semantic", + features={"image": "pixel_values", "label": "label"} + ), + label_image_mode="L", + id_mapping=[ + 'unlabeled', 'flat-road', 'flat-sidewalk', 'flat-crosswalk', 'flat-cyclinglane', 'flat-parkingdriveway', + 'flat-railtrack', 'flat-curb', 'human-person', 'human-rider', 'vehicle-car', 'vehicle-truck', 'vehicle-bus', + 'vehicle-tramtrain', 'vehicle-motorcycle', 'vehicle-bicycle', 'vehicle-caravan', 'vehicle-cartrailer', + 'construction-building', 'construction-door', 'construction-wall', 'construction-fenceguardrail', + 'construction-bridge', 'construction-tunnel', 'construction-stairs', 'object-pole', 'object-trafficsign', + 'object-trafficlight', 'nature-vegetation', 'nature-terrain', 'sky', 'void-ground', 'void-dynamic', + 'void-static', 'void-unclear' + ] +) + +ExampleSkincancerDataset = HuggingFaceClassificationDatasetConfig( + name="skin_cancer", + metadata=HuggingFaceConfig( + custom_cache_dir=None, + repo="marmal88/skin_cancer", + features={"image": "image", "label": "dx"} + ) +) + +ExampleTrafficsignDataset = LocalDetectionDatasetConfig( + name="traffic_sign_yolo", + path=PathConfig( + root="../../data/traffic-sign", + train=ImageLabelPathConfig(image="images/train", label="labels/train"), + valid=ImageLabelPathConfig(image="images/val", label="labels/val"), + ), + id_mapping=['prohibitory', 'danger', 'mandatory', 'other'] # class names +) + +ExampleVoc12Dataset = LocalSegmentationDatasetConfig( + name="voc2012", + path=PathConfig( + root="/DATA/VOC12Dataset", + train=ImageLabelPathConfig(image="image/train", label="mask/train"), + valid=ImageLabelPathConfig(image="image/val", label="mask/val"), + ), + label_image_mode="RGB", + id_mapping={ + "(0, 0, 0)": "background", + "(128, 0, 0)": "aeroplane", + "(0, 128, 0)": "bicycle", + "(128, 128, 0)": "bird", + "(0, 0, 128)": "boat", + "(128, 0, 128)": "bottle", + "(0, 128, 128)": "bus", + "(128, 128, 128)": "car", + "(64, 0, 0)": "cat", + "(192, 0, 0)": "chair", + "(64, 128, 0)": "cow", + "(192, 128, 0)": "diningtable", + "(64, 0, 128)": "dog", + "(192, 0, 128)": "horse", + "(64, 128, 128)": "motorbike", + "(192, 128, 128)": "person", + "(0, 64, 0)": "pottedplant", + "(128, 64, 0)": "sheep", + "(0, 192, 0)": "sofa", + "(128, 192, 0)": "train", + "(0, 64, 128)": "tvmonitor", + "(128, 64, 128)": "void" + } +) + +ExampleVoc12CustomDataset = LocalSegmentationDatasetConfig( + name="voc2012", + path=PathConfig( + root="../../data/VOC12Dataset", + train=ImageLabelPathConfig(image="image/train", label="mask/train"), + valid=ImageLabelPathConfig(image="image/val", label="mask/val"), + ), + label_image_mode="L", + id_mapping=[ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', + 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', + 'train', 'tvmonitor' + ], + pallete=[ + [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], + [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], + [128, 192, 0], [0, 64, 128] + ] +) + +ExampleWikiartDataset = HuggingFaceClassificationDatasetConfig( + name="wikiart_artist", + metadata=HuggingFaceConfig( + custom_cache_dir=None, + repo="huggan/wikiart", + subset="full", + features={"image": "image", "label": "artist"} + ) +) diff --git a/netspresso/trainer/registries/model.py b/netspresso/trainer/models/__init__.py similarity index 90% rename from netspresso/trainer/registries/model.py rename to netspresso/trainer/models/__init__.py index 2f3101a2..581fc87d 100644 --- a/netspresso/trainer/registries/model.py +++ b/netspresso/trainer/models/__init__.py @@ -1,4 +1,5 @@ -from netspresso_trainer.cfg.model import ( +from netspresso.trainer.models.model import ( + CheckpointConfig, ClassificationEfficientFormerModelConfig, ClassificationMixNetLargeModelConfig, ClassificationMixNetMediumModelConfig, @@ -14,6 +15,7 @@ DetectionMobileNetV3ModelConfig, DetectionResNetModelConfig, DetectionYoloXModelConfig, + ModelConfig, PIDNetModelConfig, SegmentationEfficientFormerModelConfig, SegmentationMixNetLargeModelConfig, @@ -55,3 +57,12 @@ "MixNetL": SegmentationMixNetLargeModelConfig, "PIDNet": PIDNetModelConfig, } + + +__all__ = [ + "CLASSIFICATION_MODELS", + "DETECTION_MODELS", + "SEGMENTATION_MODELS", + "CheckpointConfig", + "ModelConfig", +] diff --git a/netspresso/trainer/models/model.py b/netspresso/trainer/models/model.py new file mode 100644 index 00000000..f30d7982 --- /dev/null +++ b/netspresso/trainer/models/model.py @@ -0,0 +1,1029 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from omegaconf import MISSING, MissingMandatoryValue + +__all__ = [ + "ModelConfig", + "ClassificationEfficientFormerModelConfig", + "SegmentationEfficientFormerModelConfig", + "DetectionEfficientFormerModelConfig", + "ClassificationMobileNetV3ModelConfig", + "SegmentationMobileNetV3ModelConfig", + "DetectionMobileNetV3ModelConfig", + "ClassificationMobileViTModelConfig", + "PIDNetModelConfig", + "ClassificationResNetModelConfig", + "SegmentationResNetModelConfig", + "DetectionResNetModelConfig", + "SegmentationSegFormerModelConfig", + "ClassificationViTModelConfig", + "DetectionYoloXModelConfig", + "ClassificationMixNetSmallModelConfig", + "ClassificationMixNetMediumModelConfig", + "ClassificationMixNetLargeModelConfig", + "SegmentationMixNetSmallModelConfig", + "SegmentationMixNetMediumModelConfig", + "SegmentationMixNetLargeModelConfig", + "DetectionMixNetSmallModelConfig", + "DetectionMixNetMediumModelConfig", + "DetectionMixNetLargeModelConfig", +] + + +@dataclass +class ArchitectureConfig: + full: Optional[Dict[str, Any]] = None + backbone: Optional[Dict[str, Any]] = None + neck: Optional[Dict[str, Any]] = None + head: Optional[Dict[str, Any]] = None + + def __post_init__(self): + assert bool(self.full) != bool(self.backbone), "Only one of full or backbone should be given." + +@dataclass +class CheckpointConfig: + use_pretrained: bool = True + load_head: bool = False + path: Optional[Union[Path, str]] = None + fx_model_path: Optional[Union[Path, str]] = None + optimizer_path: Optional[Union[Path, str]] = None + +@dataclass +class ModelConfig: + task: str = MISSING + name: str = MISSING + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig()) + load_checkpoint_head: bool = False + fx_model_checkpoint: Optional[Union[Path, str]] = None + resume_optimizer_checkpoint: Optional[Union[Path, str]] = None + freeze_backbone: bool = False + architecture: ArchitectureConfig = field(default_factory=lambda: ArchitectureConfig()) + losses: Optional[List[Dict[str, Any]]] = None + + +@dataclass +class EfficientFormerArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "efficientformer", + "params": { + "num_attention_heads": 8, + "attention_channels": 256, + "attention_dropout_prob": 0., + "attention_value_expansion_ratio": 4, + "ffn_intermediate_ratio": 4, + "ffn_dropout_prob": 0., + "ffn_act_type": 'gelu', + "vit_num": 1, + }, + "stage_params": [ + {"num_blocks": 3, "channels": 48}, + {"num_blocks": 2, "channels": 96}, + {"num_blocks": 6, "channels": 224}, + {"num_blocks": 4, "channels": 448}, + ], + }) + + +@dataclass +class MobileNetV3ArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "mobilenetv3", + "params": None, + "stage_params": [ + { + "in_channels": [16], + "kernel_sizes": [3], + "expanded_channels": [16], + "out_channels": [16], + "use_se": [True], + "act_type": ["relu"], + "stride": [2], + }, + { + "in_channels": [16, 24], + "kernel_sizes": [3, 3], + "expanded_channels": [72, 88], + "out_channels": [24, 24], + "use_se": [False, False], + "act_type": ["relu", "relu"], + "stride": [2, 1], + }, + { + "in_channels": [24, 40, 40, 40, 48], + "kernel_sizes": [5, 5, 5, 5, 5], + "expanded_channels": [96, 240, 240, 120, 144], + "out_channels": [40, 40, 40, 48, 48], + "use_se": [True, True, True, True, True], + "act_type": ["hard_swish", "hard_swish", "hard_swish", "hard_swish", "hard_swish"], + "stride": [2, 1, 1, 1, 1], + }, + { + "in_channels": [48, 96, 96], + "kernel_sizes": [5, 5, 5], + "expanded_channels": [288, 576, 576], + "out_channels": [96, 96, 96], + "use_se": [True, True, True], + "act_type": ["hard_swish", "hard_swish", "hard_swish"], + "stride": [2, 1, 1], + }, + ], + }) + + +@dataclass +class MobileViTArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "mobilevit", + "params": { + "patch_size": 2, + "num_attention_heads": 4, + "attention_dropout_prob": 0.1, + "ffn_dropout_prob": 0.0, + "output_expansion_ratio": 4, + "use_fusion_layer": True, + }, + "stage_params": [ + { + "out_channels": 32, + "block_type": "mv2", + "num_blocks": 1, + "stride": 1, + "ir_expansion_ratio": 4, + }, + { + "out_channels": 64, + "block_type": "mv2", + "num_blocks": 3, + "stride": 2, + "ir_expansion_ratio": 4, + }, + { + "out_channels": 96, + "block_type": "mobilevit", + "num_blocks": 2, + "stride": 2, + "hidden_size": 144, + "intermediate_size": 288, + "dilate": False, + "ir_expansion_ratio": 4, + }, + { + "out_channels": 128, + "block_type": "mobilevit", + "num_blocks": 4, + "stride": 2, + "hidden_size": 192, + "intermediate_size": 384, + "dilate": False, + "ir_expansion_ratio": 4, + }, + { + "out_channels": 160, + "block_type": "mobilevit", + "num_blocks": 3, + "stride": 2, + "hidden_size": 240, + "intermediate_size": 480, + "dilate": False, + "ir_expansion_ratio": 4, + }, + ] + }) + + +@dataclass +class PIDNetArchitectureConfig(ArchitectureConfig): + full: Dict[str, Any] = field(default_factory=lambda: { + "name": "pidnet", + "m": 2, + "n": 3, + "channels": 32, + "ppm_channels": 96, + "head_channels": 128, + }) + + +@dataclass +class ResNetArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "resnet", + "params": { + "block_type": "bottleneck", + "norm_type": "batch_norm", + }, + "stage_params": [ + {"channels": 64, "num_blocks": 3}, + {"channels": 128, "num_blocks": 4, "replace_stride_with_dilation": False}, + {"channels": 256, "num_blocks": 6, "replace_stride_with_dilation": False}, + {"channels": 512, "num_blocks": 3, "replace_stride_with_dilation": False}, + ], + }) + + +@dataclass +class SegFormerArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "mixtransformer", + "params": { + "ffn_intermediate_expansion_ratio": 4, + "ffn_act_type": "gelu", + "ffn_dropout_prob": 0.0, + "attention_dropout_prob": 0.0, + }, + "stage_params": [ + { + "num_blocks": 2, + "sequence_reduction_ratio": 8, + "attention_chananels": 32, + "embedding_patch_sizes": 7, + "embedding_strides": 4, + "num_attention_heads": 1, + }, + { + "num_blocks": 2, + "sequence_reduction_ratio": 4, + "attention_chananels": 64, + "embedding_patch_sizes": 3, + "num_attention_heads": 2, + }, + { + "num_blocks": 2, + "sequence_reduction_ratio": 2, + "attention_chananels": 160, + "embedding_patch_sizes": 3, + "embedding_strides": 2, + "num_attention_heads": 5, + }, + { + "num_blocks": 2, + "sequence_reduction_ratio": 1, + "attention_chananels": 256, + "embedding_patch_sizes": 3, + "embedding_strides": 2, + "num_attention_heads": 8, + }, + ], + }) + + +@dataclass +class ViTArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "vit", + "params": { + "patch_size": 16, + "attention_channels": 192, + "num_blocks": 12, + "num_attention_heads": 3, + "attention_dropout_prob": 0.0, + "ffn_intermediate_channels": 768, + "ffn_dropout_prob": 0.1, + "use_cls_token": True, + "vocab_size": 1000, + }, + "stage_params": None, + }) + + +@dataclass +class MixNetSmallArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "mixnet", + "params": { + "stem_channels": 16, + "wid_mul": 1.0, + "dep_mul": 1.0, + "dropout_rate": 0., + }, + "stage_params": [ + { + "expansion_ratio": [1, 6, 3], + "out_channels": [16, 24, 24], + "num_blocks": [1, 1, 1], + "kernel_sizes": [[3], [3], [3]], + "num_exp_groups": [1, 2, 2], + "num_poi_groups": [1, 2, 2], + "stride": [1, 2, 1], + "act_type": ["relu", "relu", "relu"], + "se_reduction_ratio": [None, None, None], + }, + { + "expansion_ratio": [6, 6], + "out_channels": [40, 40], + "num_blocks": [1, 3], + "kernel_sizes": [[3, 5, 7], [3, 5]], + "num_exp_groups": [1, 2], + "num_poi_groups": [1, 2], + "stride": [2, 1], + "act_type": ["swish", "swish"], + "se_reduction_ratio": [2, 2], + }, + { + "expansion_ratio": [6, 6, 6, 3], + "out_channels": [80, 80, 120, 120], + "num_blocks": [1, 2, 1, 2], + "kernel_sizes": [[3, 5, 7], [3, 5], [3, 5, 7], [3, 5, 7, 9]], + "num_exp_groups": [1, 1, 2, 2], + "num_poi_groups": [2, 2, 2, 2], + "stride": [2, 1, 1, 1], + "act_type": ["swish", "swish", "swish", "swish"], + "se_reduction_ratio": [4, 4, 2, 2], + }, + { + "expansion_ratio": [6, 6], + "out_channels": [200, 200], + "num_blocks": [1, 2], + "kernel_sizes": [[3, 5, 7, 9, 11], [3, 5, 7, 9]], + "num_exp_groups": [1, 1], + "num_poi_groups": [1, 2], + "stride": [2, 1], + "act_type": ["swish", "swish"], + "se_reduction_ratio": [2, 2], + }, + ], + }) + + +@dataclass +class MixNetMediumArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "mixnet", + "params": { + "stem_channels": 24, + "wid_mul": 1.0, + "dep_mul": 1.0, + "dropout_rate": 0., + }, + "stage_params": [ + { + "expansion_ratio": [1, 6, 3], + "out_channels": [24, 32, 32], + "num_blocks": [1, 1, 1], + "kernel_sizes": [[3], [3, 5, 7], [3]], + "num_exp_groups": [1, 2, 2], + "num_poi_groups": [1, 2, 2], + "stride": [1, 2, 1], + "act_type": ["relu", "relu", "relu"], + "se_reduction_ratio": [None, None, None], + }, + { + "expansion_ratio": [6, 6], + "out_channels": [40, 40], + "num_blocks": [1, 3], + "kernel_sizes": [[3, 5, 7, 9], [3, 5]], + "num_exp_groups": [1, 2], + "num_poi_groups": [1, 2], + "stride": [2, 1], + "act_type": ["swish", "swish"], + "se_reduction_ratio": [2, 2], + }, + { + "expansion_ratio": [6, 6, 6, 3], + "out_channels": [80, 80, 120, 120], + "num_blocks": [1, 3, 1, 3], + "kernel_sizes": [[3, 5, 7], [3, 5, 7, 9], [3], [3, 5, 7, 9]], + "num_exp_groups": [1, 2, 1, 2], + "num_poi_groups": [1, 2, 1, 2], + "stride": [2, 1, 1, 1], + "act_type": ["swish", "swish", "swish", "swish"], + "se_reduction_ratio": [4, 4, 2, 2], + }, + { + "expansion_ratio": [6, 6], + "out_channels": [200, 200], + "num_blocks": [1, 3], + "kernel_sizes": [[3, 5, 7, 9], [3, 5, 7, 9]], + "num_exp_groups": [1, 1], + "num_poi_groups": [1, 2], + "stride": [2, 1], + "act_type": ["swish", "swish"], + "se_reduction_ratio": [2, 2], + }, + ], + }) + + +@dataclass +class MixNetLargeArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "mixnet", + "params": { + "stem_channels": 24, + "wid_mul": 1.3, + "dep_mul": 1.0, + "dropout_rate": 0., + }, + "stage_params": [ + { + "expansion_ratio": [1, 6, 3], + "out_channels": [24, 32, 32], + "num_blocks": [1, 1, 1], + "kernel_sizes": [[3], [3, 5, 7], [3]], + "num_exp_groups": [1, 2, 2], + "num_poi_groups": [1, 2, 2], + "stride": [1, 2, 1], + "act_type": ["relu", "relu", "relu"], + "se_reduction_ratio": [None, None, None], + }, + { + "expansion_ratio": [6, 6], + "out_channels": [40, 40], + "num_blocks": [1, 3], + "kernel_sizes": [[3, 5, 7, 9], [3, 5]], + "num_exp_groups": [1, 2], + "num_poi_groups": [1, 2], + "stride": [2, 1], + "act_type": ["swish", "swish"], + "se_reduction_ratio": [2, 2], + }, + { + "expansion_ratio": [6, 6, 6, 3], + "out_channels": [80, 80, 120, 120], + "num_blocks": [1, 3, 1, 3], + "kernel_sizes": [[3, 5, 7], [3, 5, 7, 9], [3], [3, 5, 7, 9]], + "num_exp_groups": [1, 2, 1, 2], + "num_poi_groups": [1, 2, 1, 2], + "stride": [2, 1, 1, 1], + "act_type": ["swish", "swish", "swish", "swish"], + "se_reduction_ratio": [4, 4, 2, 2], + }, + { + "expansion_ratio": [6, 6], + "out_channels": [200, 200], + "num_blocks": [1, 3], + "kernel_sizes": [[3, 5, 7, 9], [3, 5, 7, 9]], + "num_exp_groups": [1, 1], + "num_poi_groups": [1, 2], + "stride": [2, 1], + "act_type": ["swish", "swish"], + "se_reduction_ratio": [2, 2], + }, + ], + }) + + +@dataclass +class CSPDarkNetSmallArchitectureConfig(ArchitectureConfig): + backbone: Dict[str, Any] = field(default_factory=lambda: { + "name": "cspdarknet", + "params": { + "dep_mul": 0.33, + "wid_mul": 0.5, + "act_type": "silu", + }, + "stage_params": None, + }) + + +@dataclass +class ClassificationEfficientFormerModelConfig(ModelConfig): + task: str = "classification" + name: str = "efficientformer_l1" + architecture: ArchitectureConfig = field(default_factory=lambda: EfficientFormerArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class SegmentationEfficientFormerModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "efficientformer_l1" + architecture: ArchitectureConfig = field(default_factory=lambda: EfficientFormerArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class DetectionEfficientFormerModelConfig(ModelConfig): + task: str = "detection" + name: str = "efficientformer_l1" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: EfficientFormerArchitectureConfig( + neck={ + "name": "fpn", + "params": { + "num_outs": 4, + "start_level": 0, + "end_level": -1, + "add_extra_convs": False, + "relu_before_extra_convs": False, + }, + }, + head={ + "name": "anchor_decoupled_head", + "params": { + # Anchor parameters + "anchor_sizes": [[32,], [64,], [128,], [256,]], + "aspect_ratios": [0.5, 1.0, 2.0], + "num_layers": 1, + "norm_type": "batch_norm", + # postprocessor - decode + "topk_candidates": 1000, + "score_thresh": 0.05, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "retinanet_loss", "weight": None}, + ]) + + +@dataclass +class ClassificationMobileNetV3ModelConfig(ModelConfig): + task: str = "classification" + name: str = "mobilenet_v3_small" + architecture: ArchitectureConfig = field(default_factory=lambda: MobileNetV3ArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class SegmentationMobileNetV3ModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "mobilenet_v3_small" + architecture: ArchitectureConfig = field(default_factory=lambda: MobileNetV3ArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class DetectionMobileNetV3ModelConfig(ModelConfig): + task: str = "detection" + name: str = "mobilenet_v3_small" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: MobileNetV3ArchitectureConfig( + neck={ + "name": "fpn", + "params": { + "num_outs": 4, + "start_level": 0, + "end_level": -1, + "add_extra_convs": False, + "relu_before_extra_convs": False, + }, + }, + head={ + "name": "anchor_decoupled_head", + "params": { + # Anchor parameters + "anchor_sizes": [[32,], [64,], [128,], [256,]], + "aspect_ratios": [0.5, 1.0, 2.0], + "num_layers": 1, + "norm_type": "batch_norm", + # postprocessor - decode + "topk_candidates": 1000, + "score_thresh": 0.05, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "retinanet_loss", "weight": None}, + ]) + + +@dataclass +class ClassificationMobileViTModelConfig(ModelConfig): + task: str = "classification" + name: str = "mobilevit_s" + architecture: ArchitectureConfig = field(default_factory=lambda: MobileViTArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class PIDNetModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "pidnet_s" + architecture: ArchitectureConfig = field(default_factory=lambda: PIDNetArchitectureConfig()) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "pidnet_loss", "ignore_index": 255, "weight": None}, + ]) + + +@dataclass +class ClassificationResNetModelConfig(ModelConfig): + task: str = "classification" + name: str = "resnet50" + architecture: ArchitectureConfig = field(default_factory=lambda: ResNetArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class SegmentationResNetModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "resnet50" + architecture: ArchitectureConfig = field(default_factory=lambda: ResNetArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class DetectionResNetModelConfig(ModelConfig): + task: str = "detection" + name: str = "resnet50" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: ResNetArchitectureConfig( + neck={ + "name": "fpn", + "params": { + "num_outs": 4, + "start_level": 0, + "end_level": -1, + "add_extra_convs": False, + "relu_before_extra_convs": False, + }, + }, + head={ + "name": "anchor_decoupled_head", + "params": { + # Anchor parameters + "anchor_sizes": [[32,], [64,], [128,], [256,]], + "aspect_ratios": [0.5, 1.0, 2.0], + "num_layers": 1, + "norm_type": "batch_norm", + # postprocessor - decode + "topk_candidates": 1000, + "score_thresh": 0.05, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "retinanet_loss", "weight": None}, + ]) + + +@dataclass +class SegmentationSegFormerModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "segformer" + architecture: ArchitectureConfig = field(default_factory=lambda: SegFormerArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class ClassificationViTModelConfig(ModelConfig): + task: str = "classification" + name: str = "vit_tiny" + architecture: ArchitectureConfig = field(default_factory=lambda: ViTArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class DetectionYoloXModelConfig(ModelConfig): + task: str = "detection" + name: str = "yolox_s" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: CSPDarkNetSmallArchitectureConfig( + neck={ + "name": "yolopafpn", + "params": { + "dep_mul": 0.33, + "act_type": "silu", + }, + }, + head={ + "name": "anchor_free_decoupled_head", + "params": { + "act_type": "silu", + # postprocessor - decode + "score_thresh": 0.7, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "yolox_loss", "weight": None} + ]) + + +@dataclass +class ClassificationMixNetSmallModelConfig(ModelConfig): + task: str = "classification" + name: str = "mixnet_s" + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetSmallArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class SegmentationMixNetSmallModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "mixnet_s" + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetSmallArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class DetectionMixNetSmallModelConfig(ModelConfig): + task: str = "detection" + name: str = "mixnet_s" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetSmallArchitectureConfig( + neck={ + "name": "fpn", + "params": { + "num_outs": 4, + "start_level": 0, + "end_level": -1, + "add_extra_convs": False, + "relu_before_extra_convs": False, + }, + }, + head={ + "name": "anchor_decoupled_head", + "params": { + # Anchor parameters + "anchor_sizes": [[32,], [64,], [128,], [256,]], + "aspect_ratios": [0.5, 1.0, 2.0], + "num_layers": 1, + "norm_type": "batch_norm", + # postprocessor - decode + "topk_candidates": 1000, + "score_thresh": 0.05, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "retinanet_loss", "weight": None}, + ]) + + +@dataclass +class ClassificationMixNetMediumModelConfig(ModelConfig): + task: str = "classification" + name: str = "mixnet_m" + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetMediumArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class SegmentationMixNetMediumModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "mixnet_m" + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetMediumArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class DetectionMixNetMediumModelConfig(ModelConfig): + task: str = "detection" + name: str = "mixnet_m" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetMediumArchitectureConfig( + neck={ + "name": "fpn", + "params": { + "num_outs": 4, + "start_level": 0, + "end_level": -1, + "add_extra_convs": False, + "relu_before_extra_convs": False, + }, + }, + head={ + "name": "anchor_decoupled_head", + "params": { + # Anchor parameters + "anchor_sizes": [[32,], [64,], [128,], [256,]], + "aspect_ratios": [0.5, 1.0, 2.0], + "num_layers": 1, + "norm_type": "batch_norm", + # postprocessor - decode + "topk_candidates": 1000, + "score_thresh": 0.05, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "retinanet_loss", "weight": None}, + ]) + + +@dataclass +class ClassificationMixNetLargeModelConfig(ModelConfig): + task: str = "classification" + name: str = "mixnet_l" + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetLargeArchitectureConfig( + head={ + "name": "fc", + "params": { + "intermediate_channels": 1024, + "num_layers": 1, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "label_smoothing": 0.1, "weight": None} + ]) + + +@dataclass +class SegmentationMixNetLargeModelConfig(ModelConfig): + task: str = "segmentation" + name: str = "mixnet_l" + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetLargeArchitectureConfig( + head={ + "name": "all_mlp_decoder", + "params": { + "intermediate_channels": 256, + "classifier_dropout_prob": 0., + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "cross_entropy", "ignore_index": 255, "weight": None} + ]) + + +@dataclass +class DetectionMixNetLargeModelConfig(ModelConfig): + task: str = "detection" + name: str = "mixnet_l" + checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig( + load_head=True + )) + architecture: ArchitectureConfig = field(default_factory=lambda: MixNetLargeArchitectureConfig( + neck={ + "name": "fpn", + "params": { + "num_outs": 4, + "start_level": 0, + "end_level": -1, + "add_extra_convs": False, + "relu_before_extra_convs": False, + }, + }, + head={ + "name": "anchor_decoupled_head", + "params": { + # Anchor parameters + "anchor_sizes": [[32,], [64,], [128,], [256,]], + "aspect_ratios": [0.5, 1.0, 2.0], + "num_layers": 1, + "norm_type": "batch_norm", + # postprocessor - decode + "topk_candidates": 1000, + "score_thresh": 0.05, + # postprocessor - nms + "nms_thresh": 0.45, + "class_agnostic": False, + } + } + )) + losses: List[Dict[str, Any]] = field(default_factory=lambda: [ + {"criterion": "retinanet_loss", "weight": None}, + ]) diff --git a/netspresso/trainer/registries/__init__.py b/netspresso/trainer/registries/__init__.py deleted file mode 100644 index cf65c2d4..00000000 --- a/netspresso/trainer/registries/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .augmentation import AUGMENTATION_CONFIG_TYPE -from .data import DATA_CONFIG_TYPE -from .model import CLASSIFICATION_MODELS, DETECTION_MODELS, SEGMENTATION_MODELS -from .training import TRAINING_CONFIG_TYPE - -__all__ = [ - "DATA_CONFIG_TYPE", - "CLASSIFICATION_MODELS", - "DETECTION_MODELS", - "SEGMENTATION_MODELS", - "TRAINING_CONFIG_TYPE", - "AUGMENTATION_CONFIG_TYPE", -] diff --git a/netspresso/trainer/registries/augmentation.py b/netspresso/trainer/registries/augmentation.py deleted file mode 100644 index 01fe0c68..00000000 --- a/netspresso/trainer/registries/augmentation.py +++ /dev/null @@ -1,11 +0,0 @@ -from netspresso_trainer.cfg import ( - ClassificationAugmentationConfig, - DetectionAugmentationConfig, - SegmentationAugmentationConfig, -) - -AUGMENTATION_CONFIG_TYPE = { - "classification": ClassificationAugmentationConfig, - "detection": DetectionAugmentationConfig, - "segmentation": SegmentationAugmentationConfig, -} diff --git a/netspresso/trainer/registries/training.py b/netspresso/trainer/registries/training.py deleted file mode 100644 index c9d9d5e4..00000000 --- a/netspresso/trainer/registries/training.py +++ /dev/null @@ -1,7 +0,0 @@ -from netspresso_trainer.cfg import ClassificationScheduleConfig, DetectionScheduleConfig, SegmentationScheduleConfig - -TRAINING_CONFIG_TYPE = { - "classification": ClassificationScheduleConfig, - "detection": DetectionScheduleConfig, - "segmentation": SegmentationScheduleConfig, -} diff --git a/netspresso/trainer/trainer.py b/netspresso/trainer/trainer.py index 781357e3..26ed988a 100644 --- a/netspresso/trainer/trainer.py +++ b/netspresso/trainer/trainer.py @@ -3,26 +3,23 @@ from loguru import logger from netspresso_trainer import train_with_yaml -from netspresso_trainer.cfg import AugmentationConfig, EnvironmentConfig, LoggingConfig, ModelConfig, ScheduleConfig -from netspresso_trainer.cfg.augmentation import Inference, Train, Transform -from netspresso_trainer.cfg.data import ImageLabelPathConfig, PathConfig -from netspresso_trainer.cfg.model import CheckpointConfig from omegaconf import OmegaConf from netspresso.enums import Status, Task, TaskType - -from ..utils import FileHandler -from ..utils.metadata import MetadataHandler -from ..utils.metadata.default.trainer import InputShape -from .registries import ( - AUGMENTATION_CONFIG_TYPE, +from netspresso.trainer.augmentations import AUGMENTATION_CONFIG_TYPE, AugmentationConfig, Inference, Train, Transform +from netspresso.trainer.data import DATA_CONFIG_TYPE, ImageLabelPathConfig, PathConfig +from netspresso.trainer.models import ( CLASSIFICATION_MODELS, - DATA_CONFIG_TYPE, DETECTION_MODELS, SEGMENTATION_MODELS, - TRAINING_CONFIG_TYPE, + CheckpointConfig, + ModelConfig, ) -from .trainer_configs import TrainerConfigs +from netspresso.trainer.trainer_configs import TrainerConfigs +from netspresso.trainer.training import TRAINING_CONFIG_TYPE, EnvironmentConfig, LoggingConfig, ScheduleConfig +from netspresso.utils import FileHandler +from netspresso.utils.metadata import MetadataHandler +from netspresso.utils.metadata.default.trainer import InputShape class Trainer: diff --git a/netspresso/trainer/trainer_configs.py b/netspresso/trainer/trainer_configs.py index c142d0fb..ff89d76e 100644 --- a/netspresso/trainer/trainer_configs.py +++ b/netspresso/trainer/trainer_configs.py @@ -3,16 +3,13 @@ from pathlib import Path from typing import Union -from netspresso_trainer.cfg import ( - AugmentationConfig, - DatasetConfig, - EnvironmentConfig, - LoggingConfig, - ModelConfig, - ScheduleConfig, -) from omegaconf import OmegaConf +from netspresso.trainer.augmentations import AugmentationConfig +from netspresso.trainer.data import DatasetConfig +from netspresso.trainer.models import ModelConfig +from netspresso.trainer.training import EnvironmentConfig, LoggingConfig, ScheduleConfig + class TrainerConfigs: def __init__( diff --git a/netspresso/trainer/training/__init__.py b/netspresso/trainer/training/__init__.py new file mode 100644 index 00000000..8917ab44 --- /dev/null +++ b/netspresso/trainer/training/__init__.py @@ -0,0 +1,18 @@ +from netspresso.trainer.training.environment import EnvironmentConfig +from netspresso.trainer.training.logging import LoggingConfig +from netspresso.trainer.training.training import ( + ClassificationScheduleConfig, + DetectionScheduleConfig, + ScheduleConfig, + SegmentationScheduleConfig, +) + +TRAINING_CONFIG_TYPE = { + "classification": ClassificationScheduleConfig, + "detection": DetectionScheduleConfig, + "segmentation": SegmentationScheduleConfig, +} + + +__all__ = ["ScheduleConfig", "TRAINING_CONFIG_TYPE", "EnvironmentConfig", "LoggingConfig"] + diff --git a/netspresso/trainer/training/environment.py b/netspresso/trainer/training/environment.py new file mode 100644 index 00000000..04ec85be --- /dev/null +++ b/netspresso/trainer/training/environment.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + + +@dataclass +class EnvironmentConfig: + seed: int = 1 + num_workers: int = 4 + gpus: str = "0" diff --git a/netspresso/trainer/training/logging.py b/netspresso/trainer/training/logging.py new file mode 100644 index 00000000..ef7643d1 --- /dev/null +++ b/netspresso/trainer/training/logging.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Union + +from omegaconf import MISSING, MissingMandatoryValue + + +@dataclass +class LoggingConfig: + project_id: Optional[str] = None + output_dir: Union[Path, str] = "./outputs" + tensorboard: bool = True + csv: bool = False + image: bool = True + stdout: bool = True + save_optimizer_state: bool = True + validation_epoch: int = 10 + save_checkpoint_epoch: Optional[int] = None + + def __post_init__(self): + if self.save_checkpoint_epoch is None: + self.save_checkpoint_epoch = self.validation_epoch diff --git a/netspresso/trainer/training/training.py b/netspresso/trainer/training/training.py new file mode 100644 index 00000000..9f2e8411 --- /dev/null +++ b/netspresso/trainer/training/training.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field +from typing import Dict + +from omegaconf import MISSING, MissingMandatoryValue + + +@dataclass +class ScheduleConfig: + epochs: int = 3 + batch_size: int = 8 + optimizer: Dict = field(default_factory=lambda: { + "name": "adamw", + "lr": 6e-5, + "betas": [0.9, 0.999], + "weight_decay": 0.0005, + }) + scheduler: Dict = field(default_factory=lambda: { + "name": "cosine_no_sgdr", + "warmup_epochs": 5, + "warmup_bias_lr": 1e-5, + "min_lr": 0., + }) + + +@dataclass +class ClassificationScheduleConfig(ScheduleConfig): + batch_size: int = 32 + + +@dataclass +class SegmentationScheduleConfig(ScheduleConfig): + pass + + +@dataclass +class DetectionScheduleConfig(ScheduleConfig): + pass