[ONNX] Add test for decomp_table update (#153671)

Added a test to strengthen the case for cherry-picking #153168. The original PR didn’t include this test since the fix for decomp_table and the registry was already covered by existing tests. However, it's reasonable to include a dedicated test for the specific issue (https://github.com/pytorch/pytorch/issues/150367 ) when considering the cherry-pick.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153671
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang
2025-05-16 08:00:13 +00:00
committed by PyTorch MergeBot
parent 3fe42d4d5d
commit 658d17dfb5

View File

@ -374,6 +374,51 @@ class TestCustomTranslationTable(common_utils.TestCase):
self.assertIn("Sub", all_nodes)
self.assertNotIn("Add", all_nodes)
def test_custom_translation_table_supports_custom_op_with_its_decomp(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
@torch.library.register_fake("mylib::foo")
def foo_impl(a, b):
return a + b
class M(torch.nn.Module):
def forward(self, x, y):
return torch.ops.mylib.foo(x, y)
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
# Replace add with Sub
return op.Sub(self, other)
# With the custom op defined, we can use it in the model
# and replace it with a custom translation table
custom_translation_table = {
torch.ops.mylib.foo.default: onnx_add,
}
onnx_program = torch.onnx.export(
M(),
(torch.ones(3, 3), torch.ones(3, 3)),
custom_translation_table=custom_translation_table,
dynamo=True,
)
all_nodes = [n.op_type for n in onnx_program.model.graph]
self.assertIn("Sub", all_nodes)
self.assertNotIn("Add", all_nodes)
# Without the custom op defined, it's going to be decomposed
onnx_program_decomp = torch.onnx.export(
M(), (torch.ones(3, 3), torch.ones(3, 3)), dynamo=True
)
all_nodes_decomp = [n.op_type for n in onnx_program_decomp.model.graph]
self.assertIn("Add", all_nodes_decomp)
self.assertNotIn("Sub", all_nodes_decomp)
class TestFakeTensorExport(common_utils.TestCase):
"""Test exporting in fake mode."""