mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix for normalizing signature for op overloads (#77182)
Previously, we were taking the `.op` from OpOverload/OpOverloadPacket and looking for a mapping in `_jit_builtins` for their signature. Those will only exist for operators on the public api, not the overload packets, e.g. `torch.resize_as_` not `torch.ops.aten.resize_as_` (as least in this case, and im pretty sure generally). The OpOverloads/OpOverloadPackets have schemas stored on them so we can just use those directly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77182 Approved by: https://github.com/anjali411
This commit is contained in:
committed by
PyTorch MergeBot
parent
02713221e3
commit
023aafbcd7
@ -1678,6 +1678,13 @@ class TestModule(torch.nn.Module):
|
||||
)
|
||||
self.assertEqual(norm_args_and_kwargs.args, tuple())
|
||||
|
||||
def test_normalize_args_op_overload(self):
|
||||
for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
|
||||
inp1 = torch.rand([1])
|
||||
inp2 = torch.rand([4])
|
||||
args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
|
||||
self.assertIs(kwargs["input"], inp1)
|
||||
self.assertIs(kwargs["the_template"], inp2)
|
||||
|
||||
instantiate_device_type_tests(TestNormalizeOperators, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user