mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[NJT] Fix schema validation error in jagged functions (#165307)
Fixes #161812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
70ec464c16
commit
0ce945790e
@ -857,6 +857,22 @@ class TestNestedTensor(NestedTensorTestCase):
|
|||||||
):
|
):
|
||||||
torch.cat([x, y], dim=-1)
|
torch.cat([x, y], dim=-1)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/issues/161812
|
||||||
|
def test_jagged_with_dim_error(self):
|
||||||
|
x = torch.nested.nested_tensor(
|
||||||
|
[torch.ones(3, 2, 3), torch.ones(4, 2, 3)], layout=torch.jagged
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"not supported for NestedTensor on dim=0",
|
||||||
|
):
|
||||||
|
torch.cat([x, x])
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"not supported for NestedTensor on dim=0",
|
||||||
|
):
|
||||||
|
torch.stack([x, x])
|
||||||
|
|
||||||
def test_nested_view_from_buffer_overflow_errors(self):
|
def test_nested_view_from_buffer_overflow_errors(self):
|
||||||
buffer = torch.tensor([1])
|
buffer = torch.tensor([1])
|
||||||
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)
|
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)
|
||||||
|
@ -1232,7 +1232,7 @@ def unsqueeze_default(func, *args, **kwargs):
|
|||||||
return NestedTensor(func(values, **new_kwargs), **output_kwargs)
|
return NestedTensor(func(values, **new_kwargs), **output_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
|
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
|
||||||
def cat_default(func, *args, **kwargs):
|
def cat_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
@ -2275,7 +2275,7 @@ def value_selecting_reduction_backward_default(func, *args, **kwargs):
|
|||||||
return NestedTensor(func(**new_kwargs), **output_kwargs)
|
return NestedTensor(func(**new_kwargs), **output_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
|
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
|
||||||
def stack_default(func, *args, **kwargs):
|
def stack_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
|
Reference in New Issue
Block a user