mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit c1cdb1216b97970d903a6d6e9e7d0e2b4ffaef46. Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
import torch
|
|
import torch._ops
|
|
import torch.library
|
|
from typing import Callable, Union, Dict, Sequence
|
|
from torch.utils._pytree import tree_map
|
|
from collections import defaultdict
|
|
|
|
__all__ = ["decomposition_table", "register_decomposition", "get_decompositions"]
|
|
|
|
# TODO: relax key type here; torch registrations should be possible to; but
|
|
# right now this type is accurate
|
|
decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
|
|
|
|
|
|
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
|
|
|
|
|
def register_decomposition(aten_op, registry=None, *, register_meta: bool = False):
|
|
"""
|
|
A decorator to register a function as a decomposition to the Python
|
|
decomposition table. Use it like this::
|
|
|
|
@register_decomposition(torch.ops.aten.clamp_min)
|
|
def clamp_min(x):
|
|
return torch.clamp(self, min=min)
|
|
|
|
If you are writing a new decomposition, consider contributing it
|
|
directly to PyTorch in torch._decomp.decompositions.
|
|
|
|
This API is experimental; we are almost certainly going to extend
|
|
the API when we make decompositions eligible for use in transforms (e.g.,
|
|
autograd) and not just backend tracing, where we then need to know if a
|
|
decomposition can be used to simulate a transform.
|
|
|
|
If `register_meta` is True, we will also register this function to the
|
|
Meta key in the dispatcher, so that it will be used to compute meta
|
|
tensors.
|
|
"""
|
|
def decomposition_decorator(f):
|
|
nonlocal registry
|
|
if registry is None:
|
|
registry = decomposition_table
|
|
|
|
def add_op_to_table(aten_op):
|
|
overloads = []
|
|
if isinstance(aten_op, torch._ops.OpOverload):
|
|
overloads.append(aten_op)
|
|
else:
|
|
assert isinstance(aten_op, torch._ops.OpOverloadPacket)
|
|
for ol in aten_op.overloads():
|
|
overloads.append(getattr(aten_op, ol))
|
|
for op_overload in overloads:
|
|
if op_overload in registry:
|
|
raise RuntimeError(f"duplicate registrations for {op_overload}")
|
|
registry[op_overload] = f
|
|
if register_meta:
|
|
meta_lib.impl(op_overload, f)
|
|
|
|
# To handle allowing multiple aten_ops at once
|
|
tree_map(add_op_to_table, aten_op)
|
|
return f
|
|
|
|
return decomposition_decorator
|
|
|
|
|
|
def get_decompositions(
|
|
aten_ops: Sequence[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]]
|
|
) -> Dict[torch._ops.OpOverload, Callable]:
|
|
"""
|
|
Retrieve a dictionary of decompositions corresponding to the list of
|
|
operator overloads and overload packets passed as input. Overload
|
|
packets will include all decomposed overloads in the packet. If there is
|
|
no decomposition for a requested operator, it is silently ignored.
|
|
|
|
This API is experimental; we are almost certainly going to give an alternate,
|
|
more recommended formulation, where a user provides the set of operators
|
|
they know how to implement, and we provide decompositions for everything
|
|
not in this set.
|
|
"""
|
|
packets_to_overloads = defaultdict(list)
|
|
for opo in decomposition_table:
|
|
packets_to_overloads[opo.overloadpacket].append(opo)
|
|
decompositions = {}
|
|
for op in aten_ops:
|
|
if isinstance(op, torch._ops.OpOverloadPacket) and op in packets_to_overloads:
|
|
for op_overload in packets_to_overloads[op]:
|
|
decompositions[op_overload] = decomposition_table[op_overload]
|
|
elif isinstance(op, torch._ops.OpOverload) and op in decomposition_table:
|
|
decompositions[op] = decomposition_table[op]
|
|
return decompositions
|
|
|
|
# populate the table
|
|
import torch._decomp.decompositions
|
|
import torch._refs
|