Skip to content

Commit

Permalink
[Feature] Support auto import modules from registry. (open-mmlab#1731)
Browse files Browse the repository at this point in the history
* [Feature] Support auto import modules from registry.

* limit mmdet version

* location parrent dir if it not exist
  • Loading branch information
Harold-lkk authored Feb 17, 2023
1 parent df0be64 commit 1127240
Show file tree
Hide file tree
Showing 16 changed files with 98 additions and 61 deletions.
2 changes: 1 addition & 1 deletion docs/en/get_started/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,6 @@ MMOCR has different version requirements on MMEngine, MMCV and MMDetection at ea

| MMOCR | MMEngine | MMCV | MMDetection |
| -------------- | --------------------------- | -------------------------- | --------------------------- |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
2 changes: 1 addition & 1 deletion docs/zh_cn/get_started/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,6 @@ docker run --gpus all --shm-size=8g -it -v {实际数据目录}:/mmocr/data mmoc

| MMOCR | MMEngine | MMCV | MMDetection |
| -------------- | --------------------------- | -------------------------- | --------------------------- |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
2 changes: 1 addition & 1 deletion mmocr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
f'Please install mmengine>={mmengine_minimum_version}, ' \
f'<{mmengine_maximum_version}.'

mmdet_minimum_version = '3.0.0rc0'
mmdet_minimum_version = '3.0.0rc5'
mmdet_maximum_version = '3.1.0'
mmdet_version = digit_version(mmdet.__version__)

Expand Down
5 changes: 3 additions & 2 deletions mmocr/apis/inferencers/base_mmocr_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
from mmengine.dataset import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData
from torch import Tensor

from mmocr.utils import ConfigType, register_all_modules
from mmocr.utils import ConfigType

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self,
# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0
register_all_modules()
init_default_scope(scope)
super().__init__(
model=model, weights=weights, device=device, scope=scope)

Expand Down
94 changes: 73 additions & 21 deletions mmocr/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,103 @@
from mmengine.registry import Registry

# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
RUNNERS = Registry(
'runner',
parent=MMENGINE_RUNNERS,
# TODO: update the location when mmocr has its own runner
locations=['mmocr.engine'])
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
'runner constructor',
parent=MMENGINE_RUNNER_CONSTRUCTORS,
# TODO: update the location when mmocr has its own runner constructor
locations=['mmocr.engine'])
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
LOOPS = Registry(
'loop',
parent=MMENGINE_LOOPS,
# TODO: update the location when mmocr has its own loop
locations=['mmocr.engine'])
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
HOOKS = Registry(
'hook', parent=MMENGINE_HOOKS, locations=['mmocr.engine.hooks'])

# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
DATASETS = Registry(
'dataset', parent=MMENGINE_DATASETS, locations=['mmocr.datasets'])
DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmocr.datasets.samplers'])
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmocr.datasets.transforms'])

# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmocr.models'])
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
MODEL_WRAPPERS = Registry(
'model wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmocr.models'])
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmocr.models'])

# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
# TODO: update the location when mmocr has its own optimizer
locations=['mmocr.engine'])
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
OPTIM_WRAPPERS = Registry(
'optimizer wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
# TODO: update the location when mmocr has its own optimizer wrapper
locations=['mmocr.engine'])
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
'optimizer constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
# TODO: update the location when mmocr has its own optimizer constructor
locations=['mmocr.engine'])
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)

'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
# TODO: update the location when mmocr has its own parameter scheduler
locations=['mmocr.engine'])
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
METRICS = Registry(
'metric', parent=MMENGINE_METRICS, locations=['mmocr.evaluation.metrics'])
# manage evaluator
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
EVALUATOR = Registry(
'evaluator',
parent=MMENGINE_EVALUATOR,
locations=['mmocr.evaluation.evaluator'])

# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
TASK_UTILS = Registry(
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmocr.models'])

# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmocr.visualization'])
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
VISBACKENDS = Registry(
'visualizer backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmocr.visualization'])

# manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
LOG_PROCESSORS = Registry(
'logger processor',
parent=MMENGINE_LOG_PROCESSORS,
# TODO: update the location when mmocr has its own log processor
locations=['mmocr.engine'])
6 changes: 3 additions & 3 deletions requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mmcv>==2.0.0rc1,<2.1.0
mmdet>=3.0.0rc0,<3.1.0
mmengine>= 0.1.0, <1.0.0
mmcv>==2.0.0rc4,<2.1.0
mmdet>=3.0.0rc5,<3.1.0
mmengine>= 0.5.0, <1.0.0
5 changes: 3 additions & 2 deletions tests/test_datasets/test_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from unittest import TestCase
from unittest.mock import MagicMock

from mmengine.registry import init_default_scope

from mmocr.datasets import ConcatDataset, OCRDataset
from mmocr.registry import TRANSFORMS
from mmocr.utils import register_all_modules


class TestConcatDataset(TestCase):
Expand All @@ -22,7 +23,7 @@ def __call__(self, *args, **kwargs):

def setUp(self):

register_all_modules()
init_default_scope('mmocr')
dataset = OCRDataset

# create dataset_a
Expand Down
5 changes: 3 additions & 2 deletions tests/test_models/test_textdet/test_detectors/test_drrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
import numpy as np
import torch
from mmengine.config import Config, ConfigDict
from mmengine.registry import init_default_scope

from mmocr.registry import MODELS
from mmocr.testing.data import create_dummy_textdet_inputs
from mmocr.utils import register_all_modules


class TestDRRG(unittest.TestCase):

def setUp(self):
cfg_path = 'textdet/drrg/drrg_resnet50_fpn-unet_1200e_ctw1500.py'
self.model_cfg = self._get_detector_cfg(cfg_path)
register_all_modules()
cfg = self._get_config_module(cfg_path)
init_default_scope(cfg.get('default_scope', 'mmocr'))
self.model = MODELS.build(self.model_cfg)
self.inputs = create_dummy_textdet_inputs(input_shape=(1, 3, 224, 224))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from mmdet.structures import DetDataSample
from mmdet.testing import demo_mm_inputs
from mmengine.config import Config
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import register_all_modules


class TestMMDetWrapper(unittest.TestCase):

def setUp(self):
register_all_modules()
init_default_scope('mmocr')
model_cfg_fcos = dict(
type='MMDetWrapper',
cfg=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from unittest import TestCase

import torch
from mmengine.registry import init_default_scope

from mmocr.models.textrecog.backbones import ResNet
from mmocr.utils import register_all_modules


class TestResNet(TestCase):

def setUp(self) -> None:
self.img = torch.rand(1, 3, 32, 100)
register_all_modules()
init_default_scope('mmocr')

def test_resnet45_aster(self):
resnet45_aster = ResNet(
Expand Down
5 changes: 2 additions & 3 deletions tools/analysis_tools/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import numpy as np
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer

from mmocr.registry import DATASETS, VISUALIZERS
from mmocr.utils import register_all_modules


# TODO: Support for printing the change in key of results
Expand Down Expand Up @@ -331,8 +331,7 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# register all modules in mmyolo into the registries
register_all_modules()
init_default_scope(cfg.get('default_scope', 'mmocr'))

dataset_cfg, visualizer_cfg = obtain_dataset_cfg(cfg, args.phase,
args.mode, args.task)
Expand Down
5 changes: 2 additions & 3 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
from mmengine import Config
from mmengine.registry import init_default_scope

from mmocr.registry import MODELS
from mmocr.utils import register_all_modules

register_all_modules()


def parse_args():
Expand Down Expand Up @@ -38,6 +36,7 @@ def main():
input_shape = (1, 3, h, w)

cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
model = MODELS.build(cfg.model)

flops = FlopCountAnalysis(model, torch.ones(input_shape))
Expand Down
6 changes: 2 additions & 4 deletions tools/analysis_tools/offline_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import mmengine
from mmengine.config import Config, DictAction
from mmengine.evaluator import Evaluator

from mmocr.utils import register_all_modules
from mmengine.registry import init_default_scope


def parse_args():
Expand All @@ -33,10 +32,9 @@ def parse_args():
def main():
args = parse_args()

register_all_modules()

# load config
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down
2 changes: 0 additions & 2 deletions tools/dataset_converters/prepare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings

from mmocr.datasets.preparers import DatasetPreparer
from mmocr.utils import register_all_modules


def parse_args():
Expand Down Expand Up @@ -39,7 +38,6 @@ def parse_args():

def main():
args = parse_args()
register_all_modules()
for dataset in args.datasets:
if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)):
warnings.warn(f'{dataset} is not supported yet. Please check '
Expand Down
6 changes: 0 additions & 6 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from mmocr.utils import register_all_modules


def parse_args():
parser = argparse.ArgumentParser(description='Test (and eval) a model')
Expand Down Expand Up @@ -80,10 +78,6 @@ def trigger_visualization_hook(cfg, args):
def main():
args = parse_args()

# register all modules in mmocr into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
Expand Down
6 changes: 0 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from mmocr.utils import register_all_modules


def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
Expand Down Expand Up @@ -54,10 +52,6 @@ def parse_args():
def main():
args = parse_args()

# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
Expand Down

0 comments on commit 1127240

Please sign in to comment.