mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Note: This is a re-land of https://github.com/pytorch/pytorch/pull/141791, which I reverted due to breaking some Meta-internal tests - an internal ET delegate did not handle the non-decomposed upsample_nearest2d, and it was not caught in CI. I've resolved that issue and should be ready to safely re-land. Summary: As upsample_bilinear2d.vec and upsample_nearest2d.vec are core ATen ops, they should not be decomposed by default in the export path. Because the operators have CompositeImplicitAutograd dispatch, their decomposition is registered by default. This change adds an override list for CIA decompositions being registered in the default decomp table. In the long-term, we likely will want to exclude decompositions for all core-tagged CIA ops, but this will require all consumers to be ready to handle the remaining two ops, avg_pool1d, and adaptive_avg_pool1d. Until they are ready, I believe an explicit override list is the safest option. Additionally, I've also removed the ExecuTorch XNNPACK delegate ConvertToUpsampleBilinear2d pass, as the pass breaks (and is not needed), given that the op is not decomposed. The purpose of this pass was originally to pattern match the decomposition and recompose it, but this is no longer necessary. Test Plan: Added a new test (`test_default_decomposition_core_cia_ops`) in test_export.py to verify that upsample_bilinear2d.vec (and in the future, other core-tagged CIA ops) are not decomposed by default. Also, I manually validated end to end with ExecuTorch that the op is not decomposed in to_edge (see N6238522). ``` buck test //caffe2/test:test_export -- test_default_decomposition_core_cia_ops ``` Differential Revision: D69625112 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147153 Approved by: https://github.com/manuelcandales
157 lines
5.4 KiB
Python
157 lines
5.4 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from torch._export.utils import (
|
|
_collect_all_valid_cia_ops,
|
|
_collect_all_valid_cia_ops_for_aten_namespace,
|
|
_get_decomp_for_cia,
|
|
_is_aten_op,
|
|
)
|
|
|
|
|
|
__all__ = ["CustomDecompTable"]
|
|
|
|
|
|
"""
|
|
Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition
|
|
by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all
|
|
backends are ready, this list allows opt-in one at a time.
|
|
"""
|
|
PRESERVED_ATEN_CIA_OPS = {
|
|
torch.ops.aten.upsample_bilinear2d.vec,
|
|
torch.ops.aten.upsample_nearest2d.vec,
|
|
}
|
|
|
|
|
|
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
|
|
"""
|
|
This is a custom dictionary that is specifically used for handling decomp_table in export.
|
|
The reason we need this is because in the new world, you can only *delete* an op from decomp
|
|
table to preserve it. This is problematic for custom ops because we don't know when the custom
|
|
op will actually be loaded to the dispatcher. As a result, we need to record the custom ops operations
|
|
until we really need to materialize it (which is when we run decomposition pass.)
|
|
|
|
Invariants we hold are:
|
|
1. All aten decomp is loaded at the init time
|
|
2. We materialize ALL ops when user ever reads from the table to make it more likely
|
|
that dispatcher picks up the custom op.
|
|
3. If it is write operation, we don't necessarily materialize
|
|
4. We load the final time during export, right before calling run_decompositions()
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
from torch._decomp import _core_aten_decompositions_post_autograd
|
|
|
|
# For aten ops, we load them up in the beginning
|
|
self.decomp_table = _core_aten_decompositions_post_autograd()
|
|
|
|
for op in _collect_all_valid_cia_ops_for_aten_namespace():
|
|
if op not in PRESERVED_ATEN_CIA_OPS:
|
|
self.decomp_table[op] = _get_decomp_for_cia(op)
|
|
|
|
# This is to track the *pending* deleted custom ops that haven't been materialized yet
|
|
self.deleted_custom_ops = set()
|
|
# When this is true, there shouldn't be any pending operations in the table.
|
|
self.has_materialized = False
|
|
|
|
def __getitem__(self, key):
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.__getitem__(key)
|
|
|
|
def __setitem__(self, key, value) -> None:
|
|
self.decomp_table.__setitem__(key, value)
|
|
|
|
if key in self.deleted_custom_ops:
|
|
self.deleted_custom_ops.remove(key)
|
|
|
|
def keys(self):
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.keys()
|
|
|
|
def __delitem__(self, key) -> None:
|
|
self.pop(key)
|
|
|
|
def update(self, other_dict): # type: ignore[override]
|
|
for k, v in other_dict.items():
|
|
self.decomp_table.__setitem__(k, v)
|
|
|
|
def __missing__(self, key) -> bool:
|
|
return not self.__contains__(key)
|
|
|
|
def __contains__(self, key) -> bool:
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.__contains__(key)
|
|
|
|
def __len__(self) -> int:
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.__len__()
|
|
|
|
def __iter__(self):
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.__iter__()
|
|
|
|
def __reversed__(self):
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.__reversed__()
|
|
|
|
def copy(self) -> "CustomDecompTable":
|
|
new_dict = CustomDecompTable()
|
|
new_dict.decomp_table = self.decomp_table.copy()
|
|
new_dict.deleted_custom_ops = self.deleted_custom_ops.copy()
|
|
new_dict.has_materialized = self.has_materialized
|
|
return new_dict
|
|
|
|
def pop(self, *args):
|
|
def _pop_if_can(key):
|
|
if _is_aten_op(key):
|
|
return self.decomp_table.pop(key)
|
|
|
|
if key in self.decomp_table:
|
|
# Even if we materialized it, we should add it to the deleted
|
|
# custom ops list so that when we materialize next time,
|
|
# we should respect user's intention.
|
|
self.deleted_custom_ops.add(key)
|
|
return self.decomp_table.pop(key)
|
|
|
|
if key in self.deleted_custom_ops:
|
|
raise KeyError(f"{key} doesn't exist in the table")
|
|
|
|
self.deleted_custom_ops.add(key)
|
|
# We would come here when user pops off something that is
|
|
# not in the table. In this case, we just pretend that it
|
|
# was in the table.
|
|
return _get_decomp_for_cia(key)
|
|
|
|
if len(args) == 1:
|
|
return _pop_if_can(args[0])
|
|
|
|
if len(args) == 2:
|
|
try:
|
|
return _pop_if_can(args[0])
|
|
except KeyError:
|
|
return args[1]
|
|
|
|
def items(self):
|
|
self._materialize_if_needed()
|
|
return self.decomp_table.items()
|
|
|
|
def materialize(self) -> dict[torch._ops.OperatorBase, Callable]:
|
|
for op in _collect_all_valid_cia_ops():
|
|
if _is_aten_op(op):
|
|
continue
|
|
elif op in self.decomp_table:
|
|
continue
|
|
elif op not in self.deleted_custom_ops:
|
|
self.decomp_table[op] = _get_decomp_for_cia(op)
|
|
|
|
self.has_materialized = True
|
|
self.deleted_custom_ops = set()
|
|
return {**self.decomp_table}
|
|
|
|
def _materialize_if_needed(self) -> None:
|
|
if not self.has_materialized:
|
|
self.materialize()
|