Skip to content

Commit

Permalink
#255 Update v0.2.2 version of netspresso_trainer (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jun 27, 2024
1 parent a58a126 commit 485e6ca
Show file tree
Hide file tree
Showing 20 changed files with 1,722 additions and 1,126 deletions.
18 changes: 14 additions & 4 deletions netspresso/trainer/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from netspresso.trainer.augmentations.augmentation import (
AugmentationConfig,
CenterCrop,
ClassificationAugmentationConfig,
ColorJitter,
DetectionAugmentationConfig,
Inference,
HSVJitter,
Mixing,
MosaicDetection,
Pad,
PoseTopDownAffine,
RandomCrop,
RandomCutmix,
RandomErasing,
RandomHorizontalFlip,
RandomMixup,
RandomResize,
RandomResizedCrop,
RandomVerticalFlip,
Resize,
SegmentationAugmentationConfig,
Train,
Transform,
TrivialAugmentWide,
)
Expand All @@ -26,6 +31,13 @@


__all__ = [
"CenterCrop",
"HSVJitter",
"Mixing",
"MosaicDetection",
"PoseTopDownAffine",
"RandomErasing",
"RandomResize",
"ColorJitter",
"Pad",
"RandomCrop",
Expand All @@ -36,8 +48,6 @@
"TrivialAugmentWide",
"RandomMixup",
"RandomCutmix",
"Inference",
"Train",
"Transform",
"AugmentationConfig",
"AUGMENTATION_CONFIG_TYPE",
Expand Down
138 changes: 99 additions & 39 deletions netspresso/trainer/augmentations/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,16 @@ class Transform:


@dataclass
class Train:
transforms: Optional[List] = None
mix_transforms: Optional[List] = None


@dataclass
class Inference:
transforms: Optional[List] = None
class AugmentationConfig:
img_size: int = DEFAULT_IMG_SIZE
train: Optional[List] = None
inference: Optional[List] = None


@dataclass
class AugmentationConfig:
img_size: int = DEFAULT_IMG_SIZE
train: Train = field(default_factory=lambda: Train())
inference: Inference = field(default_factory=lambda: Inference())
class CenterCrop(Transform):
name: str = 'centercrop'
size: int = DEFAULT_IMG_SIZE


@dataclass
Expand All @@ -40,6 +35,38 @@ class ColorJitter(Transform):
p: Optional[float] = 0.5


@dataclass
class HSVJitter(Transform):
name: str = "hsvjitter"
h_mag: int = 5
s_mag: int = 30
v_mag: int = 30


@dataclass
class Mixing(Transform):
name: str = "mixing"
mixup: Optional[List[float]] = field(default=None)
cutmix: Optional[List[float]] = field(default=None)
inplace: bool = False


@dataclass
class MosaicDetection(Transform):
name: str = "mosaicdetection"
size: List = field(default_factory=lambda: [DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE])
mosaic_prob: float = 1.0
affine_scale: List = field(default_factory=lambda: [0.5, 1.5])
degrees: float = 10.0
translate: float = 0.1
shear: float = 2.0
enable_mixup: bool = True
mixup_prob: float = 1.0
mixup_scale: List = field(default_factory=lambda: [0.5, 1.5])
fill: int = 114
mosaic_off_epoch: int = 10


@dataclass
class Pad(Transform):
name: str = 'pad'
Expand All @@ -48,19 +75,32 @@ class Pad(Transform):
padding_mode: str = 'constant'


@dataclass
class PoseTopDownAffine(Transform):
name: str = "posetopdownaffine"
scale: List = field(default_factory=lambda: [0.75, 1.25])
scale_prob: float = 1.
translate: float = 0.1
translate_prob: float = 1.
rotation: int = 60
rotation_prob: float = 1.
size: List = field(default_factory=lambda: [DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE])


@dataclass
class RandomCrop(Transform):
name: str = 'randomcrop'
size: int = DEFAULT_IMG_SIZE


@dataclass
class RandomResizedCrop(Transform):
name: str = 'randomresizedcrop'
size: int = DEFAULT_IMG_SIZE
scale: List = field(default_factory=lambda: [0.08, 1.0])
ratio: List = field(default_factory=lambda: [0.75, 1.33])
interpolation: Optional[str] = 'bilinear'
class RandomErasing(Transform):
name: str = "randomerasing"
p: float = 0.5
scale: List = field(default_factory=lambda: [0.02, 0.33])
scale: List = field(default_factory=lambda: [0.3, 3.3])
value: Optional[int] = 0
inplace: bool = False


@dataclass
Expand All @@ -69,6 +109,24 @@ class RandomHorizontalFlip(Transform):
p: float = 0.5


@dataclass
class RandomResize(Transform):
name: str = "randomresize"
base_size: List = field(default_factory=lambda: [256, 256])
stride: int = 32
random_range: int = 4
interpolation: str = "bilinear"


@dataclass
class RandomResizedCrop(Transform):
name: str = 'randomresizedcrop'
size: int = DEFAULT_IMG_SIZE
scale: List = field(default_factory=lambda: [0.08, 1.0])
ratio: List = field(default_factory=lambda: [0.75, 1.33])
interpolation: Optional[str] = 'bilinear'


@dataclass
class RandomVerticalFlip(Transform):
name: str = 'randomverticalflip'
Expand All @@ -81,6 +139,7 @@ class Resize(Transform):
size: List = field(default_factory=lambda: [DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE])
interpolation: Optional[str] = 'bilinear'
max_size: Optional[int] = None
resize_criteria: Optional[int] = None


class TrivialAugmentWide(Transform):
Expand Down Expand Up @@ -109,34 +168,35 @@ class RandomCutmix(Transform):
@dataclass
class ClassificationAugmentationConfig(AugmentationConfig):
img_size: int = 256
train: Train = field(default_factory=lambda: Train(
transforms=[RandomResizedCrop(size=256), RandomHorizontalFlip()],
mix_transforms=[RandomCutmix()]
))
inference: Inference = field(default_factory=lambda: Inference(
transforms=[Resize(size=[256, 256])]
))
train: Optional[List] = field(default_factory=lambda: [
RandomResizedCrop(size=256),
RandomHorizontalFlip(),
Mixing(mixup=[0.25, 1.0])
])
inference: Optional[List] = field(default_factory=lambda: [
Resize(size=[256, 256])
])


@dataclass
class SegmentationAugmentationConfig(AugmentationConfig):
img_size: int = 512
train: Train = field(default_factory=lambda: Train(
transforms=[RandomResizedCrop(size=512), RandomHorizontalFlip(), ColorJitter()],
mix_transforms=None
))
inference: Inference = field(default_factory=lambda: Inference(
transforms=[Resize(size=[512, 512])]
))
train: Optional[List] = field(default_factory=lambda: [
RandomResizedCrop(size=512),
RandomHorizontalFlip(),
ColorJitter()
])
inference: Optional[List] = field(default_factory=lambda: [
Resize(size=[512, 512])
])


@dataclass
class DetectionAugmentationConfig(AugmentationConfig):
img_size: int = 512
train: Train = field(default_factory=lambda: Train(
transforms=[Resize(size=[512, 512])],
mix_transforms=None
))
inference: Inference = field(default_factory=lambda: Inference(
transforms=[Resize(size=[512, 512])],
))
train: Optional[List] = field(default_factory=lambda: [
Resize(size=[512, 512])
])
inference: Optional[List] = field(default_factory=lambda: [
Resize(size=[512, 512])
])
74 changes: 49 additions & 25 deletions netspresso/trainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,87 @@
from netspresso.trainer.models.model import (
CheckpointConfig,
from netspresso.trainer.models.base import CheckpointConfig, ModelConfig
from netspresso.trainer.models.efficientformer import (
ClassificationEfficientFormerModelConfig,
DetectionEfficientFormerModelConfig,
SegmentationEfficientFormerModelConfig,
)
from netspresso.trainer.models.mixnet import (
ClassificationMixNetLargeModelConfig,
ClassificationMixNetMediumModelConfig,
ClassificationMixNetSmallModelConfig,
ClassificationMobileNetV3ModelConfig,
ClassificationMobileViTModelConfig,
ClassificationResNetModelConfig,
ClassificationViTModelConfig,
DetectionEfficientFormerModelConfig,
DetectionMixNetLargeModelConfig,
DetectionMixNetMediumModelConfig,
DetectionMixNetSmallModelConfig,
DetectionMobileNetV3ModelConfig,
DetectionResNetModelConfig,
DetectionYoloXModelConfig,
ModelConfig,
PIDNetModelConfig,
SegmentationEfficientFormerModelConfig,
SegmentationMixNetLargeModelConfig,
SegmentationMixNetMediumModelConfig,
SegmentationMixNetSmallModelConfig,
SegmentationMobileNetV3ModelConfig,
SegmentationResNetModelConfig,
SegmentationSegFormerModelConfig,
)
from netspresso.trainer.models.mobilenetv3 import (
ClassificationMobileNetV3LargeModelConfig,
ClassificationMobileNetV3SmallModelConfig,
DetectionMobileNetV3SmallModelConfig,
SegmentationMobileNetV3SmallModelConfig,
)
from netspresso.trainer.models.mobilevit import ClassificationMobileViTModelConfig
from netspresso.trainer.models.pidnet import PIDNetModelConfig
from netspresso.trainer.models.resnet import (
ClassificationResNet18ModelConfig,
ClassificationResNet34ModelConfig,
ClassificationResNet50ModelConfig,
DetectionResNet50ModelConfig,
SegmentationResNet50ModelConfig,
)
from netspresso.trainer.models.rtmpose import PoseEstimationMobileNetV3SmallModelConfig
from netspresso.trainer.models.segformer import SegmentationSegFormerB0ModelConfig
from netspresso.trainer.models.vit import ClassificationViTTinyModelConfig
from netspresso.trainer.models.yolox import (
DetectionYoloXLModelConfig,
DetectionYoloXMModelConfig,
DetectionYoloXSModelConfig,
DetectionYoloXXModelConfig,
)

CLASSIFICATION_MODELS = {
"EfficientFormer": ClassificationEfficientFormerModelConfig,
"MobileNetV3": ClassificationMobileNetV3ModelConfig,
"MobileNetV3_Small": ClassificationMobileNetV3SmallModelConfig,
"MobileNetV3_Large": ClassificationMobileNetV3LargeModelConfig,
"MobileViT": ClassificationMobileViTModelConfig,
"ResNet": ClassificationResNetModelConfig,
"ViT": ClassificationViTModelConfig,
"ResNet18": ClassificationResNet18ModelConfig,
"ResNet34": ClassificationResNet34ModelConfig,
"ResNet50": ClassificationResNet50ModelConfig,
"ViT_Tiny": ClassificationViTTinyModelConfig,
"MixNetS": ClassificationMixNetSmallModelConfig,
"MixNetM": ClassificationMixNetMediumModelConfig,
"MixNetL": ClassificationMixNetLargeModelConfig,
}

DETECTION_MODELS = {
"EfficientFormer": DetectionEfficientFormerModelConfig,
"YOLOX-S": DetectionYoloXModelConfig,
"ResNet": DetectionResNetModelConfig,
"MobileNetV3": DetectionMobileNetV3ModelConfig,
"MobileNetV3_Small": DetectionMobileNetV3SmallModelConfig,
"YOLOX-S": DetectionYoloXSModelConfig,
"YOLOX-M": DetectionYoloXMModelConfig,
"YOLOX-L": DetectionYoloXLModelConfig,
"YOLOX-X": DetectionYoloXXModelConfig,
"ResNet50": DetectionResNet50ModelConfig,
"MixNetL": DetectionMixNetLargeModelConfig,
"MixNetM": DetectionMixNetMediumModelConfig,
"MixNetS": DetectionMixNetSmallModelConfig,
}

SEGMENTATION_MODELS = {
"EfficientFormer": SegmentationEfficientFormerModelConfig,
"MobileNetV3": SegmentationMobileNetV3ModelConfig,
"ResNet": SegmentationResNetModelConfig,
"SegFormer": SegmentationSegFormerModelConfig,
"MobileNetV3_Small": SegmentationMobileNetV3SmallModelConfig,
"ResNet50": SegmentationResNet50ModelConfig,
"SegFormer-B0": SegmentationSegFormerB0ModelConfig,
"MixNetS": SegmentationMixNetSmallModelConfig,
"MixNetM": SegmentationMixNetMediumModelConfig,
"MixNetL": SegmentationMixNetLargeModelConfig,
"PIDNet": PIDNetModelConfig,
}

POSEESTIMATION_MODELS = {
"MobileNetV3_Small": PoseEstimationMobileNetV3SmallModelConfig,
}


__all__ = [
"CLASSIFICATION_MODELS",
Expand Down
35 changes: 35 additions & 0 deletions netspresso/trainer/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from omegaconf import MISSING


@dataclass
class ArchitectureConfig:
full: Optional[Dict[str, Any]] = None
backbone: Optional[Dict[str, Any]] = None
neck: Optional[Dict[str, Any]] = None
head: Optional[Dict[str, Any]] = None

def __post_init__(self):
assert bool(self.full) != bool(self.backbone), "Only one of full or backbone should be given."


@dataclass
class CheckpointConfig:
use_pretrained: bool = True
load_head: bool = False
path: Optional[Union[Path, str]] = None
fx_model_path: Optional[Union[Path, str]] = None
optimizer_path: Optional[Union[Path, str]] = None


@dataclass
class ModelConfig:
task: str = MISSING
name: str = MISSING
checkpoint: CheckpointConfig = field(default_factory=lambda: CheckpointConfig())
freeze_backbone: bool = False
architecture: ArchitectureConfig = field(default_factory=lambda: ArchitectureConfig())
losses: Optional[List[Dict[str, Any]]] = None
Loading

0 comments on commit 485e6ca

Please sign in to comment.