mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Exclude upsample_bilinear2d.vec from default core ATen decomposition table (#141791)"
This reverts commit 3d604b17d91b928c850ded83b2ec25ea066bb3f6. Reverted https://github.com/pytorch/pytorch/pull/141791 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/141791#issuecomment-2649717140))
This commit is contained in:
@ -11881,48 +11881,6 @@ class GraphModule(torch.nn.Module):
|
||||
]
|
||||
self.assertEqual(len(shift_op), 1)
|
||||
|
||||
def test_default_decomposition_core_cia_ops(self):
|
||||
"""
|
||||
Verify that core ATen ops with Composite Implicit Autograd dispatch are not
|
||||
decomposed by default.
|
||||
"""
|
||||
|
||||
# TODO Add avg_pool1d, and adaptive_avg_pool1d when ready.
|
||||
# See issue #116684.
|
||||
core_cia_ops = {
|
||||
"torch.ops.aten.upsample_bilinear2d.vec": (
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
{
|
||||
"align_corners": False,
|
||||
"scale_factors": [2, 2],
|
||||
"output_size": None,
|
||||
},
|
||||
),
|
||||
"torch.ops.aten.upsample_nearest2d.vec": (
|
||||
torch.ops.aten.upsample_nearest2d.vec,
|
||||
{
|
||||
"scale_factors": [2, 2],
|
||||
"output_size": None,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
for op_name, (op, kwargs) in core_cia_ops.items():
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return op(x, **kwargs)
|
||||
|
||||
ep = export(M(), (torch.randn(2, 3, 4, 5),))
|
||||
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
|
||||
|
||||
decomp_table = default_decompositions()
|
||||
|
||||
ep = ep.run_decompositions(
|
||||
decomp_table=decomp_table,
|
||||
)
|
||||
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
@ -12528,30 +12486,30 @@ class TestExportCustomClass(TorchTestCase):
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
def test_preserve_cia_op(self):
|
||||
class StaticResizeTrilinear2dModule(torch.nn.Module):
|
||||
class StaticResizeBilinear2dModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
a = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(x.shape[2] * 2, x.shape[3] * 3, x.shape[4] * 4),
|
||||
mode="trilinear",
|
||||
size=(x.shape[2] * 2, x.shape[3] * 3),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=False,
|
||||
)
|
||||
return a
|
||||
|
||||
ep = export(StaticResizeTrilinear2dModule(), (torch.randn(2, 3, 4, 5, 6),))
|
||||
ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),))
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
|
||||
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
decomp_table = default_decompositions()
|
||||
del decomp_table[torch.ops.aten.upsample_trilinear3d.vec]
|
||||
del decomp_table[torch.ops.aten.upsample_bilinear2d.vec]
|
||||
ep = ep.run_decompositions(
|
||||
decomp_table=decomp_table,
|
||||
)
|
||||
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
|
||||
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
|
||||
|
@ -13,17 +13,6 @@ from torch._export.utils import (
|
||||
__all__ = ["CustomDecompTable"]
|
||||
|
||||
|
||||
"""
|
||||
Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition
|
||||
by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all
|
||||
backends are ready, this list allows opt-in one at a time.
|
||||
"""
|
||||
PRESERVED_ATEN_CIA_OPS = {
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.upsample_nearest2d.vec,
|
||||
}
|
||||
|
||||
|
||||
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
|
||||
"""
|
||||
This is a custom dictionary that is specifically used for handling decomp_table in export.
|
||||
@ -49,8 +38,7 @@ class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
|
||||
self.decomp_table = _core_aten_decompositions_post_autograd()
|
||||
|
||||
for op in _collect_all_valid_cia_ops_for_aten_namespace():
|
||||
if op not in PRESERVED_ATEN_CIA_OPS:
|
||||
self.decomp_table[op] = _get_decomp_for_cia(op)
|
||||
self.decomp_table[op] = _get_decomp_for_cia(op)
|
||||
|
||||
# This is to track the *pending* deleted custom ops that haven't been materialized yet
|
||||
self.deleted_custom_ops = set()
|
||||
|
Reference in New Issue
Block a user