diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 18c12a99f6e1..c587b6f23633 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1678,6 +1678,13 @@ class TestModule(torch.nn.Module): ) self.assertEqual(norm_args_and_kwargs.args, tuple()) + def test_normalize_args_op_overload(self): + for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]: + inp1 = torch.rand([1]) + inp2 = torch.rand([4]) + args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True) + self.assertIs(kwargs["input"], inp1) + self.assertIs(kwargs["the_template"], inp2) instantiate_device_type_tests(TestNormalizeOperators, globals()) diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 0b72afaca2ed..9848baa04c10 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -135,20 +135,22 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): return_schemas=True, returns a tuple containing the optional Python signatures and the optional TorchScript Function signature """ - if isinstance(op, OpOverloadPacket) or isinstance(op, OpOverload): - op = op.op - override = _manual_overrides.get(op) - if override: - return (override, None) if return_schemas else None + if isinstance(op, OpOverload): + schemas = [op._schema] + elif isinstance(op, OpOverloadPacket): + schemas = [getattr(op, overload)._schema for overload in op.overloads()] + else: + override = _manual_overrides.get(op) + if override: + return (override, None) if return_schemas else None - aten_fn = torch.jit._builtins._find_builtin(op) + aten_fn = torch.jit._builtins._find_builtin(op) - if aten_fn is None: - return (None, None) if return_schemas else None + if aten_fn is None: + return (None, None) if return_schemas else None + schemas = torch._C._jit_get_schemas_for_operator(aten_fn) - schemas = torch._C._jit_get_schemas_for_operator(aten_fn) signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] - return (signatures, schemas) if return_schemas else signatures @compatibility(is_backward_compatible=False) @@ -257,9 +259,9 @@ def normalize_function( if kwargs is None: kwargs = {} new_args_and_kwargs = None - if isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload): - target = target.op - if not isinstance(target, types.BuiltinFunctionType): + if not isinstance(target, types.BuiltinFunctionType) and not ( + isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload) + ): target_for_analysis = target if target in boolean_dispatched: # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have