[NJT] Fix schema validation error in jagged functions (#165307)

Fixes #161812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307
Approved by: https://github.com/soulitzer
This commit is contained in:
can-gaa-hou
2025-10-13 17:59:14 +00:00
committed by PyTorch MergeBot
parent 70ec464c16
commit 0ce945790e
2 changed files with 18 additions and 2 deletions

View File

@ -1232,7 +1232,7 @@ def unsqueeze_default(func, *args, **kwargs):
return NestedTensor(func(values, **new_kwargs), **output_kwargs)
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
def cat_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
@ -2275,7 +2275,7 @@ def value_selecting_reduction_backward_default(func, *args, **kwargs):
return NestedTensor(func(**new_kwargs), **output_kwargs)
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
def stack_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True