[custom_op] explicit autograd API (#101824)

This PR adds an explicit API for registering a backward formula for a
CustomOp. In the end state, we will likely have this explicit API and a
magic API (which is sugar on top of an explicit API), since different
parties of users prefer different ones.

Concretely, to define a backward formula for a CustomOp:
- a user must provide us a "save for backward" function that accepts
(inputs, output) and returns exactly what they want saved for backward
- a user must provide us a "backward" function that accepts
(ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs
are returned as a dict mapping str to a gradient.
Please see the changes in custom_op_db.py for examples of the API.

There are a number of pieces to this PR and I'm happy to split it if it
helps. They are:
- The actual APIs for specifying the two functions
(impl_save_for_backward, impl_backward)
- The autograd kernel: we take the functions the user give us and
construct an autograd.Function object that we then register to
the Autograd dispatch key
- Indirection for the autograd kernel. We add a layer of indirection so
that one can swap out the autograd kernel. This is necessary because by
default, we register an "autograd not implemented" kernel as the
Autograd implementation but then swap it for the actual kernel when the
user provides it.

Test Plan:
- We apply this API to give backward formulas for things in
custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests.
- Various tests in test_python_dispatch.py to check error cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101824
Approved by: https://github.com/ezyang
This commit is contained in:
Richard Zou
2023-05-23 07:02:58 -07:00
committed by PyTorch MergeBot
parent 8487105fae
commit 723f111545
5 changed files with 730 additions and 29 deletions

View File

@ -6,6 +6,7 @@ import torch
from torch.testing._internal.common_utils import TestGradients, run_tests
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
@ -18,7 +19,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
@ -33,7 +34,7 @@ class TestBwdGradients(TestGradients):
# self._skip_helper(op, device, dtype)
# self._grad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + custom_op_db)
def test_inplace_grad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant:
@ -52,7 +53,7 @@ class TestBwdGradients(TestGradients):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad:
@ -61,7 +62,7 @@ class TestBwdGradients(TestGradients):
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
# Test that gradients of gradients are properly raising
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + custom_op_db)
def test_fn_fail_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if op.supports_gradgrad:

View File

@ -730,7 +730,7 @@ class TestCustomOp(TestCase):
def test_private_ctor(self):
with self.assertRaisesRegex(RuntimeError, 'CustomOp constructor is private'):
CustomOp(None, None, None, None)
CustomOp(None, None, None, None, None)
def test_lifetime(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
@ -789,6 +789,21 @@ class TestCustomOp(TestCase):
foo(y, x)
foo._destroy()
def test_autograd_notimplemented_gradmode(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu'])
def foo_impl(x, y):
return x * y
x = torch.randn(3, requires_grad=True)
y = torch.randn(3)
with torch.no_grad():
# Shouldn't raise, because we are in no_grad
foo(y, x)
def test_impl_cpu(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
@ -824,6 +839,271 @@ class TestCustomOp(TestCase):
foo.impl(invalid_type)(foo_impl)
foo._destroy()
def test_backward_partially_registered(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return grad * saved.cos()
x = torch.randn([], requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "unable to find a 'save_for_backward'"):
y = foo(x)
y.backward()
def test_save_for_backward_inputs_are_namedtuple(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()
hit = 0
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
nonlocal hit
hit += 1
self.assertTrue(isinstance(inputs, tuple))
self.assertEqual(list(inputs._asdict().keys()), ['x'])
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos()}
x = torch.randn([], requires_grad=True)
y = foo(x)
self.assertEqual(hit, 1)
y.backward()
self.assertEqual(hit, 1)
def test_backward_returns_dict(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return grad * saved.cos()
x = torch.randn([], requires_grad=True)
y = foo(x)
with self.assertRaisesRegex(RuntimeError, 'to be a dict'):
y.backward()
def test_backward_dict_invalid_keys(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos(), 'y': None}
x = torch.randn([], requires_grad=True)
y = foo(x)
with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
y.backward()
def test_backward_dict_grad_for_nontensor(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x, dim):
return x.sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos(), 'dim': None}
x = torch.randn([], requires_grad=True)
y = foo(x, 32)
with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
y.backward()
def test_backward_dict_requires_keys_for_input_tensors(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x, y):
return x.sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos()}
x = torch.randn([], requires_grad=True)
y = foo(x, x)
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
y.backward()
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x, y):
return x.sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos()}
x = torch.randn([], requires_grad=True)
y = foo(x, None)
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
y.backward()
def test_backward_grads_are_tensor_or_none(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': (grad * saved.cos(),)}
x = torch.randn([], requires_grad=True)
y = foo(x)
with self.assertRaisesRegex(RuntimeError, 'either None or a Tensor'):
y.backward()
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(xs):
return xs[0].sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.xs[0]
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'xs': [grad * saved.cos(), None]}
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
y = foo(xs)
with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
y.backward()
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(xs):
return xs[0].sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.xs[0]
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'xs': [grad * saved.cos(), None, (None,)]}
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
y = foo(xs)
with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
y.backward()
def test_backward_tensorlist_input_requires_list_grads(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(xs):
return xs[0].sin()
@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.xs[0]
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'xs': None}
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
y = foo(xs)
with self.assertRaisesRegex(RuntimeError, "list of gradients"):
y.backward()
def test_backward_output_differentiability_type(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
@foo.impl_backward(output_differentiability=True)
def foo_backward(ctx, saved, grad):
return {'xs': None}
def test_backward_output_differentiability_numel(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
...
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
@foo.impl_backward(output_differentiability=[True])
def foo_backward(ctx, saved, grad):
return {'xs': None}
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_impl_separate(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')

View File

@ -0,0 +1,243 @@
import torch
import torch.utils._pytree as pytree
from collections import namedtuple
import functools
# NOTE [CustomOp autograd kernel indirection]
# We register `inner` as the autograd kernel for this custom_op.
# `inner` either calls the autograd formula registered by the user,
# or goes into an `autograd_not_implemented` kernel.
#
# The reason why this indirection exists is
# so that we can swap out the autograd kernel (the PyTorch dispatcher
# doesn't actually allow us to do this). By default, we want
# the `autograd_not_implemented` behavior, but then the user may come
# and register something that is actually a backward formula
def autograd_kernel_indirection(custom_op):
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
# or change the default autograd fallback to the autograd not implemented fallback.
def autograd_not_implemented(*args, **kwargs) -> None:
if torch.is_grad_enabled() and pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
guard = torch._C._AutoDispatchBelowAutograd()
try:
return custom_op(*args, **kwargs)
finally:
del guard
def inner(*args, **kwargs):
if custom_op._has_impl('autograd'):
kernel = custom_op._get_impl('autograd').func
return kernel(*args, **kwargs)
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
# after the user gives us "backward" and "save_for_backward", we generate
# the "autograd" impl. If the user only provided one, then we tell
# the user they've done something wrong.
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
missing = (
'save_for_backward' if custom_op._has_impl('backward')
else 'backward'
)
found = 'save_for_backward' if missing == 'backward' else 'backward'
loc = custom_op._get_impl(found).location
raise RuntimeError(
f"We found a '{found}' registration for {custom_op} at "
f"{loc} but were unable to find a '{missing}' registration. "
f"To use the CustomOp API to register a backward formula, "
f"please provide us both a backward function and a "
f"'save for backward' function via `impl_backward` and "
f"`impl_save_for_backward` respectively.")
return autograd_not_implemented(*args, **kwargs)
return inner
def construct_autograd_kernel(
schema,
output_differentiability,
forward_op,
save_for_backward_fn,
backward_fn):
def apply(*args):
flat_args, spec = pytree.tree_flatten(args)
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
guard = torch._C._AutoDispatchBelowAutograd()
try:
output = forward_op(*args)
finally:
del guard
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(
schema, pytree.tree_map(lambda arg: type(arg), args))
save_for_backward_fn_inputs = namedtuple_args(schema, args)
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
save_pytree_for_backward(ctx, (to_save, args_info))
# Output must be one or more Tensors, no TensorList (yet)
if output_differentiability is not None:
if isinstance(output, tuple):
assert len(output_differentiability) == len(output)
for differentiable, out in zip(output_differentiability, output):
if not differentiable:
ctx.mark_non_differentiable(out)
else:
assert len(output_differentiability) == 1
if not output_differentiability[0]:
ctx.mark_non_differentiable(output)
return output
def backward(ctx, *grads):
saved, args_info = unpack_saved(ctx)
# There is nothing on the ctx object for now, it is just there so
# that we can add additional things in the future.
inner_ctx = object()
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
# Massage the grad_inputs_dict to a form acceptable by
# autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function(
forward_op._opname + '_customop', forward, backward)
return generated_cls.apply(*flat_args)
return apply
def gen_autograd_function(name, forward, backward):
generated_cls = type(
name,
(torch.autograd.Function,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
}
)
return generated_cls
@functools.lru_cache
def namedtuple_args_cls(schema):
attribs = [arg.name for arg in schema.arguments.flat_all]
name = str(schema.name) + "_args"
# mypy doesn't support dynamic namedtuple name
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
return tuple_cls
def namedtuple_args(schema, args):
assert isinstance(args, tuple)
tuple_cls = namedtuple_args_cls(schema)
return tuple_cls(*args)
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
def error(what):
backward = forward_op._get_impl('backward')
raise RuntimeError(
f"In the backward function defined for {forward_op} at "
f"{backward.location} using the CustomOp API, {what}")
if not isinstance(grad_inputs_dict, dict):
error(f"expected the output of the backward function to be a dict but "
f"got {type(grad_inputs_dict)}")
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
if arg.type.is_tensor_like()}
actual_keys = grad_inputs_dict.keys()
if expected_keys != actual_keys:
error(f"expected the returned grad_input dict to have keys "
f"{expected_keys} but got {actual_keys}. The backward "
f"function must return a gradient (can be None) for each arg "
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
f"Args declared to be non-Tensor-like types should not appear "
f"in the grad_input dict")
for name, grad in grad_inputs_dict.items():
arg_info = getattr(args_info, name)
if isinstance(arg_info, list):
if not isinstance(grad, (tuple, list)):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of gradients but got object of type "
f"{type(grad)}.")
if not len(grad) == len(arg_info):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got "
f"{len(grad)}")
for idx, (g, info) in enumerate(zip(grad, arg_info)):
if g is None:
continue
if not isinstance(g, torch.Tensor):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of None or Tensor gradients but got "
f"object of {type(g)} at index {idx}")
if info != torch.Tensor:
error(f"for input '{name}', got a Tensor as the gradient "
f"for the {idx}-th value but expected None because "
f"the {idx}-th value was not a Tensor (it was "
f"type {arg_info}")
continue
if grad is None:
continue
if not isinstance(grad, torch.Tensor):
error(f"got object of type {type(grad)} as the gradient for input "
f"'{name}', "
f"but expected the gradient to be either None or a Tensor")
if arg_info != torch.Tensor:
error(f"got a Tensor as the gradient for input '{name}' but "
f"expected None as the gradient because input '{name}' "
f"was not a Tensor (it was type {arg_info}).")
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
result = []
for name, arg_info in args_info._asdict().items():
if name not in grad_inputs_dict:
result.append(pytree.tree_map(lambda x: None, arg_info))
continue
result.append(grad_inputs_dict[name])
return tuple(pytree.tree_flatten(result)[0])
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
# autograd.Function prefers that users use ctx.save_for_backward to
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
# ctx object.
def save_pytree_for_backward(ctx, stuff):
flat_stuff, spec = pytree.tree_flatten(stuff)
num_elts = len(flat_stuff)
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if isinstance(thing, torch.Tensor)]
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if not isinstance(thing, torch.Tensor)]
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
ctx.spec = spec
ctx.num_elts = num_elts
ctx.save_for_backward(*tensors)
ctx.tensor_idxs = tensor_idxs
ctx.saved_non_tensors = non_tensors
ctx.non_tensor_idxs = non_tensor_idxs
# Inverse operation to save_pytree_for_backward
def unpack_saved(ctx):
flat_stuff = [None] * ctx.num_elts
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
flat_stuff[idx] = tensor
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
flat_stuff[idx] = non_tensor
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
return stuff

View File

@ -10,7 +10,8 @@ from torchgen.model import FunctionSchema, OperatorName, SchemaKind
import torch
import torch._C as _C
import torch.library as library
import torch.utils._pytree as pytree
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
"""
There are various APIs for defining custom-operator-like things in PyTorch:
@ -143,18 +144,14 @@ def custom_op(
lib = library.Library(ns, "FRAGMENT")
lib.define(schema_str)
ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema.name, ophandle, _private_access=True)
result = CustomOp(lib, ns, function_schema, function_schema.name, ophandle, _private_access=True)
result.__name__ = func.__name__
result.__module__ = func.__module__
result.__doc__ = func.__doc__
# NYI: autograd not supported
# In the near future we will either directly use the
# autograd_not_implemented kernels or make those the default fallback
# for the Autograd and ADInplaceOrView keys. Both of those are a bit tricky.
library.impl(lib, result._opname, "Autograd")(
get_autograd_not_implemented_kernel(weakref.proxy(result))
autograd_kernel_indirection(weakref.proxy(result))
)
torch._C._dispatch_set_report_error_callback(
@ -185,7 +182,7 @@ class CustomOp:
To construct a `CustomOp`, use `custom_op`.
"""
def __init__(self, lib, cpp_ns, operator_name, ophandle, *, _private_access=False):
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
super(CustomOp, self).__init__()
if not _private_access:
raise RuntimeError(
@ -193,6 +190,7 @@ class CustomOp:
"BC for it. Please use custom_op(...) to create a CustomOp object"
)
name = f"{cpp_ns}::{str(operator_name.name)}"
self._schema = schema
self._cpp_ns = cpp_ns
self._lib: library.Library = lib
self._ophandle: _C._DispatchOperatorHandle = ophandle
@ -417,6 +415,90 @@ class CustomOp:
return inner
# NOTE ["backward", "save_for_backward", and "autograd"]
# As a part of the explicit autograd API, a user must provide us
# a "save_for_backward" function and a "backward" function.
# When both of these have been provided, then we automatically
# construct the "autograd" kernel.
def _register_autograd_kernel(self):
assert self._has_impl("backward")
assert self._has_impl("save_for_backward")
kernel = construct_autograd_kernel(
self._schema,
self._output_differentiability,
self,
self._get_impl("save_for_backward").func,
self._get_impl("backward").func)
self._register_impl("autograd", kernel)
def impl_save_for_backward(self):
r"""Register a function that tells us what to save for backward.
Please see impl_backward for more details.
"""
def inner(f):
self._register_impl("save_for_backward", f)
if self._has_impl("backward"):
self._register_autograd_kernel()
return inner
def impl_backward(self, output_differentiability=None):
r"""Registers a backward formula.
In order for the CustomOp to work with autograd, you need to register
a backward formula. There are two pieces to this:
1. You must give us a function to specify what to save for backward.
Call this the "save for backward" function.
2. You must give us a function that computes gradients. Call this the
"backward" function.
Use `impl_save_for_backward` to define a "save for backward" function
that specifies what gets saved for backward. The function should accept
two arguments ``(inputs, output)`` and return the quantities to be saved
for backward.
During runtime, when you call the CustomOp, PyTorch will invoke the
"save for backward" function with the inputs and output of the CustomOp.
Use `impl_backward` to define the "backward" function. The backward
function must accept ``(ctx, saved, *grads)``:
- ``ctx`` is a context object where we may provide information
- ``saved`` is exactly what gets returned from the "save for backward"
function
- ``grads`` is one or more gradients. The number of gradients matches
the number of outputs of the CustomOp.
The backward function must return a dict that maps the name of
an input to the CustomOp to its corresponding gradient. All inputs that
were declared to be Tensors in the CustomOp definition must be accounted
for in the dict. The gradient may be a Tensor or None.
TODO(rzou): Add example when this PR is closer to landing.
"""
if output_differentiability is not None:
def yell():
raise RuntimeError(
f"impl_backward(output_differentiability): expected "
f"output_differentiability to be a list of bools with "
f"length equal to the number of outputs of this CustomOp "
f"got: {output_differentiability}")
if not isinstance(output_differentiability, list):
yell()
for diff in output_differentiability:
if not isinstance(diff, bool):
yell()
if len(self._schema.returns) != len(output_differentiability):
yell()
def inner(f):
self._register_impl("backward", f)
self._output_differentiability = output_differentiability
if self._has_impl("save_for_backward"):
self._register_autograd_kernel()
return inner
@dataclasses.dataclass
class FuncAndLocation:
@ -490,21 +572,6 @@ def validate_device_type(device_type: str) -> None:
)
def get_autograd_not_implemented_kernel(custom_op) -> typing.Callable:
def autograd_not_implemented(*args, **kwargs) -> None:
if pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
guard = _C._AutoDispatchBelowAutograd()
try:
return custom_op(*args, **kwargs)
finally:
del guard
return autograd_not_implemented
def supported_param(param: inspect.Parameter) -> bool:
return param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,

View File

@ -42,6 +42,16 @@ def numpy_cube_impl(x):
def numpy_cube_abstract(x):
return x.clone(), x.clone()
@numpy_cube.impl_save_for_backward()
def numpy_cube_save_for_backward(inputs, output):
return (inputs.x, output[1])
@numpy_cube.impl_backward()
def numpy_cube_backward(ctx, saved, grad_out, grad_dx):
x, dx = saved
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
return {'x': grad_x}
@custom_op('_torch_testing::numpy_mul')
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
...
@ -56,6 +66,22 @@ def numpy_mul_abstract(x, y):
assert x.device == y.device
return (x * y).contiguous()
@numpy_mul.impl_save_for_backward()
def numpy_mul_save_for_backward(inputs, output):
saved = {}
saved['x_requires_grad'] = inputs.x.requires_grad
saved['y_requires_grad'] = inputs.y.requires_grad
# Optimization: only save what is necessary
saved['y'] = inputs.y if inputs.x.requires_grad else None
saved['x'] = inputs.x if inputs.y.requires_grad else None
return saved
@numpy_mul.impl_backward()
def numpy_mul_backward(ctx, saved, grad_out):
grad_x = grad_out * saved['y'] if saved['x_requires_grad'] else None
grad_y = grad_out * saved['x'] if saved['x_requires_grad'] else None
return {'y': grad_y, 'x': grad_x}
@custom_op('_torch_testing::numpy_sort')
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
...
@ -69,7 +95,7 @@ def numpy_sort_impl(x, dim):
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(x, device=device),
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@ -78,6 +104,16 @@ def numpy_sort_impl(x, dim):
def numpy_sort_abstract(x, dim):
return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
@numpy_sort.impl_save_for_backward()
def numpy_sort_save_for_backward(inputs, output):
out, ind, ind_inv = output
return [inputs.dim, ind, ind_inv]
@numpy_sort.impl_backward(output_differentiability=[True, False, False])
def numpy_sort_backward(ctx, saved, grad_out, grad_ind, grad_ind_inv):
dim, ind, ind_inv = saved
return {'x': numpy_take(grad_out, ind_inv, ind, dim)}
@custom_op('_torch_testing::numpy_take')
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
...
@ -94,8 +130,26 @@ def numpy_take_impl(x, ind, ind_inv, dim):
def numpy_take_abstract(x, ind, ind_inv, dim):
assert x.device == ind.device
assert x.device == ind_inv.device
assert ind.dtype == torch.long
assert ind_inv.dtype == torch.long
return torch.empty_like(x)
@numpy_take.impl_save_for_backward()
def numpy_take_save_for_backward(inputs, output):
return {
'dim': inputs.dim,
'ind': inputs.ind,
'ind_inv': inputs.ind_inv,
}
@numpy_take.impl_backward()
def numpy_take_backward(ctx, saved, grad_out):
return {
'x': numpy_take(grad_out, saved['ind_inv'], saved['ind'], saved['dim']),
'ind': None,
'ind_inv': None,
}
@custom_op('_torch_testing::numpy_nonzero')
def numpy_nonzero(x: Tensor) -> Tensor:
...
@ -138,11 +192,56 @@ def numpy_view_copy_impl(x, shape) -> Tensor:
def numpy_view_copy_abstract(x, shape) -> Tensor:
return x.clone().view(shape).clone()
@numpy_view_copy.impl_save_for_backward()
def numpy_view_copy_save_for_backward(inputs, output) -> Tensor:
return inputs.x.shape
@numpy_view_copy.impl_backward()
def numpy_view_copy_backward(ctx, x_shape, grad_out) -> Tensor:
return {'x': numpy_view_copy(grad_out, x_shape)}
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
result = make_arg(2, 3, 4, low=0.9, high=2)
yield SampleInput(result, args=([2, 12],))
@custom_op('_torch_testing::numpy_cat')
def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
...
@numpy_cat.impl(['cpu', 'cuda'])
def numpy_cat_impl(xs, dim):
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
np_xs = [to_numpy(x) for x in xs]
np_out = np.concatenate(np_xs, axis=dim)
return torch.tensor(np_out, device=xs[0].device)
@numpy_cat.impl_abstract()
def numpy_cat_abstract(xs, dim):
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
return torch.cat(xs, dim=dim)
@numpy_cat.impl_save_for_backward()
def numpy_cat_save_for_backward(inputs, output):
dim_sizes = [x.shape[inputs.dim] for x in inputs.xs]
return dim_sizes, inputs.dim
@numpy_cat.impl_backward()
def numpy_cat_backward(ctx, saved, grad_out):
dim_sizes, dim = saved
return {'xs': torch.split(grad_out, dim_sizes, dim)}
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
r0 = make_arg(2, 3, 4, low=0.9, high=2)
r1 = make_arg(4, 3, 4, low=0.9, high=2)
r2 = make_arg(5, 3, 4, low=0.9, high=2)
yield SampleInput([r0, r1, r2], args=(0,))
@custom_op('_torch_testing::numpy_nms')
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
...
@ -257,6 +356,7 @@ custom_op_db = [
op=wrap_for_opinfo(numpy_nonzero),
sample_inputs_func=sample_inputs_numpy_nonzero,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
@ -264,6 +364,7 @@ custom_op_db = [
op=wrap_for_opinfo(numpy_nms),
sample_inputs_func=sample_inputs_numpy_nms,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
@ -271,6 +372,15 @@ custom_op_db = [
op=wrap_for_opinfo(numpy_view_copy),
sample_inputs_func=sample_inputs_numpy_view_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
supports_out=False,
),
OpInfo(
'NumpyCatCustomOp',
op=wrap_for_opinfo(numpy_cat),
sample_inputs_func=sample_inputs_numpy_cat,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
supports_out=False,
),
]