mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Add test for decomp_table update (#153671)
Added a test to strengthen the case for cherry-picking #153168. The original PR didn’t include this test since the fix for decomp_table and the registry was already covered by existing tests. However, it's reasonable to include a dedicated test for the specific issue (https://github.com/pytorch/pytorch/issues/150367 ) when considering the cherry-pick. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153671 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
3fe42d4d5d
commit
658d17dfb5
@ -374,6 +374,51 @@ class TestCustomTranslationTable(common_utils.TestCase):
|
||||
self.assertIn("Sub", all_nodes)
|
||||
self.assertNotIn("Add", all_nodes)
|
||||
|
||||
def test_custom_translation_table_supports_custom_op_with_its_decomp(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor a, Tensor b) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
|
||||
@torch.library.register_fake("mylib::foo")
|
||||
def foo_impl(a, b):
|
||||
return a + b
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.ops.mylib.foo(x, y)
|
||||
|
||||
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
|
||||
# Replace add with Sub
|
||||
return op.Sub(self, other)
|
||||
|
||||
# With the custom op defined, we can use it in the model
|
||||
# and replace it with a custom translation table
|
||||
custom_translation_table = {
|
||||
torch.ops.mylib.foo.default: onnx_add,
|
||||
}
|
||||
onnx_program = torch.onnx.export(
|
||||
M(),
|
||||
(torch.ones(3, 3), torch.ones(3, 3)),
|
||||
custom_translation_table=custom_translation_table,
|
||||
dynamo=True,
|
||||
)
|
||||
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
||||
self.assertIn("Sub", all_nodes)
|
||||
self.assertNotIn("Add", all_nodes)
|
||||
|
||||
# Without the custom op defined, it's going to be decomposed
|
||||
onnx_program_decomp = torch.onnx.export(
|
||||
M(), (torch.ones(3, 3), torch.ones(3, 3)), dynamo=True
|
||||
)
|
||||
all_nodes_decomp = [n.op_type for n in onnx_program_decomp.model.graph]
|
||||
self.assertIn("Add", all_nodes_decomp)
|
||||
self.assertNotIn("Sub", all_nodes_decomp)
|
||||
|
||||
|
||||
class TestFakeTensorExport(common_utils.TestCase):
|
||||
"""Test exporting in fake mode."""
|
||||
|
Reference in New Issue
Block a user