mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Closes https://github.com/pytorch/pytorch/issues/24015 (see description of that issue for more details). For a toy example, see the `DiagonalTensor` and `SubDiagonalTensor` class in test/test_overrides.py. This PR currently contains: * tests for `__torch_function__` behavior * modification to `gen_python_functions` and `parse` function signatures and dispatched to correct overloaded argument. This feature is inspired by and analogous to NumPy's `__array_function__` protocol ([see NumPy Enhancement Proposal 18](https://numpy.org/neps/nep-0018-array-function-protocol.html#trying-array-function-methods-until-the-right-one-works)). ### Benchmarks: See Nathan's comment below: https://github.com/pytorch/pytorch/pull/27064#issuecomment-554601189 Pull Request resolved: https://github.com/pytorch/pytorch/pull/27064 Differential Revision: D18645954 Pulled By: ezyang fbshipit-source-id: 54b5e4344d7afdbcf996bb57191b0bdadc7b1767
281 lines
11 KiB
Python
281 lines
11 KiB
Python
"""
|
|
Python implementation of __torch_function__
|
|
|
|
While most of the torch API and handling for __torch_function__ happens
|
|
at the C++ level, some of the torch API is written in Python so we need
|
|
python-level handling for __torch_function__ overrides as well. The main
|
|
developer-facing functionality in this file is the
|
|
torch_function_dispatch decorator. This function can be applied to
|
|
python functions in the torch.functional module to enable
|
|
__torch_function__ overrides for that function. See the examples in the
|
|
docstrings for torch_function_dispatch for details.
|
|
|
|
NOTE: heavily inspired by NumPy's ``__array_function__`` (see:
|
|
https://github.com/pytorch/pytorch/issues/24015 and
|
|
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
|
|
)
|
|
|
|
"""
|
|
|
|
import functools
|
|
import textwrap
|
|
from . import _six
|
|
if _six.PY3:
|
|
from inspect import getfullargspec
|
|
import collections
|
|
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
|
|
|
|
def getargspec(func):
|
|
spec = getfullargspec(func)
|
|
return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults)
|
|
else:
|
|
from inspect import getargspec
|
|
|
|
from .tensor import Tensor
|
|
|
|
|
|
_TENSOR_ONLY = [Tensor]
|
|
|
|
def _get_overloaded_types_and_args(relevant_args):
|
|
"""Returns a list of arguments on which to call __torch_function__.
|
|
|
|
Checks arguments in relevant_args for __torch_function__ implementations,
|
|
storing references to the arguments and their types in overloaded_args and
|
|
overloaded_types in order of calling precedence. Only distinct types are
|
|
considered. If a type is a subclass of another type it will have higher
|
|
precedence, otherwise the precedence order is the same as the order of
|
|
arguments in relevant_args, that is, from left-to-right in the argument list.
|
|
|
|
The precedence-determining algorithm implemented in this function is
|
|
described in `NEP-0018`_.
|
|
|
|
See torch::append_overloaded_arg for the equivalent function in the C++
|
|
implementation.
|
|
|
|
Parameters
|
|
----------
|
|
relevant_args : iterable of array-like
|
|
Iterable of array-like arguments to check for __torch_function__
|
|
methods.
|
|
|
|
Returns
|
|
-------
|
|
overloaded_types : collection of types
|
|
Types of arguments from relevant_args with __torch_function__ methods.
|
|
overloaded_args : list
|
|
Arguments from relevant_args on which to call __torch_function__
|
|
methods, in the order in which they should be called.
|
|
|
|
.. _NEP-0018:
|
|
https://numpy.org/neps/nep-0018-array-function-protocol.html
|
|
|
|
"""
|
|
# Runtime is O(num_arguments * num_unique_types)
|
|
overloaded_types = []
|
|
overloaded_args = []
|
|
for arg in relevant_args:
|
|
arg_type = type(arg)
|
|
# We only collect arguments if they have a unique type, which ensures
|
|
# reasonable performance even with a long list of possibly overloaded
|
|
# arguments.
|
|
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')):
|
|
# Create lists explicitly for the first type (usually the only one
|
|
# done) to avoid setting up the iterator for overloaded_args.
|
|
if overloaded_types:
|
|
overloaded_types.append(arg_type)
|
|
# By default, insert argument at the end, but if it is
|
|
# subclass of another argument, insert it before that argument.
|
|
# This ensures "subclasses before superclasses".
|
|
index = len(overloaded_args)
|
|
for i, old_arg in enumerate(overloaded_args):
|
|
if issubclass(arg_type, type(old_arg)):
|
|
index = i
|
|
break
|
|
overloaded_args.insert(index, arg)
|
|
else:
|
|
overloaded_types = [arg_type]
|
|
overloaded_args = [arg]
|
|
|
|
return overloaded_types, overloaded_args
|
|
|
|
|
|
def _implement_torch_function(
|
|
implementation, public_api, relevant_args, args, kwargs):
|
|
"""Implement a function with checks for __torch_function__ overrides.
|
|
|
|
See torch::autograd::handle_torch_function for the equivalent of this
|
|
function in the C++ implementation.
|
|
|
|
Arguments
|
|
---------
|
|
implementation : function
|
|
Function that implements the operation on ``torch.Tensor`` without
|
|
overrides when called like ``implementation(*args, **kwargs)``.
|
|
public_api : function
|
|
Function exposed by the public torch API originally called like
|
|
``public_api(*args, **kwargs)`` on which arguments are now being
|
|
checked.
|
|
relevant_args : iterable
|
|
Iterable of arguments to check for __torch_function__ methods.
|
|
args : tuple
|
|
Arbitrary positional arguments originally passed into ``public_api``.
|
|
kwargs : tuple
|
|
Arbitrary keyword arguments originally passed into ``public_api``.
|
|
|
|
Returns
|
|
-------
|
|
Result from calling `implementation()` or an `__torch_function__`
|
|
method, as appropriate.
|
|
|
|
Raises
|
|
------
|
|
TypeError : if no implementation is found.
|
|
|
|
"""
|
|
# Check for __torch_function__ methods.
|
|
types, overloaded_args = _get_overloaded_types_and_args(relevant_args)
|
|
# Short-cut for common cases: no overload or only Tensor overload
|
|
# (directly or with subclasses that do not override __torch_function__).
|
|
if not overloaded_args or types == _TENSOR_ONLY:
|
|
return implementation(*args, **kwargs)
|
|
|
|
# Call overrides
|
|
for overloaded_arg in overloaded_args:
|
|
# Use `public_api` instead of `implementation` so __torch_function__
|
|
# implementations can do equality/identity comparisons.
|
|
result = overloaded_arg.__torch_function__(public_api, args, kwargs)
|
|
|
|
if result is not NotImplemented:
|
|
return result
|
|
|
|
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
|
|
raise TypeError("no implementation found for '{}' on types that implement "
|
|
'__torch_function__: {}'
|
|
.format(func_name, list(map(type, overloaded_args))))
|
|
|
|
|
|
def _verify_matching_signatures(implementation, dispatcher):
|
|
"""Verify that a dispatcher function has the right signature."""
|
|
implementation_spec = getargspec(implementation)
|
|
dispatcher_spec = getargspec(dispatcher)
|
|
|
|
if (implementation_spec.args != dispatcher_spec.args or
|
|
implementation_spec.varargs != dispatcher_spec.varargs or
|
|
implementation_spec.keywords != dispatcher_spec.keywords or
|
|
(bool(implementation_spec.defaults) !=
|
|
bool(dispatcher_spec.defaults)) or
|
|
(implementation_spec.defaults is not None and
|
|
len(implementation_spec.defaults) !=
|
|
len(dispatcher_spec.defaults))):
|
|
raise RuntimeError('implementation and dispatcher for %s have '
|
|
'different function signatures' % implementation)
|
|
|
|
|
|
_wrapped_func_source = textwrap.dedent("""
|
|
@functools.wraps(implementation)
|
|
def {name}(*args, **kwargs):
|
|
relevant_args = dispatcher(*args, **kwargs)
|
|
return implement_torch_function(
|
|
implementation, {name}, relevant_args, args, kwargs)
|
|
""")
|
|
|
|
def torch_function_dispatch(dispatcher, module=None, verify=True):
|
|
"""Decorator for adding dispatch with the __torch_function__ protocol.
|
|
|
|
If you define a function in Python and would like to permit user-defined
|
|
tensor-like types to override it using __torch_function__, please apply this
|
|
decorator on this function together with a custom dispatcher that indicates
|
|
which arguments should be checked for the presence of __torch_function__.
|
|
|
|
Suppose we'd like to apply this function to torch.frob, which has the
|
|
following definition:
|
|
|
|
def frob(input, bias, option=None):
|
|
return input + bias
|
|
|
|
We'd need to define a dispatcher for frob that has the same signature and
|
|
returns the elements of the signature that should be checked for
|
|
`__torch_function__`. If any of the arguments has a `__torch_function__`
|
|
attribute, that function will be called to handle custom dispatch. Assuming
|
|
that `bias` can be a tensor-like, our dispatcher would look like:
|
|
|
|
def _frob_dispatcher(input, bias, option=None):
|
|
return (input, bias)
|
|
|
|
The dispatcher must return an iterable, so return a single-element tuple if
|
|
only one argument should be checked. We would then modify the original
|
|
definition for torch.frob to look like:
|
|
|
|
@torch_function_dispatch(_frob_dispatcher)
|
|
def frob(input, bias, option=None):
|
|
return input + bias
|
|
|
|
See ``torch/functional.py`` for more usage examples.
|
|
|
|
Parameters
|
|
----------
|
|
dispatcher : callable
|
|
Function that when called like ``dispatcher(*args, **kwargs)`` with
|
|
arguments from the NumPy function call returns an iterable of
|
|
array-like arguments to check for ``__torch_function__``.
|
|
module : str, optional
|
|
``__module__`` attribute to set on new function, e.g.,
|
|
``module='torch'``. By default, module is copied from the decorated
|
|
function.
|
|
verify : bool, optional
|
|
If True, verify the that the signature of the dispatcher and decorated
|
|
function signatures match exactly: all required and optional arguments
|
|
should appear in order with the same names, but the default values for
|
|
all optional arguments should be ``None``. Only disable verification
|
|
if the dispatcher's signature needs to deviate for some particular
|
|
reason, e.g., because the function has a signature like
|
|
``func(*args, **kwargs)``.
|
|
|
|
Returns
|
|
-------
|
|
dispatcher : callable
|
|
Function suitable for decorating the implementation of a NumPy
|
|
function.
|
|
|
|
Notes
|
|
-----
|
|
The dispatcher should normally return a tuple containing all input
|
|
arguments that may have a ``__torch_function__`` attribute.
|
|
|
|
In some cases where that's not easily possible, e.g. ``torch.cat``, it is
|
|
also valid (if a little slower) to make the dispatcher function a generator
|
|
(i.e. use ``yield`` to return arguments one by one).
|
|
|
|
"""
|
|
def decorator(implementation):
|
|
if verify:
|
|
_verify_matching_signatures(implementation, dispatcher)
|
|
|
|
# Equivalently, we could define this function directly instead of using
|
|
# exec. This version has the advantage of giving the helper function a
|
|
# more interpretable name. Otherwise, the original function does not
|
|
# show up at all in many cases, e.g., if it's written in C++ or if the
|
|
# dispatcher gets an invalid keyword argument.
|
|
source = _wrapped_func_source.format(name=implementation.__name__)
|
|
|
|
source_object = compile(
|
|
source, filename='<__torch_function__ internals>', mode='exec')
|
|
scope = {
|
|
'implementation': implementation,
|
|
'dispatcher': dispatcher,
|
|
'functools': functools,
|
|
'implement_torch_function': _implement_torch_function,
|
|
}
|
|
_six.exec_(source_object, scope)
|
|
|
|
public_api = scope[implementation.__name__]
|
|
|
|
if module is not None:
|
|
public_api.__module__ = module
|
|
|
|
public_api._implementation = implementation
|
|
|
|
return public_api
|
|
|
|
return decorator
|