mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 08:11:06 +08:00 
			
		
		
		
	See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127838 Approved by: https://github.com/oulgen
		
			
				
	
	
		
			874 lines
		
	
	
		
			36 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			874 lines
		
	
	
		
			36 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| import dataclasses
 | |
| import functools
 | |
| import inspect
 | |
| import sys
 | |
| import typing
 | |
| import weakref
 | |
| 
 | |
| from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
 | |
| 
 | |
| import torch
 | |
| import torch._C as _C
 | |
| import torch.library as library
 | |
| from torch._library.abstract_impl import AbstractImplCtx
 | |
| from torch.library import get_ctx
 | |
| 
 | |
| from .autograd import autograd_kernel_indirection, construct_autograd_kernel
 | |
| import torch._library.infer_schema
 | |
| from torch._library.infer_schema import infer_schema
 | |
| 
 | |
| """
 | |
| For a detailed guide on custom ops, please see
 | |
| https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
 | |
| 
 | |
| This file includes pieces of the implementation of our custom operator API.
 | |
| """
 | |
| 
 | |
| __all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
 | |
| 
 | |
| 
 | |
| SUPPORTED_DEVICE_TYPE_TO_KEY = {
 | |
|     "cpu": "CPU",
 | |
|     "cuda": "CUDA",
 | |
| }
 | |
| 
 | |
| # We will not let users register CustomOps with anything that could look like
 | |
| # PyTorch internals to avoid confusion.
 | |
| RESERVED_NS = {
 | |
|     "prim",
 | |
|     "prims",
 | |
|     "aten",
 | |
|     "at",
 | |
|     "torch",
 | |
|     "pytorch",
 | |
| }
 | |
| 
 | |
| 
 | |
| def custom_op(
 | |
|     qualname: str, manual_schema: typing.Optional[str] = None
 | |
| ) -> typing.Callable:
 | |
|     r"""Creates a new CustomOp object.
 | |
| 
 | |
|     WARNING: if you're a user, please do not use this directly
 | |
|     (instead use the torch._custom_ops APIs).
 | |
|     Also please see the following for a detailed guide on custom ops.
 | |
|     https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
 | |
| 
 | |
|     In PyTorch, defining an op (short for "operator") is a two step-process:
 | |
|     - we need to define (create) the op
 | |
|     - we need to implement behavior for how the operator interacts with
 | |
|       various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
 | |
| 
 | |
|     This entrypoint defines the CustomOp object (the first step);
 | |
|     you must then perform the second step by calling various methods on
 | |
|     the CustomOp object.
 | |
| 
 | |
|     This API is used as a decorator (see examples).
 | |
| 
 | |
|     Arguments:
 | |
|         qualname (str): Should be a string that looks like
 | |
|             "namespace::operator_name". Operators in PyTorch need a namespace to
 | |
|             avoid name collisions; a given operator may only be created once.
 | |
|             If you are writing a Python library, we recommend the namespace to
 | |
|             be the name of your top-level module. The operator_name must be
 | |
|             the same as the name of the function you pass to custom_op
 | |
|             (see examples).
 | |
|         manual_schema (Optional[str]): Each PyTorch operator needs a schema that
 | |
|             tells PyTorch the types of the inputs/outputs. If None (default),
 | |
|             we will infer the schema from the type annotations on the function
 | |
|             (see examples). Otherwise, if you don't want to use type annotations,
 | |
|             you may provide us the schema string.
 | |
| 
 | |
|     Example::
 | |
|         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
 | |
|         >>> import numpy as np
 | |
|         >>> from torch import Tensor
 | |
|         >>>
 | |
|         >>> # Step 1: define the CustomOp.
 | |
|         >>> # We need to provide the decorator a "prototype function"
 | |
|         >>> # (a function with Python ellipses as the body).
 | |
|         >>> @custom_op("my_library::numpy_sin")
 | |
|         >>> def numpy_sin(x: Tensor) -> Tensor:
 | |
|         >>>     ...
 | |
|         >>>
 | |
|         >>> # numpy_sin is now an instance of class CustomOp
 | |
|         >>> print(type(numpy_sin))
 | |
|         >>>
 | |
|         >>> # Step 2: Register an implementation for various PyTorch subsystems
 | |
|         >>>
 | |
|         >>> # Register an implementation for CPU tensors
 | |
|         >>> @numpy_sin.impl('cpu')
 | |
|         >>> def numpy_sin_impl_cpu(x):
 | |
|         >>>     return torch.from_numpy(np.sin(x.numpy()))
 | |
|         >>>
 | |
|         >>> # Register an implementation for CUDA tensors
 | |
|         >>> @numpy_sin.impl('cuda')
 | |
|         >>> def numpy_sin_impl_cuda(x):
 | |
|         >>>     return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
 | |
|         >>>
 | |
|         >>> x = torch.randn(3)
 | |
|         >>> numpy_sin(x)  # calls numpy_sin_impl_cpu
 | |
|         >>>
 | |
|         >>> x_cuda = x.cuda()
 | |
|         >>> numpy_sin(x)  # calls numpy_sin_impl_cuda
 | |
| 
 | |
|     """
 | |
| 
 | |
|     def inner(func):
 | |
|         if not inspect.isfunction(func):
 | |
|             raise ValueError(
 | |
|                 f"custom_op(...)(func): Expected `func` to be a Python "
 | |
|                 f"function, got: {type(func)}"
 | |
|             )
 | |
| 
 | |
|         ns, name = parse_qualname(qualname)
 | |
|         validate_namespace(ns)
 | |
|         if func.__name__ != name:
 | |
|             raise ValueError(
 | |
|                 f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
 | |
|                 f"to have name '{name}' but got '{func.__name__}'. "
 | |
|                 f"Please either change the name of `func` or the qualname that "
 | |
|                 f"is passed to `custom_op`"
 | |
|             )
 | |
| 
 | |
|         schema = infer_schema(func) if manual_schema is None else manual_schema
 | |
|         schema_str = f"{name}{schema}"
 | |
|         function_schema = FunctionSchema.parse(schema_str)
 | |
|         validate_schema(function_schema)
 | |
|         if manual_schema is not None:
 | |
|             validate_function_matches_schema(function_schema, func)
 | |
| 
 | |
|         lib = library.Library(ns, "FRAGMENT")
 | |
|         lib.define(schema_str)
 | |
|         ophandle = find_ophandle_or_throw(ns, function_schema.name)
 | |
|         result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
 | |
| 
 | |
|         result.__name__ = func.__name__
 | |
|         result.__module__ = func.__module__
 | |
|         result.__doc__ = func.__doc__
 | |
| 
 | |
|         library.impl(lib, result._opname, "Autograd")(
 | |
|             autograd_kernel_indirection(weakref.proxy(result))
 | |
|         )
 | |
| 
 | |
|         torch._C._dispatch_set_report_error_callback(
 | |
|             ophandle, functools.partial(report_error_callback, weakref.proxy(result))
 | |
|         )
 | |
| 
 | |
|         return result
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| # Global dictionary holding references to all CustomOp objects
 | |
| # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
 | |
| # Used to query the CustomOp associated with a specific C++ dispatcher operator.
 | |
| # An example usage is FakeTensor: FakeTensor checks if a specific operator
 | |
| # has an implementation registered via the CustomOp API.
 | |
| # Indexed by qualname (e.g. aten::foo)
 | |
| global_registry: typing.Dict[str, "CustomOp"] = {}
 | |
| 
 | |
| 
 | |
| class CustomOp:
 | |
|     r"""Class for custom operators in PyTorch.
 | |
| 
 | |
|     Use the CustomOp API to create user-defined custom operators that behave
 | |
|     just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
 | |
|     comes to various PyTorch subsystems (like torch.compile).
 | |
| 
 | |
|     To construct a `CustomOp`, use `custom_op`.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
 | |
|         super().__init__()
 | |
|         if not _private_access:
 | |
|             raise RuntimeError(
 | |
|                 "The CustomOp constructor is private and we do not guarantee "
 | |
|                 "BC for it. Please use custom_op(...) to create a CustomOp object"
 | |
|             )
 | |
|         name = f"{cpp_ns}::{operator_name}"
 | |
|         self._schema = schema
 | |
|         self._cpp_ns = cpp_ns
 | |
|         self._lib: library.Library = lib
 | |
|         self._ophandle: _C._DispatchOperatorHandle = ophandle
 | |
|         # Has the name of the op, e.g. "foo". We cache here for convenience.
 | |
|         self._opname: str = operator_name
 | |
|         # this is _opname but with namespace. e.g. "custom::foo"
 | |
|         self._qualname: str = name
 | |
|         self.__name__ = None  # mypy requires this
 | |
|         # NB: Some of these impls are registered as kernels to DispatchKeys.
 | |
|         # Modifying the _impls dict directly won't do anything in that case.
 | |
|         self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
 | |
|         # See NOTE [CustomOp autograd kernel indirection]
 | |
|         self._registered_autograd_kernel_indirection = False
 | |
| 
 | |
|         global_registry[self._qualname] = self
 | |
| 
 | |
|     def _register_autograd_kernel_indirection(self):
 | |
|         assert not self._registered_autograd_kernel_indirection
 | |
|         self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
 | |
|         self._registered_autograd_kernel_indirection = True
 | |
| 
 | |
|     # Records the impl and the source location in self._impls
 | |
|     # Note that this doesn't cause torch.library to use the impl, that
 | |
|     # needs to be done in a separate self._lib.impl call.
 | |
|     def _register_impl(self, kind, func, stacklevel=2):
 | |
|         if self._has_impl(kind):
 | |
|             func_and_location = self._impls[kind]
 | |
|             assert func_and_location is not None  # Pacify mypy
 | |
|             location = func_and_location.location
 | |
|             raise RuntimeError(
 | |
|                 f"Attempting to register a {kind} impl for operator {self._qualname} "
 | |
|                 f"that already has a {kind} impl registered from Python at "
 | |
|                 f"{location}. This is not supported."
 | |
|             )
 | |
|         frame = inspect.getframeinfo(sys._getframe(stacklevel))
 | |
|         location = f"{frame.filename}:{frame.lineno}"
 | |
|         self._impls[kind] = FuncAndLocation(func, location)
 | |
| 
 | |
|     def _get_impl(self, kind):
 | |
|         return self._impls[kind]
 | |
| 
 | |
|     def _has_impl(self, kind):
 | |
|         return kind in self._impls
 | |
| 
 | |
|     def _destroy(self):
 | |
|         # NOTE: [CustomOp lifetime]
 | |
|         # A CustomOp, once created, lives forever. The mechanism is that the
 | |
|         # global registry holds a reference to it. However, to make testing
 | |
|         # easier, we want to be able to destroy CustomOp objects.
 | |
|         # CustomOp._destroy does the job, though it leaves the CustomOp
 | |
|         # in a garbage state.
 | |
|         del self._lib
 | |
| 
 | |
|         opnamespace = getattr(torch.ops, self._cpp_ns)
 | |
|         if hasattr(opnamespace, self._opname):
 | |
|             delattr(opnamespace, self._opname)
 | |
| 
 | |
|         del global_registry[self._qualname]
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f'<CustomOp(op="{self._qualname}")>'
 | |
| 
 | |
|     def __call__(self, *args, **kwargs):
 | |
|         # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
 | |
|         # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
 | |
|         # issues from caching operators that make testing CustomOp difficult).
 | |
|         result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
 | |
|         return result
 | |
| 
 | |
|     def impl(
 | |
|         self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
 | |
|     ) -> typing.Callable:
 | |
|         r"""Register an implementation for a device type for this CustomOp object.
 | |
| 
 | |
|         WARNING: if you're a user, please do not use this directly
 | |
|         (instead use the torch._custom_ops APIs).
 | |
|         Also please see the following for a detailed guide on custom ops.
 | |
|         https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
 | |
| 
 | |
|         If the CustomOp is passed multiple Tensor inputs with different device
 | |
|         types, it will dispatch to the registered implementation for the highest
 | |
|         priority device type among those present.
 | |
|         The supported device types, in order of priority, are {'cuda', 'cpu'}.
 | |
| 
 | |
|         This API is used as a decorator (see examples).
 | |
| 
 | |
|         Arguments:
 | |
|             device_types (str or Iterable[str]): the device type(s) to register the function for.
 | |
| 
 | |
|         Examples::
 | |
|             >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
 | |
|             >>> import numpy as np
 | |
|             >>> from torch import Tensor
 | |
|             >>>
 | |
|             >>> @custom_op("my_library::numpy_cos")
 | |
|             >>> def numpy_cos(x: Tensor) -> Tensor:
 | |
|             >>>     ...
 | |
|             >>>
 | |
|             >>> # Register an implementation for CPU Tensors
 | |
|             >>> @numpy_cos.impl('cpu')
 | |
|             >>> def numpy_cos_impl_cpu(x):
 | |
|             >>>     return torch.from_numpy(np.cos(x.numpy()))
 | |
|             >>>
 | |
|             >>> # Register an implementation for CUDA Tensors
 | |
|             >>> @numpy_cos.impl('cuda')
 | |
|             >>> def numpy_cos_impl_cuda(x):
 | |
|             >>>     return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
 | |
|             >>>
 | |
|             >>> x = torch.randn(3)
 | |
|             >>> numpy_cos(x)  # calls numpy_cos_impl_cpu
 | |
|             >>>
 | |
|             >>> x_cuda = x.cuda()
 | |
|             >>> numpy_cos(x)  # calls numpy_cos_impl_cuda
 | |
| 
 | |
|         """
 | |
|         if isinstance(device_types, str):
 | |
|             device_types = [device_types]
 | |
|         for device_type in device_types:
 | |
|             validate_device_type(device_type)
 | |
| 
 | |
|         def inner(f):
 | |
|             for device_type in set(device_types):
 | |
|                 self._check_doesnt_have_library_impl(device_type)
 | |
|                 self._register_impl(device_type, f, stacklevel=_stacklevel)
 | |
|                 dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
 | |
|                 library.impl(self._lib, self._opname, dispatch_key)(f)
 | |
|             return f
 | |
| 
 | |
|         return inner
 | |
| 
 | |
|     def _check_doesnt_have_library_impl(self, device_type):
 | |
|         if self._has_impl(device_type):
 | |
|             return
 | |
|         key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
 | |
|         if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
 | |
|             raise RuntimeError(
 | |
|                 f"impl(..., device_types={device_type}): the operator {self._qualname} "
 | |
|                 f"already has an implementation for this device type via a "
 | |
|                 f"pre-existing torch.library or TORCH_LIBRARY registration.")
 | |
| 
 | |
|     def impl_factory(self) -> typing.Callable:
 | |
|         r"""Register an implementation for a factory function."""
 | |
| 
 | |
|         def inner(f):
 | |
|             self._register_impl("factory", f)
 | |
|             library.impl(self._lib, self._opname, "BackendSelect")(f)
 | |
|             return f
 | |
| 
 | |
|         return inner
 | |
| 
 | |
|     def impl_abstract(self, _stacklevel=2) -> typing.Callable:
 | |
|         r"""Register an abstract implementation for this operator.
 | |
| 
 | |
|         WARNING: please do not use this directly (and instead use the torch._custom_ops
 | |
|         APIs). Also please see the following for a detailed guide on custom ops.
 | |
|         https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
 | |
| 
 | |
|         An "abstract implementation" specifies the behavior of this operator on
 | |
|         Tensors that carry no data. Given some input Tensors with certain properties
 | |
|         (sizes/strides/storage_offset/device), it specifies what the properties of
 | |
|         the output Tensors are.
 | |
| 
 | |
|         The abstract implementation has the same signature as the operator.
 | |
|         It is run for both FakeTensors and meta tensors. To write an abstract
 | |
|         implementation, assume that all Tensor inputs to the operator are
 | |
|         regular CPU/CUDA/Meta tensors, but they do not have storage, and
 | |
|         you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
 | |
|         The abstract implementation must consist of only PyTorch operations
 | |
|         (and may not directly access the storage or data of any input or
 | |
|         intermediate Tensors).
 | |
| 
 | |
|         This API is used as a decorator (see examples).
 | |
| 
 | |
|         Examples::
 | |
|             >>> import numpy as np
 | |
|             >>> from torch import Tensor
 | |
|             >>>
 | |
|             >>> # Example 1: an operator without data-dependent output shape
 | |
|             >>> @custom_op('my_library::custom_linear')
 | |
|             >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
 | |
|             >>>     ...
 | |
|             >>>
 | |
|             >>> @custom_linear.impl_abstract()
 | |
|             >>> def custom_linear_abstract(x, weight):
 | |
|             >>>     assert x.dim() == 2
 | |
|             >>>     assert weight.dim() == 2
 | |
|             >>>     assert bias.dim() == 1
 | |
|             >>>     assert x.shape[1] == weight.shape[1]
 | |
|             >>>     assert weight.shape[0] == bias.shape[0]
 | |
|             >>>     assert x.device == weight.device
 | |
|             >>>
 | |
|             >>>     return (x @ weight.t()) + bias
 | |
|             >>>
 | |
|             >>> # Example 2: an operator with data-dependent output shape
 | |
|             >>> @custom_op('my_library::custom_nonzero')
 | |
|             >>> def custom_nonzero(x: Tensor) -> Tensor:
 | |
|             >>>     ...
 | |
|             >>>
 | |
|             >>> @custom_nonzero.impl_abstract()
 | |
|             >>> def custom_nonzero_abstract(x):
 | |
|             >>>     # Number of nonzero-elements is data-dependent.
 | |
|             >>>     # Since we cannot peek at the data in an abstract impl,
 | |
|             >>>     # we use the ctx object to construct a new symint that
 | |
|             >>>     # represents the data-dependent size.
 | |
|             >>>     ctx = torch._custom_op.get_ctx()
 | |
|             >>>     nnz = ctx.create_unbacked_symint()
 | |
|             >>>     shape = [x.dim(), nnz]
 | |
|             >>>     result = x.new_empty(shape, dtype=torch.long)
 | |
|             >>>     return result
 | |
|             >>>
 | |
|             >>> @custom_nonzero.impl(['cpu', 'cuda'])
 | |
|             >>> def custom_nonzero_impl(x):
 | |
|             >>>     x_np = to_numpy(x)
 | |
|             >>>     res = np.stack(np.nonzero(x_np), axis=1)
 | |
|             >>>     # unbacked symbolic ints in PyTorch must be >= 2, so we
 | |
|             >>>     # constrain the range to at least 2
 | |
|             >>>     if res.shape[0] <= 1:
 | |
|             >>>         raise RuntimeError("not supported")
 | |
|             >>>     return torch.tensor(res, device=x.device)
 | |
| 
 | |
|         """
 | |
| 
 | |
|         def inner(f):
 | |
|             self._check_doesnt_have_library_meta_impl()
 | |
|             self._register_impl("abstract", f, stacklevel=_stacklevel)
 | |
|             location = self._get_impl("abstract").location
 | |
| 
 | |
|             qualname = self._qualname
 | |
| 
 | |
|             # Handle DispatchKey.Meta registration
 | |
|             @functools.wraps(f)
 | |
|             def f_with_ctx(*args, **kwargs):
 | |
|                 def error_on_ctx():
 | |
|                     raise RuntimeError(
 | |
|                         f"Attempted to call get_ctx() for the meta implementation "
 | |
|                         f"for {qualname}."
 | |
|                         f"You have presumably called get_ctx() because the operator "
 | |
|                         f"has a data-dependent output shape; if so, there is no "
 | |
|                         f"such meta implementation and this error is the correct "
 | |
|                         f"behavior. Otherwise, please remove the call to get_ctx() "
 | |
|                         f"in the implementation registered with impl_abstract "
 | |
|                         f"at {location}"
 | |
|                     )
 | |
| 
 | |
|                 with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
 | |
|                     return f(*args, **kwargs)
 | |
| 
 | |
|             self._lib.impl(self._opname, f_with_ctx, "Meta")
 | |
|             return f
 | |
| 
 | |
|         return inner
 | |
| 
 | |
|     def _check_can_register_backward(self):
 | |
|         def error(detail):
 | |
|             raise RuntimeError(
 | |
|                 f"Cannot use torch._custom_ops APIs to register backward "
 | |
|                 f"formula for {detail}. Got operator "
 | |
|                 f"{self._qualname} with schema: {schema}"
 | |
|             )
 | |
| 
 | |
|         schema = self._schema
 | |
|         if schema.kind() != SchemaKind.functional:
 | |
|             error("non-functional operator")
 | |
| 
 | |
|         rets = schema.returns
 | |
|         if not schema.returns:
 | |
|             error("operator with no returns")
 | |
| 
 | |
|         assert len(rets) > 0
 | |
|         is_non_mutating_view = any(
 | |
|             r.annotation is not None and not r.annotation.is_write for r in rets
 | |
|         )
 | |
|         if is_non_mutating_view:
 | |
|             error("operator that returns views")
 | |
| 
 | |
|         # We make assumptions about the schema's return types.
 | |
|         allowed_return_types = {
 | |
|             BaseType(BaseTy.int): "int",
 | |
|             BaseType(BaseTy.SymInt): "SymInt",
 | |
|             BaseType(BaseTy.bool): "bool",
 | |
|             BaseType(BaseTy.float): "float",
 | |
|             BaseType(BaseTy.Tensor): "Tensor",
 | |
|             ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
 | |
|         }
 | |
|         for ret in schema.returns:
 | |
|             if ret.type in allowed_return_types:
 | |
|                 continue
 | |
|             error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
 | |
| 
 | |
|     def _check_doesnt_have_library_autograd_impl(self):
 | |
|         if self._registered_autograd_kernel_indirection:
 | |
|             return
 | |
| 
 | |
|         if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
 | |
|             raise RuntimeError(
 | |
|                 f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
 | |
|                 f"already has an implementation for this device type via a "
 | |
|                 f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
 | |
|                 f"CompositeImplicitAutograd operators do not need an autograd formula; "
 | |
|                 f"instead, the operator will decompose into its constituents and those "
 | |
|                 f"can have autograd formulas defined on them.")
 | |
| 
 | |
|         # We can improve this by adding "all Autograd<BACKEND> keys", but
 | |
|         # realistically people will just be using this API for CPU/CUDA for now.
 | |
|         for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
 | |
|             if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
 | |
|                 raise RuntimeError(
 | |
|                     f"impl_backward/impl_save_for_backward: "
 | |
|                     f"the operator {self._qualname} already has an Autograd kernel "
 | |
|                     f"registered to DispatchKey::{key} vi a pre-existing "
 | |
|                     f"torch.library or TORCH_LIBRARY registration. Please either "
 | |
|                     f"remove those registrations or don't use the torch._custom_ops APIs")
 | |
| 
 | |
|     def _check_doesnt_have_library_meta_impl(self):
 | |
|         if self._has_impl("abstract"):
 | |
|             return
 | |
| 
 | |
|         # If the user's operator is CompositeExplicitAutograd,
 | |
|         # allow them to impl_abstract. This is being pragmatic
 | |
|         # (existing custom ops may have CompositeExplicitAutograd
 | |
|         # registration that don't work with Meta kernels, so this
 | |
|         # gives them an escape hatch).
 | |
|         if (
 | |
|             _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
 | |
|             and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
 | |
|         ):
 | |
|             return
 | |
| 
 | |
|         # Otherwise, if the user's already has a Meta kernel or their
 | |
|         # op is CompositeImplicitAutograd or some other alias dispatch key,
 | |
|         # raise.
 | |
| 
 | |
|         # Special case for CompositeImplicitAutograd
 | |
|         if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
 | |
|             raise RuntimeError(
 | |
|                 f"impl_abstract(...): the operator {self._qualname} "
 | |
|                 f"already has an implementation for this device type via a "
 | |
|                 f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
 | |
|                 f"CompositeImplicitAutograd operators do not need an abstract impl; "
 | |
|                 f"instead, the operator will decompose into its constituents and those "
 | |
|                 f"can have abstract impls defined on them.")
 | |
| 
 | |
|         if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
 | |
|             raise RuntimeError(
 | |
|                 f"impl_abstract(...): the operator {self._qualname} "
 | |
|                 f"already has an DispatchKey::Meta implementation via a "
 | |
|                 f"pre-existing torch.library or TORCH_LIBRARY registration. "
 | |
|                 f"Please either remove that registration or don't call impl_abstract.")
 | |
| 
 | |
|     # NOTE ["backward", "save_for_backward", and "autograd"]
 | |
|     # As a part of the explicit autograd API, a user must provide us
 | |
|     # a "save_for_backward" function and a "backward" function.
 | |
|     # When both of these have been provided, then we automatically
 | |
|     # construct the "autograd" kernel.
 | |
|     def _register_autograd_kernel(self):
 | |
|         assert self._has_impl("backward")
 | |
|         assert self._has_impl("save_for_backward")
 | |
|         kernel = construct_autograd_kernel(
 | |
|             self._schema,
 | |
|             self._output_differentiability,
 | |
|             self,
 | |
|             get_op(self._qualname),
 | |
|             self._get_impl("save_for_backward").func,
 | |
|             self._get_impl("backward").func)
 | |
|         self._register_impl("autograd", kernel)
 | |
| 
 | |
|     def impl_save_for_backward(self, _stacklevel=2):
 | |
|         r"""Register a function that tells us what to save for backward.
 | |
| 
 | |
|         Please see impl_backward for more details.
 | |
|         """
 | |
|         def inner(f):
 | |
|             self._check_can_register_backward()
 | |
|             self._check_doesnt_have_library_autograd_impl()
 | |
|             if not self._registered_autograd_kernel_indirection:
 | |
|                 self._register_autograd_kernel_indirection()
 | |
|             self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
 | |
|             if self._has_impl("backward"):
 | |
|                 self._register_autograd_kernel()
 | |
|         return inner
 | |
| 
 | |
|     def impl_backward(self, output_differentiability=None, _stacklevel=2):
 | |
|         r"""Registers a backward formula.
 | |
| 
 | |
|         WARNING: if you're a user, please do not use this directly
 | |
|         (instead use the torch._custom_ops APIs).
 | |
|         Also please see the following for a detailed guide on custom ops.
 | |
|         https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
 | |
| 
 | |
|         In order for the CustomOp to work with autograd, you need to register
 | |
|         a backward formula. There are two pieces to this:
 | |
|         1. You must give us a function to specify what to save for backward.
 | |
|            Call this the "save for backward" function.
 | |
|         2. You must give us a function that computes gradients. Call this the
 | |
|            "backward" function.
 | |
| 
 | |
|         Use `impl_save_for_backward` to define a "save for backward" function
 | |
|         that specifies what gets saved for backward. The function should accept
 | |
|         two arguments ``(inputs, output)`` and return the quantities to be saved
 | |
|         for backward.
 | |
| 
 | |
|         During runtime, when you call the CustomOp, PyTorch will invoke the
 | |
|         "save for backward" function with the inputs and output of the CustomOp.
 | |
| 
 | |
|         Use `impl_backward` to define the "backward" function. The backward
 | |
|         function must accept ``(ctx, saved, *grads)``:
 | |
|         - ``ctx`` is a context object where we may provide information
 | |
|         - ``saved`` is exactly what gets returned from the "save for backward"
 | |
|           function
 | |
|         - ``grads`` is one or more gradients. The number of gradients matches
 | |
|           the number of outputs of the CustomOp.
 | |
| 
 | |
|         The backward function must return a dict that maps the name of
 | |
|         an input to the CustomOp to its corresponding gradient. All inputs that
 | |
|         were declared to be Tensors in the CustomOp definition must be accounted
 | |
|         for in the dict. The gradient may be a Tensor or None.
 | |
| 
 | |
|         """
 | |
|         if output_differentiability is not None:
 | |
|             def yell():
 | |
|                 raise RuntimeError(
 | |
|                     f"impl_backward(output_differentiability): expected "
 | |
|                     f"output_differentiability to be a list of bools with "
 | |
|                     f"length equal to the number of outputs of this CustomOp "
 | |
|                     f"got: {output_differentiability}")
 | |
| 
 | |
|             if not isinstance(output_differentiability, list):
 | |
|                 yell()
 | |
|             for diff in output_differentiability:
 | |
|                 if not isinstance(diff, bool):
 | |
|                     yell()
 | |
|             if len(self._schema.returns) != len(output_differentiability):
 | |
|                 yell()
 | |
| 
 | |
|         def inner(f):
 | |
|             self._check_can_register_backward()
 | |
|             self._check_doesnt_have_library_autograd_impl()
 | |
|             if not self._registered_autograd_kernel_indirection:
 | |
|                 self._register_autograd_kernel_indirection()
 | |
|             self._register_impl("backward", f, stacklevel=_stacklevel)
 | |
|             self._output_differentiability = output_differentiability
 | |
|             if self._has_impl("save_for_backward"):
 | |
|                 self._register_autograd_kernel()
 | |
|         return inner
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class FuncAndLocation:
 | |
|     func: typing.Callable
 | |
|     location: str
 | |
| 
 | |
| 
 | |
| def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
 | |
|     overload_name = (
 | |
|         "" if operator_name.overload_name is None else operator_name.overload_name
 | |
|     )
 | |
|     return _C._dispatch_find_schema_or_throw(
 | |
|         f"{cpp_ns}::{str(operator_name.name)}", overload_name
 | |
|     )
 | |
| 
 | |
| 
 | |
| def validate_namespace(ns: str) -> None:
 | |
|     if "." in ns:
 | |
|         raise ValueError(
 | |
|             f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
 | |
|             f"valid variable name)"
 | |
|         )
 | |
|     if ns in RESERVED_NS:
 | |
|         raise ValueError(
 | |
|             f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
 | |
|             f"please choose something else. "
 | |
|         )
 | |
| 
 | |
| def validate_schema(schema: FunctionSchema) -> None:
 | |
|     if not torch._library.utils.is_functional_schema(schema):
 | |
|         raise ValueError(
 | |
|             f"custom_op only supports functional operators "
 | |
|             f"(ops that do not mutate any inputs, do not return "
 | |
|             f"views of the inputs, and has at least one return). "
 | |
|             f"Got the following non-functional schema: {schema}"
 | |
|         )
 | |
| 
 | |
|     # For simplicity: don't allow self arguments
 | |
|     if schema.arguments.self_arg is not None:
 | |
|         raise ValueError(
 | |
|             f"custom_op does not support arguments named 'self'. Please "
 | |
|             f"rename your argument. Got: {schema}"
 | |
|         )
 | |
| 
 | |
| 
 | |
| def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
 | |
|     names = qualname.split("::", 1)
 | |
|     if len(names) != 2:
 | |
|         raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
 | |
|                          f"operator name should look something like ns::foo")
 | |
|     if '.' in names[1]:
 | |
|         raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
 | |
|                          f"i.e. operator names with '.' in them. "
 | |
|                          f"Please name your operator something like ns::foo. "
 | |
|                          f"Got: {qualname}")
 | |
|     return names[0], names[1]
 | |
| 
 | |
| 
 | |
| def validate_device_type(device_type: str) -> None:
 | |
|     if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
 | |
|         raise ValueError(
 | |
|             f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
 | |
|             f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
 | |
|         )
 | |
| 
 | |
| 
 | |
| def supported_param(param: inspect.Parameter) -> bool:
 | |
|     return param.kind in (
 | |
|         inspect.Parameter.POSITIONAL_OR_KEYWORD,
 | |
|         inspect.Parameter.KEYWORD_ONLY,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def validate_function_matches_schema(
 | |
|     schema: FunctionSchema, func: typing.Callable
 | |
| ) -> None:
 | |
|     sig = inspect.signature(func)
 | |
| 
 | |
|     if not all(supported_param(p) for _, p in sig.parameters.items()):
 | |
|         raise ValueError(
 | |
|             f"custom_op(..., manual_schema)(func): positional-only args, "
 | |
|             f"varargs, and kwargs are not supported. Please rewrite `func` "
 | |
|             f"to not have them. Got `func` with signature: {sig}"
 | |
|         )
 | |
| 
 | |
|     if (
 | |
|         any(
 | |
|             p.annotation is not inspect.Parameter.empty
 | |
|             for _, p in sig.parameters.items()
 | |
|         )
 | |
|         or sig.return_annotation is not inspect.Signature.empty
 | |
|     ):
 | |
|         raise ValueError(
 | |
|             f"custom_op(..., manual_schema)(func): When passing in a manual "
 | |
|             f"schema, we expect `func` to have no type annotations to avoid "
 | |
|             f"ambiguity. Got `func` with signature: {sig}"
 | |
|         )
 | |
| 
 | |
|     positional = [
 | |
|         (name, param)
 | |
|         for name, param in sig.parameters.items()
 | |
|         if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
 | |
|     ]
 | |
|     kwargonly = [
 | |
|         (name, param)
 | |
|         for name, param in sig.parameters.items()
 | |
|         if param.kind == inspect.Parameter.KEYWORD_ONLY
 | |
|     ]
 | |
| 
 | |
|     def error():
 | |
|         raise ValueError(
 | |
|             f"custom_op(..., manual_schema)(func): When passing in a manual "
 | |
|             f"schema, we expect `func`'s signature to match `manual_schema` "
 | |
|             f"(aside from type annotations). "
 | |
|             f"func's signature: {sig}, manual_schema: {schema}"
 | |
|         )
 | |
| 
 | |
|     def error_default_args():
 | |
|         raise ValueError(
 | |
|             f"custom_op(..., manual_schema)(func): "
 | |
|             f"neither func nor manual_schema should have default "
 | |
|             f"arguments. Got "
 | |
|             f"func's signature: {sig}, manual_schema: {schema}"
 | |
|         )
 | |
| 
 | |
|     def compare(sig_args, schema_args):
 | |
|         if len(sig_args) != len(schema_args):
 | |
|             error()
 | |
|         for (name, param), arg in zip(sig_args, schema_args):
 | |
|             if name != arg.name:
 | |
|                 error()
 | |
|             if param.default is not inspect.Parameter.empty or arg.default is not None:
 | |
|                 error_default_args()
 | |
| 
 | |
|     compare(positional, schema.arguments.flat_positional)
 | |
|     compare(kwargonly, schema.arguments.flat_kwarg_only)
 | |
| 
 | |
| 
 | |
| def report_error_callback(custom_op: typing.Any, key: str) -> None:
 | |
|     if key == "Undefined":
 | |
|         raise NotImplementedError(
 | |
|             f"{custom_op}: There were no Tensor inputs to this operator "
 | |
|             f"(e.g. you passed an empty list of Tensors). If your operator is a "
 | |
|             f"factory function (that is, it takes no Tensors and constructs "
 | |
|             f"a new one), then please use CustomOp.impl_factory to register "
 | |
|             f"an implementation for it"
 | |
|         )
 | |
|     if key == "Meta":
 | |
|         raise NotImplementedError(
 | |
|             f"{custom_op}: when running with device='Meta' tensors: there is no "
 | |
|             f"abstract impl registered for this CustomOp. Please register one via "
 | |
|             f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
 | |
|         )
 | |
|     if key in ("CPU", "CUDA"):
 | |
|         device = key.lower()
 | |
|         raise NotImplementedError(
 | |
|             f"{custom_op}: when running with device='{device}' tensors: there is no "
 | |
|             f"{device} impl registered for this CustomOp. Please register one via "
 | |
|             f"CustomOp.impl(device_type='{device}')"
 | |
|         )
 | |
|     raise NotImplementedError(
 | |
|         f"{custom_op}: No implementation for dispatch key {key}. It is likely "
 | |
|         f"that we have not added this functionality yet, please either open an "
 | |
|         f"issue or if you're feeling adventurous, use the low-level "
 | |
|         f"torch.library API"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def custom_op_from_existing(op):
 | |
|     ns = op.namespace
 | |
|     lib = torch.library.Library(ns, "FRAGMENT")
 | |
|     name = op.name().split("::")[-1]
 | |
|     schema_str = str(op._schema)
 | |
|     # CustomOp expects the schema string without the namespace
 | |
|     schema_str = schema_str.split("::")[-1]
 | |
|     schema = FunctionSchema.parse(schema_str)
 | |
|     return CustomOp(lib, ns, schema, name, op, _private_access=True)
 | |
| 
 | |
| 
 | |
| def get_op(qualname):
 | |
|     def error_not_found():
 | |
|         raise ValueError(
 | |
|             f"Could not find the operator {qualname}. Please make sure you have "
 | |
|             f"already registered the operator and (if registered from C++) "
 | |
|             f"loaded it via torch.ops.load_library.")
 | |
| 
 | |
|     ns, name = parse_qualname(qualname)
 | |
|     if not hasattr(torch.ops, ns):
 | |
|         error_not_found()
 | |
|     opnamespace = getattr(torch.ops, ns)
 | |
|     if not hasattr(opnamespace, name):
 | |
|         error_not_found()
 | |
|     packet = getattr(opnamespace, name)
 | |
|     if not hasattr(packet, 'default'):
 | |
|         error_not_found()
 | |
|     return packet.default
 | |
| 
 | |
| 
 | |
| def _find_custom_op(qualname, also_check_torch_library=False):
 | |
|     if qualname in global_registry:
 | |
|         return global_registry[qualname]
 | |
|     if not also_check_torch_library:
 | |
|         raise RuntimeError(
 | |
|             f'Could not find custom op "{qualname}". Did you register it via '
 | |
|             f"the torch._custom_ops API?")
 | |
|     overload = get_op(qualname)
 | |
|     result = custom_op_from_existing(overload)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def get_abstract_impl(qualname):
 | |
|     if qualname not in torch._custom_op.impl.global_registry:
 | |
|         return None
 | |
|     custom_op = torch._custom_op.impl.global_registry[qualname]
 | |
|     if custom_op is None:
 | |
|         return None
 | |
|     if not custom_op._has_impl("abstract"):
 | |
|         return None
 | |
|     return custom_op._get_impl("abstract").func
 | |
| 
 | |
| 
 | |
| def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
 | |
|     ns, name = qualname.split("::")
 | |
|     schema_str = f"{name}{schema}"
 | |
|     function_schema = FunctionSchema.parse(schema_str)
 | |
|     validate_schema(function_schema)
 | |
|     tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
 | |
|     lib = library.Library(ns, "FRAGMENT")
 | |
|     lib.define(schema_str, tags=tags)
 | |
|     ophandle = find_ophandle_or_throw(ns, function_schema.name)
 | |
|     result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
 | |
|     result._register_autograd_kernel_indirection()
 | |
| 
 | |
|     torch._C._dispatch_set_report_error_callback(
 | |
|         ophandle, functools.partial(report_error_callback, weakref.proxy(result))
 | |
|     )
 | |
|     return get_op(qualname)
 |