mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Exclude upsample_bilinear2d.vec from default core ATen decomposition table (#141791)
As upsample_bilinear2d.vec is a core ATen op, it should not be decomposed by default in the export path. Because the operator has CompositeImplicitAutograd dispatch, its decomposition is registered by default. This change adds an override list for CIA decompositions being registered in the default decomp table. In the long-term, we likely will want to exclude decompositions for all core-tagged CIA ops, but this will require all consumers to be ready to handle the remaining three ops: upsample_nearest2d.vec, avg_pool1d, and adaptive_avg_pool1d. Until they are ready, I believe an explicit override list is the safest option. Additionally, I've also removed the ExecuTorch XNNPACK delegate ConvertToUpsampleBilinear2d pass, as the pass breaks (and is not needed), given that the op is not decomposed. The purpose of this pass was originally to pattern match the decomposition and un-decomposite it, but this is no longer necessary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141791 Approved by: https://github.com/tugsbayasgalan, https://github.com/digantdesai
This commit is contained in:
committed by
PyTorch MergeBot
parent
97f6480cf5
commit
3d604b17d9
@ -13,6 +13,17 @@ from torch._export.utils import (
|
||||
__all__ = ["CustomDecompTable"]
|
||||
|
||||
|
||||
"""
|
||||
Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition
|
||||
by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all
|
||||
backends are ready, this list allows opt-in one at a time.
|
||||
"""
|
||||
PRESERVED_ATEN_CIA_OPS = {
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.upsample_nearest2d.vec,
|
||||
}
|
||||
|
||||
|
||||
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
|
||||
"""
|
||||
This is a custom dictionary that is specifically used for handling decomp_table in export.
|
||||
@ -38,7 +49,8 @@ class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
|
||||
self.decomp_table = _core_aten_decompositions_post_autograd()
|
||||
|
||||
for op in _collect_all_valid_cia_ops_for_aten_namespace():
|
||||
self.decomp_table[op] = _get_decomp_for_cia(op)
|
||||
if op not in PRESERVED_ATEN_CIA_OPS:
|
||||
self.decomp_table[op] = _get_decomp_for_cia(op)
|
||||
|
||||
# This is to track the *pending* deleted custom ops that haven't been materialized yet
|
||||
self.deleted_custom_ops = set()
|
||||
|
Reference in New Issue
Block a user