mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[NJT] Fix schema validation error in jagged functions (#165307)
Fixes #161812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
70ec464c16
commit
0ce945790e
@ -857,6 +857,22 @@ class TestNestedTensor(NestedTensorTestCase):
|
||||
):
|
||||
torch.cat([x, y], dim=-1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/161812
|
||||
def test_jagged_with_dim_error(self):
|
||||
x = torch.nested.nested_tensor(
|
||||
[torch.ones(3, 2, 3), torch.ones(4, 2, 3)], layout=torch.jagged
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"not supported for NestedTensor on dim=0",
|
||||
):
|
||||
torch.cat([x, x])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"not supported for NestedTensor on dim=0",
|
||||
):
|
||||
torch.stack([x, x])
|
||||
|
||||
def test_nested_view_from_buffer_overflow_errors(self):
|
||||
buffer = torch.tensor([1])
|
||||
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)
|
||||
|
@ -1232,7 +1232,7 @@ def unsqueeze_default(func, *args, **kwargs):
|
||||
return NestedTensor(func(values, **new_kwargs), **output_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
|
||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
|
||||
def cat_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
@ -2275,7 +2275,7 @@ def value_selecting_reduction_backward_default(func, *args, **kwargs):
|
||||
return NestedTensor(func(**new_kwargs), **output_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
|
||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
|
||||
def stack_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
|
Reference in New Issue
Block a user