mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[NestedTensor] Add example NestedTensor objects with inner dimension of size 1 to tests reducing along jagged dimension for NestedTensor (#131516)
Add example `NestedTensor`s with inner dimension of size `1` to `_get_example_tensor_lists` with `include_inner_dim_size_1=True`. This diff creates `NestedTensor`s of sizes `(B, *, 1)` and `(B, *, 5, 1)`, ensuring that the current implementations of jagged reductions for `sum` and `mean` hold for tensors of effective shape `(B, *)` and `(B, *, 5)`. Differential Revision: [D59846023](https://our.internmc.facebook.com/intern/diff/D59846023/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131516 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9db1b0597
commit
e782918b8e
@ -3485,7 +3485,10 @@ class TestNestedTensorSubclass(TestCase):
|
||||
return out
|
||||
|
||||
def _get_example_tensor_lists(
|
||||
self, include_list_of_lists=True, include_requires_grad=True
|
||||
self,
|
||||
include_list_of_lists=True,
|
||||
include_requires_grad=True,
|
||||
include_inner_dim_size_1=False,
|
||||
):
|
||||
def _make_tensor(
|
||||
*shape, include_requires_grad=include_requires_grad, requires_grad=True
|
||||
@ -3534,6 +3537,24 @@ class TestNestedTensorSubclass(TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
if include_inner_dim_size_1:
|
||||
example_lists.append(
|
||||
[
|
||||
_make_tensor(2, 1),
|
||||
_make_tensor(3, 1, requires_grad=False),
|
||||
_make_tensor(4, 1, requires_grad=False),
|
||||
_make_tensor(6, 1),
|
||||
] # (B, *, 1)
|
||||
)
|
||||
example_lists.append(
|
||||
[
|
||||
_make_tensor(2, 5, 1),
|
||||
_make_tensor(3, 5, 1, requires_grad=False),
|
||||
_make_tensor(4, 5, 1, requires_grad=False),
|
||||
_make_tensor(6, 5, 1),
|
||||
] # (B, *, 5, 1)
|
||||
)
|
||||
|
||||
return example_lists
|
||||
|
||||
def test_tensor_attributes(self, device):
|
||||
@ -4125,7 +4146,9 @@ class TestNestedTensorSubclass(TestCase):
|
||||
op_name = get_op_name(func)
|
||||
|
||||
tensor_lists = self._get_example_tensor_lists(
|
||||
include_list_of_lists=False, include_requires_grad=components_require_grad
|
||||
include_list_of_lists=False,
|
||||
include_requires_grad=components_require_grad,
|
||||
include_inner_dim_size_1=True, # (B, *, 1)
|
||||
)
|
||||
reduce_dim = (1,) # ragged
|
||||
|
||||
|
||||
Reference in New Issue
Block a user