Skip to content

Commit

Permalink
[Enhancement] Speedup formatting by replacing np.transpose with torch…
Browse files Browse the repository at this point in the history
….permute (open-mmlab#1719)
  • Loading branch information
gaotongxiao authored Feb 16, 2023
1 parent f820470 commit df0be64
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
39 changes: 33 additions & 6 deletions mmocr/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,17 @@ def transform(self, results: dict) -> dict:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img

data_sample = TextDetDataSample()
instance_data = InstanceData()
Expand Down Expand Up @@ -174,8 +183,17 @@ def transform(self, results: dict) -> dict:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img

data_sample = TextRecogDataSample()
gt_text = LabelData()
Expand Down Expand Up @@ -272,8 +290,17 @@ def transform(self, results: dict) -> dict:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
else:
packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_datasets/test_transforms/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,17 @@ def test_packdetinput(self):
transform = PackTextDetInputs()
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_samples', results)

# test non-contiugous img
nc_datainfo = copy.deepcopy(datainfo)
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
results = transform(nc_datainfo)
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))

data_sample = results['data_samples']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
Expand Down Expand Up @@ -115,6 +123,13 @@ def test_packrecogtinput(self):
self.assertIn('valid_ratio', data_sample)
self.assertIn('pad_shape', data_sample)

# test non-contiugous img
nc_datainfo = copy.deepcopy(datainfo)
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
results = transform(nc_datainfo)
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))

transform = PackTextRecogInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
Expand Down Expand Up @@ -174,6 +189,13 @@ def test_transform(self):
torch.int64)
self.assertIsInstance(data_sample.gt_instances.texts, list)

# test non-contiugous img
nc_datainfo = copy.deepcopy(datainfo)
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
results = self.transform(nc_datainfo)
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))

transform = PackKIEInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
Expand Down

0 comments on commit df0be64

Please sign in to comment.