mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][export] Change custom_op registeration style (#145315)
Summary: `test_unbacked_bindings_for_divisible_u_symint` has been flaky for a while due to ``` Tried to register an operator (mylib::foo(Tensor a, Tensor b) -> Tensor) with the same name and overload name multiple times. ``` It is likely due to when all variants of this test are being run (non-strict, retrace, serdes) simultaneously. In later tests, the operator has already been registered. In this diff, we change registration style. Test Plan: ``` buck2 test mode/dev-nosan caffe2/test:test_export -- -r test_unbacked_bindings_for_divisible_u_symint ``` Differential Revision: D68465258 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145315 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
4803e20bc7
commit
d0a2e11284
@ -3714,43 +3714,35 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||||||
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
|
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
def test_unbacked_bindings_for_divisible_u_symint(self):
|
def test_unbacked_bindings_for_divisible_u_symint(self):
|
||||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
class M(torch.nn.Module):
|
||||||
torch.library.define(
|
def forward(self, a, b):
|
||||||
"mylib::foo",
|
return torch.ops.mylib.foo_unbacked(a, b)
|
||||||
"(Tensor a, Tensor b) -> (Tensor)",
|
|
||||||
tags=torch.Tag.pt2_compliant_tag,
|
|
||||||
lib=lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
@torch.library.custom_op("mylib::foo_unbacked", mutates_args={})
|
||||||
def forward(self, a, b):
|
def foo_unbacked(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.ops.mylib.foo(a, b)
|
return a[b.item()]
|
||||||
|
|
||||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
@foo_unbacked.register_fake
|
||||||
def foo_impl(a, b):
|
def foo_unbacked_fake_impl(a, b):
|
||||||
return a[b.item()]
|
ctx = torch.library.get_ctx()
|
||||||
|
u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10
|
||||||
|
return torch.empty(u, a.shape[1], dtype=a.dtype)
|
||||||
|
|
||||||
@torch.library.register_fake("mylib::foo", lib=lib)
|
ep = export(
|
||||||
def foo_fake_impl(a, b):
|
M(),
|
||||||
ctx = torch.library.get_ctx()
|
(torch.randn(100, 4), torch.tensor(10)),
|
||||||
u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10
|
)
|
||||||
return torch.empty(u, a.shape[1], dtype=a.dtype)
|
foo = [node for node in ep.graph.nodes if node.name == "foo_unbacked"][0]
|
||||||
|
unbacked_bindings = foo.meta["unbacked_bindings"]
|
||||||
ep = export(
|
self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path}
|
||||||
M(),
|
u = next(iter(unbacked_bindings.keys()))
|
||||||
(torch.randn(100, 4), torch.tensor(10)),
|
self.assertEqual(
|
||||||
)
|
type(u).__name__, "Symbol"
|
||||||
foo = [node for node in ep.graph.nodes if node.name == "foo"][0]
|
) # check binding is symbol, not expr
|
||||||
unbacked_bindings = foo.meta["unbacked_bindings"]
|
path = unbacked_bindings[u]
|
||||||
self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path}
|
self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)]
|
||||||
u = next(iter(unbacked_bindings.keys()))
|
self.assertEqual(type(path[2]).__name__, "DivideByKey")
|
||||||
self.assertEqual(
|
self.assertEqual(path[2].divisor, 10)
|
||||||
type(u).__name__, "Symbol"
|
|
||||||
) # check binding is symbol, not expr
|
|
||||||
path = unbacked_bindings[u]
|
|
||||||
self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)]
|
|
||||||
self.assertEqual(type(path[2]).__name__, "DivideByKey")
|
|
||||||
self.assertEqual(path[2].divisor, 10)
|
|
||||||
|
|
||||||
def test_torch_check_eq_commutativity(self):
|
def test_torch_check_eq_commutativity(self):
|
||||||
class M1(torch.nn.Module):
|
class M1(torch.nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user