Skip to content

Commit

Permalink
#253 Add dataclass for trainer config (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jun 20, 2024
1 parent 1fb6b1b commit fd8e9e9
Show file tree
Hide file tree
Showing 15 changed files with 1,574 additions and 55 deletions.
21 changes: 20 additions & 1 deletion netspresso/trainer/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from netspresso_trainer.cfg.augmentation import (
from netspresso.trainer.augmentations.augmentation import (
AugmentationConfig,
ClassificationAugmentationConfig,
ColorJitter,
DetectionAugmentationConfig,
Inference,
Pad,
RandomCrop,
RandomCutmix,
Expand All @@ -8,9 +12,19 @@
RandomResizedCrop,
RandomVerticalFlip,
Resize,
SegmentationAugmentationConfig,
Train,
Transform,
TrivialAugmentWide,
)

AUGMENTATION_CONFIG_TYPE = {
"classification": ClassificationAugmentationConfig,
"detection": DetectionAugmentationConfig,
"segmentation": SegmentationAugmentationConfig,
}


__all__ = [
"ColorJitter",
"Pad",
Expand All @@ -22,4 +36,9 @@
"TrivialAugmentWide",
"RandomMixup",
"RandomCutmix",
"Inference",
"Train",
"Transform",
"AugmentationConfig",
"AUGMENTATION_CONFIG_TYPE",
]
142 changes: 142 additions & 0 deletions netspresso/trainer/augmentations/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Union

from omegaconf import MISSING, MissingMandatoryValue

DEFAULT_IMG_SIZE = 256


@dataclass
class Transform:
name: str = MISSING


@dataclass
class Train:
transforms: Optional[List] = None
mix_transforms: Optional[List] = None


@dataclass
class Inference:
transforms: Optional[List] = None


@dataclass
class AugmentationConfig:
img_size: int = DEFAULT_IMG_SIZE
train: Train = field(default_factory=lambda: Train())
inference: Inference = field(default_factory=lambda: Inference())


@dataclass
class ColorJitter(Transform):
name: str = 'colorjitter'
brightness: Optional[float] = 0.25
contrast: Optional[float] = 0.25
saturation: Optional[float] = 0.25
hue: Optional[float] = 0.1
p: Optional[float] = 0.5


@dataclass
class Pad(Transform):
name: str = 'pad'
padding: int = 0
fill: int = 0
padding_mode: str = 'constant'


@dataclass
class RandomCrop(Transform):
name: str = 'randomcrop'
size: int = DEFAULT_IMG_SIZE


@dataclass
class RandomResizedCrop(Transform):
name: str = 'randomresizedcrop'
size: int = DEFAULT_IMG_SIZE
scale: List = field(default_factory=lambda: [0.08, 1.0])
ratio: List = field(default_factory=lambda: [0.75, 1.33])
interpolation: Optional[str] = 'bilinear'


@dataclass
class RandomHorizontalFlip(Transform):
name: str = 'randomhorizontalflip'
p: float = 0.5


@dataclass
class RandomVerticalFlip(Transform):
name: str = 'randomverticalflip'
p: float = 0.5


@dataclass
class Resize(Transform):
name: str = 'resize'
size: List = field(default_factory=lambda: [DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE])
interpolation: Optional[str] = 'bilinear'
max_size: Optional[int] = None


class TrivialAugmentWide(Transform):
name: str = 'trivialaugmentwide'
num_magnitude_bins: int = 31
interpolation: str = 'bilinear'
fill: Optional[int] = None


@dataclass
class RandomMixup(Transform):
name: str = 'mixup'
alpha: float = 0.2
p: float = 1.0
inplace: bool = False


@dataclass
class RandomCutmix(Transform):
name: str = 'cutmix'
alpha: float = 1.0
p: float = 1.0
inplace: bool = False


@dataclass
class ClassificationAugmentationConfig(AugmentationConfig):
img_size: int = 256
train: Train = field(default_factory=lambda: Train(
transforms=[RandomResizedCrop(size=256), RandomHorizontalFlip()],
mix_transforms=[RandomCutmix()]
))
inference: Inference = field(default_factory=lambda: Inference(
transforms=[Resize(size=[256, 256])]
))


@dataclass
class SegmentationAugmentationConfig(AugmentationConfig):
img_size: int = 512
train: Train = field(default_factory=lambda: Train(
transforms=[RandomResizedCrop(size=512), RandomHorizontalFlip(), ColorJitter()],
mix_transforms=None
))
inference: Inference = field(default_factory=lambda: Inference(
transforms=[Resize(size=[512, 512])]
))


@dataclass
class DetectionAugmentationConfig(AugmentationConfig):
img_size: int = 512
train: Train = field(default_factory=lambda: Train(
transforms=[Resize(size=[512, 512])],
mix_transforms=None
))
inference: Inference = field(default_factory=lambda: Inference(
transforms=[Resize(size=[512, 512])],
))
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from netspresso_trainer.cfg import (
from netspresso.trainer.data.data import (
DatasetConfig,
ImageLabelPathConfig,
LocalClassificationDatasetConfig,
LocalDetectionDatasetConfig,
LocalSegmentationDatasetConfig,
PathConfig,
)

DATA_CONFIG_TYPE = {
"classification": LocalClassificationDatasetConfig,
"detection": LocalDetectionDatasetConfig,
"segmentation": LocalSegmentationDatasetConfig,
}

__all__ = ["ImageLabelPathConfig", "PathConfig", "DATA_CONFIG_TYPE", "DatasetConfig"]
Loading

0 comments on commit fd8e9e9

Please sign in to comment.