diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 30b82ee79296c..e1ea3389f0504 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3485,7 +3485,10 @@ def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): return out def _get_example_tensor_lists( - self, include_list_of_lists=True, include_requires_grad=True + self, + include_list_of_lists=True, + include_requires_grad=True, + include_inner_dim_size_1=False, ): def _make_tensor( *shape, include_requires_grad=include_requires_grad, requires_grad=True @@ -3534,6 +3537,24 @@ def _make_tensor( ] ) + if include_inner_dim_size_1: + example_lists.append( + [ + _make_tensor(2, 1), + _make_tensor(3, 1, requires_grad=False), + _make_tensor(4, 1, requires_grad=False), + _make_tensor(6, 1), + ] # (B, *, 1) + ) + example_lists.append( + [ + _make_tensor(2, 5, 1), + _make_tensor(3, 5, 1, requires_grad=False), + _make_tensor(4, 5, 1, requires_grad=False), + _make_tensor(6, 5, 1), + ] # (B, *, 5, 1) + ) + return example_lists def test_tensor_attributes(self, device): @@ -4125,7 +4146,9 @@ def test_jagged_op_dim_reduce_ragged_idx_1( op_name = get_op_name(func) tensor_lists = self._get_example_tensor_lists( - include_list_of_lists=False, include_requires_grad=components_require_grad + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) ) reduce_dim = (1,) # ragged