mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][export] Change custom_op registeration style (#145315)
Summary: `test_unbacked_bindings_for_divisible_u_symint` has been flaky for a while due to ``` Tried to register an operator (mylib::foo(Tensor a, Tensor b) -> Tensor) with the same name and overload name multiple times. ``` It is likely due to when all variants of this test are being run (non-strict, retrace, serdes) simultaneously. In later tests, the operator has already been registered. In this diff, we change registration style. Test Plan: ``` buck2 test mode/dev-nosan caffe2/test:test_export -- -r test_unbacked_bindings_for_divisible_u_symint ``` Differential Revision: D68465258 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145315 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
4803e20bc7
commit
d0a2e11284
@ -3714,43 +3714,35 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
def test_unbacked_bindings_for_divisible_u_symint(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor a, Tensor b) -> (Tensor)",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
return torch.ops.mylib.foo_unbacked(a, b)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
return torch.ops.mylib.foo(a, b)
|
||||
@torch.library.custom_op("mylib::foo_unbacked", mutates_args={})
|
||||
def foo_unbacked(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a[b.item()]
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
def foo_impl(a, b):
|
||||
return a[b.item()]
|
||||
@foo_unbacked.register_fake
|
||||
def foo_unbacked_fake_impl(a, b):
|
||||
ctx = torch.library.get_ctx()
|
||||
u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10
|
||||
return torch.empty(u, a.shape[1], dtype=a.dtype)
|
||||
|
||||
@torch.library.register_fake("mylib::foo", lib=lib)
|
||||
def foo_fake_impl(a, b):
|
||||
ctx = torch.library.get_ctx()
|
||||
u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10
|
||||
return torch.empty(u, a.shape[1], dtype=a.dtype)
|
||||
|
||||
ep = export(
|
||||
M(),
|
||||
(torch.randn(100, 4), torch.tensor(10)),
|
||||
)
|
||||
foo = [node for node in ep.graph.nodes if node.name == "foo"][0]
|
||||
unbacked_bindings = foo.meta["unbacked_bindings"]
|
||||
self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path}
|
||||
u = next(iter(unbacked_bindings.keys()))
|
||||
self.assertEqual(
|
||||
type(u).__name__, "Symbol"
|
||||
) # check binding is symbol, not expr
|
||||
path = unbacked_bindings[u]
|
||||
self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)]
|
||||
self.assertEqual(type(path[2]).__name__, "DivideByKey")
|
||||
self.assertEqual(path[2].divisor, 10)
|
||||
ep = export(
|
||||
M(),
|
||||
(torch.randn(100, 4), torch.tensor(10)),
|
||||
)
|
||||
foo = [node for node in ep.graph.nodes if node.name == "foo_unbacked"][0]
|
||||
unbacked_bindings = foo.meta["unbacked_bindings"]
|
||||
self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path}
|
||||
u = next(iter(unbacked_bindings.keys()))
|
||||
self.assertEqual(
|
||||
type(u).__name__, "Symbol"
|
||||
) # check binding is symbol, not expr
|
||||
path = unbacked_bindings[u]
|
||||
self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)]
|
||||
self.assertEqual(type(path[2]).__name__, "DivideByKey")
|
||||
self.assertEqual(path[2].divisor, 10)
|
||||
|
||||
def test_torch_check_eq_commutativity(self):
|
||||
class M1(torch.nn.Module):
|
||||
|
Reference in New Issue
Block a user