mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[custom_ops] expose torch.library.register_torch_dispatch (#130261)"
This reverts commit bb9a73f767526e0d23c60360db5212b6bed0e8bc. Reverted https://github.com/pytorch/pytorch/pull/130261 on behalf of https://github.com/izaitsevfb due to depends on #130064 which needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/130261#issuecomment-2221569707))
This commit is contained in:
@ -27,7 +27,6 @@ __all__ = [
|
||||
"fallthrough_kernel",
|
||||
"impl_abstract",
|
||||
"register_fake",
|
||||
"register_torch_dispatch",
|
||||
"get_ctx",
|
||||
"custom_op",
|
||||
]
|
||||
@ -871,87 +870,6 @@ def register_autograd(
|
||||
lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
|
||||
|
||||
|
||||
def register_torch_dispatch(
|
||||
op: _op_identifier,
|
||||
torch_dispatch_class: Any,
|
||||
func: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
):
|
||||
r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
|
||||
|
||||
This allows for open registration to specify the behavior between the operator
|
||||
and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
|
||||
or the operator directly.
|
||||
|
||||
The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a
|
||||
TorchDispatchMode.
|
||||
|
||||
If it is a Tensor subclass, we expect ``func`` to have the following signature:
|
||||
``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
|
||||
|
||||
If it is a TorchDispatchMode, we expect ``func`` to have the following signature:
|
||||
``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
|
||||
|
||||
``args`` and ``kwargs`` will have been normalized the same way they are
|
||||
in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`).
|
||||
|
||||
Examples:
|
||||
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> @torch.library.custom_op("mylib::foo", mutates_args={})
|
||||
>>> def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
>>> return x.clone()
|
||||
>>>
|
||||
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
|
||||
>>> def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
>>> return func(*args, **kwargs)
|
||||
>>>
|
||||
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
|
||||
>>> def _(mode, func, types, args, kwargs):
|
||||
>>> x, = args
|
||||
>>> return x + 1
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = foo(x)
|
||||
>>> assert torch.allclose(y, x)
|
||||
>>>
|
||||
>>> with MyMode():
|
||||
>>> y = foo(x)
|
||||
>>> assert torch.allclose(y, x + 1)
|
||||
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
"register_torch_dispatch(op): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
return opdef.register_torch_dispatch(torch_dispatch_class, func)
|
||||
assert isinstance(op, str)
|
||||
|
||||
def register(func):
|
||||
namespace, op_name = torch._library.utils.parse_namespace(op)
|
||||
if lib is None:
|
||||
use_lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(use_lib)
|
||||
else:
|
||||
use_lib = lib
|
||||
use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func)
|
||||
return func
|
||||
|
||||
if func is None:
|
||||
return register
|
||||
else:
|
||||
return register(func)
|
||||
|
||||
|
||||
# If the op was defined in C++, then we want to make sure there was an
|
||||
# m.set_python_module(module, ...) call and that the module is the
|
||||
# same as the module that called torch.library.register_fake.
|
||||
|
Reference in New Issue
Block a user