Fix for normalizing signature for op overloads (#77182)

Previously, we were taking the `.op` from OpOverload/OpOverloadPacket and looking for a mapping in `_jit_builtins` for their signature. Those will only exist for operators on the public api, not the overload packets, e.g. `torch.resize_as_` not `torch.ops.aten.resize_as_` (as least in this case, and im pretty sure generally). The OpOverloads/OpOverloadPackets have schemas stored on them so we can just use those directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77182
Approved by: https://github.com/anjali411
This commit is contained in:
Elias Ellison
2022-05-10 23:36:26 +00:00
committed by PyTorch MergeBot
parent 02713221e3
commit 023aafbcd7
2 changed files with 22 additions and 13 deletions

View File

@ -1678,6 +1678,13 @@ class TestModule(torch.nn.Module):
)
self.assertEqual(norm_args_and_kwargs.args, tuple())
def test_normalize_args_op_overload(self):
for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
inp1 = torch.rand([1])
inp2 = torch.rand([4])
args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
self.assertIs(kwargs["input"], inp1)
self.assertIs(kwargs["the_template"], inp2)
instantiate_device_type_tests(TestNormalizeOperators, globals())