[NJT] Fix schema validation error in jagged functions (#165307)

Fixes #161812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307
Approved by: https://github.com/soulitzer
This commit is contained in:
can-gaa-hou
2025-10-13 17:59:14 +00:00
committed by PyTorch MergeBot
parent 70ec464c16
commit 0ce945790e
2 changed files with 18 additions and 2 deletions

View File

@ -857,6 +857,22 @@ class TestNestedTensor(NestedTensorTestCase):
):
torch.cat([x, y], dim=-1)
# https://github.com/pytorch/pytorch/issues/161812
def test_jagged_with_dim_error(self):
x = torch.nested.nested_tensor(
[torch.ones(3, 2, 3), torch.ones(4, 2, 3)], layout=torch.jagged
)
with self.assertRaisesRegex(
RuntimeError,
"not supported for NestedTensor on dim=0",
):
torch.cat([x, x])
with self.assertRaisesRegex(
RuntimeError,
"not supported for NestedTensor on dim=0",
):
torch.stack([x, x])
def test_nested_view_from_buffer_overflow_errors(self):
buffer = torch.tensor([1])
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)

View File

@ -1232,7 +1232,7 @@ def unsqueeze_default(func, *args, **kwargs):
return NestedTensor(func(values, **new_kwargs), **output_kwargs)
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
def cat_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
@ -2275,7 +2275,7 @@ def value_selecting_reduction_backward_default(func, *args, **kwargs):
return NestedTensor(func(**new_kwargs), **output_kwargs)
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
def stack_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True