Skip to content

Commit

Permalink
[NestedTensor] Add example NestedTensor objects with inner dimension …
Browse files Browse the repository at this point in the history
…of size 1 to tests reducing along jagged dimension for NestedTensor (pytorch#131516)

Add example `NestedTensor`s with inner dimension of size `1` to `_get_example_tensor_lists` with `include_inner_dim_size_1=True`. This diff creates `NestedTensor`s of sizes `(B, *, 1)` and `(B, *, 5, 1)`, ensuring that the current implementations of jagged reductions for `sum` and `mean` hold for tensors of effective shape `(B, *)` and `(B, *, 5)`.

Differential Revision: [D59846023](https://our.internmc.facebook.com/intern/diff/D59846023/)
Pull Request resolved: pytorch#131516
Approved by: https://github.com/davidberard98
  • Loading branch information
jananisriram authored and pytorchmergebot committed Jul 24, 2024
1 parent e9db1b0 commit e782918
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,10 @@ def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
return out

def _get_example_tensor_lists(
self, include_list_of_lists=True, include_requires_grad=True
self,
include_list_of_lists=True,
include_requires_grad=True,
include_inner_dim_size_1=False,
):
def _make_tensor(
*shape, include_requires_grad=include_requires_grad, requires_grad=True
Expand Down Expand Up @@ -3534,6 +3537,24 @@ def _make_tensor(
]
)

if include_inner_dim_size_1:
example_lists.append(
[
_make_tensor(2, 1),
_make_tensor(3, 1, requires_grad=False),
_make_tensor(4, 1, requires_grad=False),
_make_tensor(6, 1),
] # (B, *, 1)
)
example_lists.append(
[
_make_tensor(2, 5, 1),
_make_tensor(3, 5, 1, requires_grad=False),
_make_tensor(4, 5, 1, requires_grad=False),
_make_tensor(6, 5, 1),
] # (B, *, 5, 1)
)

return example_lists

def test_tensor_attributes(self, device):
Expand Down Expand Up @@ -4125,7 +4146,9 @@ def test_jagged_op_dim_reduce_ragged_idx_1(
op_name = get_op_name(func)

tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False, include_requires_grad=components_require_grad
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
)
reduce_dim = (1,) # ragged

Expand Down

0 comments on commit e782918

Please sign in to comment.