Revert "Exclude upsample_bilinear2d.vec from default core ATen decomposition table (#141791)"

This reverts commit 3d604b17d91b928c850ded83b2ec25ea066bb3f6.

Reverted https://github.com/pytorch/pytorch/pull/141791 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/141791#issuecomment-2649717140))
This commit is contained in:
PyTorch MergeBot
2025-02-11 03:17:59 +00:00
parent 30cbf13544
commit fe94ece375
2 changed files with 8 additions and 62 deletions

View File

@ -11881,48 +11881,6 @@ class GraphModule(torch.nn.Module):
]
self.assertEqual(len(shift_op), 1)
def test_default_decomposition_core_cia_ops(self):
"""
Verify that core ATen ops with Composite Implicit Autograd dispatch are not
decomposed by default.
"""
# TODO Add avg_pool1d, and adaptive_avg_pool1d when ready.
# See issue #116684.
core_cia_ops = {
"torch.ops.aten.upsample_bilinear2d.vec": (
torch.ops.aten.upsample_bilinear2d.vec,
{
"align_corners": False,
"scale_factors": [2, 2],
"output_size": None,
},
),
"torch.ops.aten.upsample_nearest2d.vec": (
torch.ops.aten.upsample_nearest2d.vec,
{
"scale_factors": [2, 2],
"output_size": None,
},
),
}
for op_name, (op, kwargs) in core_cia_ops.items():
class M(torch.nn.Module):
def forward(self, x):
return op(x, **kwargs)
ep = export(M(), (torch.randn(2, 3, 4, 5),))
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
decomp_table = default_decompositions()
ep = ep.run_decompositions(
decomp_table=decomp_table,
)
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
@ -12528,30 +12486,30 @@ class TestExportCustomClass(TorchTestCase):
torch.distributed.destroy_process_group()
def test_preserve_cia_op(self):
class StaticResizeTrilinear2dModule(torch.nn.Module):
class StaticResizeBilinear2dModule(torch.nn.Module):
def forward(self, x):
a = torch.nn.functional.interpolate(
x,
size=(x.shape[2] * 2, x.shape[3] * 3, x.shape[4] * 4),
mode="trilinear",
size=(x.shape[2] * 2, x.shape[3] * 3),
mode="bilinear",
align_corners=False,
antialias=False,
)
return a
ep = export(StaticResizeTrilinear2dModule(), (torch.randn(2, 3, 4, 5, 6),))
ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),))
FileCheck().check_count(
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
).run(ep.graph_module.code)
decomp_table = default_decompositions()
del decomp_table[torch.ops.aten.upsample_trilinear3d.vec]
del decomp_table[torch.ops.aten.upsample_bilinear2d.vec]
ep = ep.run_decompositions(
decomp_table=decomp_table,
)
FileCheck().check_count(
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
).run(ep.graph_module.code)

View File

@ -13,17 +13,6 @@ from torch._export.utils import (
__all__ = ["CustomDecompTable"]
"""
Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition
by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all
backends are ready, this list allows opt-in one at a time.
"""
PRESERVED_ATEN_CIA_OPS = {
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
}
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
"""
This is a custom dictionary that is specifically used for handling decomp_table in export.
@ -49,8 +38,7 @@ class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
self.decomp_table = _core_aten_decompositions_post_autograd()
for op in _collect_all_valid_cia_ops_for_aten_namespace():
if op not in PRESERVED_ATEN_CIA_OPS:
self.decomp_table[op] = _get_decomp_for_cia(op)
self.decomp_table[op] = _get_decomp_for_cia(op)
# This is to track the *pending* deleted custom ops that haven't been materialized yet
self.deleted_custom_ops = set()