diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 5a725ccdd40b..3f20e8b6fac5 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -857,6 +857,22 @@ class TestNestedTensor(NestedTensorTestCase): ): torch.cat([x, y], dim=-1) + # https://github.com/pytorch/pytorch/issues/161812 + def test_jagged_with_dim_error(self): + x = torch.nested.nested_tensor( + [torch.ones(3, 2, 3), torch.ones(4, 2, 3)], layout=torch.jagged + ) + with self.assertRaisesRegex( + RuntimeError, + "not supported for NestedTensor on dim=0", + ): + torch.cat([x, x]) + with self.assertRaisesRegex( + RuntimeError, + "not supported for NestedTensor on dim=0", + ): + torch.stack([x, x]) + def test_nested_view_from_buffer_overflow_errors(self): buffer = torch.tensor([1]) sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index e157538ff123..f52bfab2a8b3 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1232,7 +1232,7 @@ def unsqueeze_default(func, *args, **kwargs): return NestedTensor(func(values, **new_kwargs), **output_kwargs) -@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any") +@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?") def cat_default(func, *args, **kwargs): _, new_kwargs = normalize_function( # type: ignore[misc] func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True @@ -2275,7 +2275,7 @@ def value_selecting_reduction_backward_default(func, *args, **kwargs): return NestedTensor(func(**new_kwargs), **output_kwargs) -@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any") +@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?") def stack_default(func, *args, **kwargs): _, new_kwargs = normalize_function( # type: ignore[misc] func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True