diff --git a/mmocr/datasets/transforms/formatting.py b/mmocr/datasets/transforms/formatting.py index 1649850e6..b9b71437a 100644 --- a/mmocr/datasets/transforms/formatting.py +++ b/mmocr/datasets/transforms/formatting.py @@ -90,8 +90,17 @@ def transform(self, results: dict) -> dict: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) - packed_results['inputs'] = to_tensor(img) + # A simple trick to speedup formatting by 3-5 times when + # OMP_NUM_THREADS != 1 + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if img.flags.c_contiguous: + img = to_tensor(img) + img = img.permute(2, 0, 1).contiguous() + else: + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + img = to_tensor(img) + packed_results['inputs'] = img data_sample = TextDetDataSample() instance_data = InstanceData() @@ -174,8 +183,17 @@ def transform(self, results: dict) -> dict: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) - packed_results['inputs'] = to_tensor(img) + # A simple trick to speedup formatting by 3-5 times when + # OMP_NUM_THREADS != 1 + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if img.flags.c_contiguous: + img = to_tensor(img) + img = img.permute(2, 0, 1).contiguous() + else: + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + img = to_tensor(img) + packed_results['inputs'] = img data_sample = TextRecogDataSample() gt_text = LabelData() @@ -272,8 +290,17 @@ def transform(self, results: dict) -> dict: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) - packed_results['inputs'] = to_tensor(img) + # A simple trick to speedup formatting by 3-5 times when + # OMP_NUM_THREADS != 1 + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if img.flags.c_contiguous: + img = to_tensor(img) + img = img.permute(2, 0, 1).contiguous() + else: + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + img = to_tensor(img) + packed_results['inputs'] = img else: packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0) diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index a29eecf11..21e9d10f2 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -36,9 +36,17 @@ def test_packdetinput(self): transform = PackTextDetInputs() results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) + self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10])) self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10)) self.assertIn('data_samples', results) + # test non-contiugous img + nc_datainfo = copy.deepcopy(datainfo) + nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0) + results = transform(nc_datainfo) + self.assertIn('inputs', results) + self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10])) + data_sample = results['data_samples'] self.assertIn('bboxes', data_sample.gt_instances) self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor) @@ -115,6 +123,13 @@ def test_packrecogtinput(self): self.assertIn('valid_ratio', data_sample) self.assertIn('pad_shape', data_sample) + # test non-contiugous img + nc_datainfo = copy.deepcopy(datainfo) + nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0) + results = transform(nc_datainfo) + self.assertIn('inputs', results) + self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10])) + transform = PackTextRecogInputs(meta_keys=('img_path', )) results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) @@ -174,6 +189,13 @@ def test_transform(self): torch.int64) self.assertIsInstance(data_sample.gt_instances.texts, list) + # test non-contiugous img + nc_datainfo = copy.deepcopy(datainfo) + nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0) + results = self.transform(nc_datainfo) + self.assertIn('inputs', results) + self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10])) + transform = PackKIEInputs(meta_keys=('img_path', )) results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results)