mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NJT] Fix schema validation error in jagged functions (#165307)
Fixes #161812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
70ec464c16
commit
0ce945790e
@ -1232,7 +1232,7 @@ def unsqueeze_default(func, *args, **kwargs):
|
||||
return NestedTensor(func(values, **new_kwargs), **output_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
|
||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
|
||||
def cat_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
@ -2275,7 +2275,7 @@ def value_selecting_reduction_backward_default(func, *args, **kwargs):
|
||||
return NestedTensor(func(**new_kwargs), **output_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
|
||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
|
||||
def stack_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
|
Reference in New Issue
Block a user