[BE][export] Change custom_op registeration style (#145315)

Summary:
`test_unbacked_bindings_for_divisible_u_symint` has been flaky for a while due to

```
Tried to register an operator (mylib::foo(Tensor a, Tensor b) -> Tensor) with the same name and overload name multiple times.
```

It is likely due to when all variants of this test are being run (non-strict, retrace, serdes) simultaneously. In later tests, the operator has already been registered.

In this diff, we change registration style.

Test Plan:
```
buck2 test mode/dev-nosan caffe2/test:test_export -- -r test_unbacked_bindings_for_divisible_u_symint
```

Differential Revision: D68465258

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145315
Approved by: https://github.com/zou3519
This commit is contained in:
Yiming Zhou
2025-01-22 23:46:50 +00:00
committed by PyTorch MergeBot
parent 4803e20bc7
commit d0a2e11284

View File

@ -3714,43 +3714,35 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
@testing.expectedFailureSerDerNonStrict
def test_unbacked_bindings_for_divisible_u_symint(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
class M(torch.nn.Module):
def forward(self, a, b):
return torch.ops.mylib.foo_unbacked(a, b)
class M(torch.nn.Module):
def forward(self, a, b):
return torch.ops.mylib.foo(a, b)
@torch.library.custom_op("mylib::foo_unbacked", mutates_args={})
def foo_unbacked(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a[b.item()]
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo_impl(a, b):
return a[b.item()]
@foo_unbacked.register_fake
def foo_unbacked_fake_impl(a, b):
ctx = torch.library.get_ctx()
u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10
return torch.empty(u, a.shape[1], dtype=a.dtype)
@torch.library.register_fake("mylib::foo", lib=lib)
def foo_fake_impl(a, b):
ctx = torch.library.get_ctx()
u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10
return torch.empty(u, a.shape[1], dtype=a.dtype)
ep = export(
M(),
(torch.randn(100, 4), torch.tensor(10)),
)
foo = [node for node in ep.graph.nodes if node.name == "foo"][0]
unbacked_bindings = foo.meta["unbacked_bindings"]
self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path}
u = next(iter(unbacked_bindings.keys()))
self.assertEqual(
type(u).__name__, "Symbol"
) # check binding is symbol, not expr
path = unbacked_bindings[u]
self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)]
self.assertEqual(type(path[2]).__name__, "DivideByKey")
self.assertEqual(path[2].divisor, 10)
ep = export(
M(),
(torch.randn(100, 4), torch.tensor(10)),
)
foo = [node for node in ep.graph.nodes if node.name == "foo_unbacked"][0]
unbacked_bindings = foo.meta["unbacked_bindings"]
self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path}
u = next(iter(unbacked_bindings.keys()))
self.assertEqual(
type(u).__name__, "Symbol"
) # check binding is symbol, not expr
path = unbacked_bindings[u]
self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)]
self.assertEqual(type(path[2]).__name__, "DivideByKey")
self.assertEqual(path[2].divisor, 10)
def test_torch_check_eq_commutativity(self):
class M1(torch.nn.Module):