From d0a2e11284ed3b51b96af9be2d8bfa0130c8e769 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Wed, 22 Jan 2025 23:46:50 +0000 Subject: [PATCH] [BE][export] Change custom_op registeration style (#145315) Summary: `test_unbacked_bindings_for_divisible_u_symint` has been flaky for a while due to ``` Tried to register an operator (mylib::foo(Tensor a, Tensor b) -> Tensor) with the same name and overload name multiple times. ``` It is likely due to when all variants of this test are being run (non-strict, retrace, serdes) simultaneously. In later tests, the operator has already been registered. In this diff, we change registration style. Test Plan: ``` buck2 test mode/dev-nosan caffe2/test:test_export -- -r test_unbacked_bindings_for_divisible_u_symint ``` Differential Revision: D68465258 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145315 Approved by: https://github.com/zou3519 --- test/export/test_export.py | 60 +++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 3bdf65eb7928..2139d3f7e96b 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3714,43 +3714,35 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): @testing.expectedFailureCppSerDes # no unbacked bindings after deserialization? @testing.expectedFailureSerDerNonStrict def test_unbacked_bindings_for_divisible_u_symint(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - torch.library.define( - "mylib::foo", - "(Tensor a, Tensor b) -> (Tensor)", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) + class M(torch.nn.Module): + def forward(self, a, b): + return torch.ops.mylib.foo_unbacked(a, b) - class M(torch.nn.Module): - def forward(self, a, b): - return torch.ops.mylib.foo(a, b) + @torch.library.custom_op("mylib::foo_unbacked", mutates_args={}) + def foo_unbacked(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a[b.item()] - @torch.library.impl("mylib::foo", "cpu", lib=lib) - def foo_impl(a, b): - return a[b.item()] + @foo_unbacked.register_fake + def foo_unbacked_fake_impl(a, b): + ctx = torch.library.get_ctx() + u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10 + return torch.empty(u, a.shape[1], dtype=a.dtype) - @torch.library.register_fake("mylib::foo", lib=lib) - def foo_fake_impl(a, b): - ctx = torch.library.get_ctx() - u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10 - return torch.empty(u, a.shape[1], dtype=a.dtype) - - ep = export( - M(), - (torch.randn(100, 4), torch.tensor(10)), - ) - foo = [node for node in ep.graph.nodes if node.name == "foo"][0] - unbacked_bindings = foo.meta["unbacked_bindings"] - self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path} - u = next(iter(unbacked_bindings.keys())) - self.assertEqual( - type(u).__name__, "Symbol" - ) # check binding is symbol, not expr - path = unbacked_bindings[u] - self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)] - self.assertEqual(type(path[2]).__name__, "DivideByKey") - self.assertEqual(path[2].divisor, 10) + ep = export( + M(), + (torch.randn(100, 4), torch.tensor(10)), + ) + foo = [node for node in ep.graph.nodes if node.name == "foo_unbacked"][0] + unbacked_bindings = foo.meta["unbacked_bindings"] + self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path} + u = next(iter(unbacked_bindings.keys())) + self.assertEqual( + type(u).__name__, "Symbol" + ) # check binding is symbol, not expr + path = unbacked_bindings[u] + self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)] + self.assertEqual(type(path[2]).__name__, "DivideByKey") + self.assertEqual(path[2].divisor, 10) def test_torch_check_eq_commutativity(self): class M1(torch.nn.Module):