From e782918b8eeffc465acff40f0c55c8d80bc20ce2 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Tue, 23 Jul 2024 13:45:15 -0700 Subject: [PATCH] [NestedTensor] Add example NestedTensor objects with inner dimension of size 1 to tests reducing along jagged dimension for NestedTensor (#131516) Add example `NestedTensor`s with inner dimension of size `1` to `_get_example_tensor_lists` with `include_inner_dim_size_1=True`. This diff creates `NestedTensor`s of sizes `(B, *, 1)` and `(B, *, 5, 1)`, ensuring that the current implementations of jagged reductions for `sum` and `mean` hold for tensors of effective shape `(B, *)` and `(B, *, 5)`. Differential Revision: [D59846023](https://our.internmc.facebook.com/intern/diff/D59846023/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131516 Approved by: https://github.com/davidberard98 --- test/test_nestedtensor.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 30b82ee79296c..e1ea3389f0504 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3485,7 +3485,10 @@ def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): return out def _get_example_tensor_lists( - self, include_list_of_lists=True, include_requires_grad=True + self, + include_list_of_lists=True, + include_requires_grad=True, + include_inner_dim_size_1=False, ): def _make_tensor( *shape, include_requires_grad=include_requires_grad, requires_grad=True @@ -3534,6 +3537,24 @@ def _make_tensor( ] ) + if include_inner_dim_size_1: + example_lists.append( + [ + _make_tensor(2, 1), + _make_tensor(3, 1, requires_grad=False), + _make_tensor(4, 1, requires_grad=False), + _make_tensor(6, 1), + ] # (B, *, 1) + ) + example_lists.append( + [ + _make_tensor(2, 5, 1), + _make_tensor(3, 5, 1, requires_grad=False), + _make_tensor(4, 5, 1, requires_grad=False), + _make_tensor(6, 5, 1), + ] # (B, *, 5, 1) + ) + return example_lists def test_tensor_attributes(self, device): @@ -4125,7 +4146,9 @@ def test_jagged_op_dim_reduce_ragged_idx_1( op_name = get_op_name(func) tensor_lists = self._get_example_tensor_lists( - include_list_of_lists=False, include_requires_grad=components_require_grad + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) ) reduce_dim = (1,) # ragged