mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom_op] explicit autograd API (#101824)
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101824 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8487105fae
commit
723f111545
@ -6,6 +6,7 @@ import torch
|
||||
from torch.testing._internal.common_utils import TestGradients, run_tests
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
|
||||
@ -18,7 +19,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
|
||||
class TestBwdGradients(TestGradients):
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
# This is verified by test_dtypes in test_ops.py
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
@ -33,7 +34,7 @@ class TestBwdGradients(TestGradients):
|
||||
# self._skip_helper(op, device, dtype)
|
||||
# self._grad_test_helper(device, dtype, op, op.get_method())
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
@_gradcheck_ops(op_db + custom_op_db)
|
||||
def test_inplace_grad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.inplace_variant:
|
||||
@ -52,7 +53,7 @@ class TestBwdGradients(TestGradients):
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
@ -61,7 +62,7 @@ class TestBwdGradients(TestGradients):
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
|
||||
# Test that gradients of gradients are properly raising
|
||||
@_gradcheck_ops(op_db)
|
||||
@_gradcheck_ops(op_db + custom_op_db)
|
||||
def test_fn_fail_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if op.supports_gradgrad:
|
||||
|
@ -730,7 +730,7 @@ class TestCustomOp(TestCase):
|
||||
|
||||
def test_private_ctor(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'CustomOp constructor is private'):
|
||||
CustomOp(None, None, None, None)
|
||||
CustomOp(None, None, None, None, None)
|
||||
|
||||
def test_lifetime(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
@ -789,6 +789,21 @@ class TestCustomOp(TestCase):
|
||||
foo(y, x)
|
||||
foo._destroy()
|
||||
|
||||
def test_autograd_notimplemented_gradmode(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu'])
|
||||
def foo_impl(x, y):
|
||||
return x * y
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
y = torch.randn(3)
|
||||
with torch.no_grad():
|
||||
# Shouldn't raise, because we are in no_grad
|
||||
foo(y, x)
|
||||
|
||||
def test_impl_cpu(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -824,6 +839,271 @@ class TestCustomOp(TestCase):
|
||||
foo.impl(invalid_type)(foo_impl)
|
||||
foo._destroy()
|
||||
|
||||
def test_backward_partially_registered(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return grad * saved.cos()
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, "unable to find a 'save_for_backward'"):
|
||||
y = foo(x)
|
||||
y.backward()
|
||||
|
||||
def test_save_for_backward_inputs_are_namedtuple(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
hit = 0
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
nonlocal hit
|
||||
hit += 1
|
||||
self.assertTrue(isinstance(inputs, tuple))
|
||||
self.assertEqual(list(inputs._asdict().keys()), ['x'])
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'x': grad * saved.cos()}
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x)
|
||||
self.assertEqual(hit, 1)
|
||||
y.backward()
|
||||
self.assertEqual(hit, 1)
|
||||
|
||||
def test_backward_returns_dict(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return grad * saved.cos()
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x)
|
||||
with self.assertRaisesRegex(RuntimeError, 'to be a dict'):
|
||||
y.backward()
|
||||
|
||||
def test_backward_dict_invalid_keys(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'x': grad * saved.cos(), 'y': None}
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x)
|
||||
with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_dict_grad_for_nontensor(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x, dim):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'x': grad * saved.cos(), 'dim': None}
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x, 32)
|
||||
with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_dict_requires_keys_for_input_tensors(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x, y):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'x': grad * saved.cos()}
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x, x)
|
||||
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x, y):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'x': grad * saved.cos()}
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x, None)
|
||||
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_grads_are_tensor_or_none(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'x': (grad * saved.cos(),)}
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = foo(x)
|
||||
with self.assertRaisesRegex(RuntimeError, 'either None or a Tensor'):
|
||||
y.backward()
|
||||
|
||||
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(xs):
|
||||
return xs[0].sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.xs[0]
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'xs': [grad * saved.cos(), None]}
|
||||
|
||||
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
|
||||
y = foo(xs)
|
||||
with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(xs):
|
||||
return xs[0].sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.xs[0]
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'xs': [grad * saved.cos(), None, (None,)]}
|
||||
|
||||
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
|
||||
y = foo(xs)
|
||||
with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_tensorlist_input_requires_list_grads(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@foo.impl(['cpu', 'cuda'])
|
||||
def foo_impl(xs):
|
||||
return xs[0].sin()
|
||||
|
||||
@foo.impl_save_for_backward()
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.xs[0]
|
||||
|
||||
@foo.impl_backward()
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'xs': None}
|
||||
|
||||
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
|
||||
y = foo(xs)
|
||||
with self.assertRaisesRegex(RuntimeError, "list of gradients"):
|
||||
y.backward()
|
||||
|
||||
def test_backward_output_differentiability_type(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
...
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
|
||||
@foo.impl_backward(output_differentiability=True)
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'xs': None}
|
||||
|
||||
def test_backward_output_differentiability_numel(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
|
||||
@foo.impl_backward(output_differentiability=[True])
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {'xs': None}
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_impl_separate(self):
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
|
243
torch/_custom_op/autograd.py
Normal file
243
torch/_custom_op/autograd.py
Normal file
@ -0,0 +1,243 @@
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
|
||||
|
||||
# NOTE [CustomOp autograd kernel indirection]
|
||||
# We register `inner` as the autograd kernel for this custom_op.
|
||||
# `inner` either calls the autograd formula registered by the user,
|
||||
# or goes into an `autograd_not_implemented` kernel.
|
||||
#
|
||||
# The reason why this indirection exists is
|
||||
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
||||
# doesn't actually allow us to do this). By default, we want
|
||||
# the `autograd_not_implemented` behavior, but then the user may come
|
||||
# and register something that is actually a backward formula
|
||||
def autograd_kernel_indirection(custom_op):
|
||||
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
||||
# or change the default autograd fallback to the autograd not implemented fallback.
|
||||
def autograd_not_implemented(*args, **kwargs) -> None:
|
||||
if torch.is_grad_enabled() and pytree.tree_any(
|
||||
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
||||
):
|
||||
raise RuntimeError("Autograd has not been implemented for operator")
|
||||
guard = torch._C._AutoDispatchBelowAutograd()
|
||||
try:
|
||||
return custom_op(*args, **kwargs)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
if custom_op._has_impl('autograd'):
|
||||
kernel = custom_op._get_impl('autograd').func
|
||||
return kernel(*args, **kwargs)
|
||||
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
||||
# after the user gives us "backward" and "save_for_backward", we generate
|
||||
# the "autograd" impl. If the user only provided one, then we tell
|
||||
# the user they've done something wrong.
|
||||
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
|
||||
missing = (
|
||||
'save_for_backward' if custom_op._has_impl('backward')
|
||||
else 'backward'
|
||||
)
|
||||
found = 'save_for_backward' if missing == 'backward' else 'backward'
|
||||
loc = custom_op._get_impl(found).location
|
||||
raise RuntimeError(
|
||||
f"We found a '{found}' registration for {custom_op} at "
|
||||
f"{loc} but were unable to find a '{missing}' registration. "
|
||||
f"To use the CustomOp API to register a backward formula, "
|
||||
f"please provide us both a backward function and a "
|
||||
f"'save for backward' function via `impl_backward` and "
|
||||
f"`impl_save_for_backward` respectively.")
|
||||
return autograd_not_implemented(*args, **kwargs)
|
||||
return inner
|
||||
|
||||
|
||||
def construct_autograd_kernel(
|
||||
schema,
|
||||
output_differentiability,
|
||||
forward_op,
|
||||
save_for_backward_fn,
|
||||
backward_fn):
|
||||
|
||||
def apply(*args):
|
||||
flat_args, spec = pytree.tree_flatten(args)
|
||||
|
||||
def forward(ctx, *flat_args):
|
||||
ctx.set_materialize_grads(True)
|
||||
args = pytree.tree_unflatten(list(flat_args), spec)
|
||||
guard = torch._C._AutoDispatchBelowAutograd()
|
||||
try:
|
||||
output = forward_op(*args)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
# We use the info about args to give better error messages in backward
|
||||
args_info = namedtuple_args(
|
||||
schema, pytree.tree_map(lambda arg: type(arg), args))
|
||||
|
||||
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
||||
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
||||
|
||||
save_pytree_for_backward(ctx, (to_save, args_info))
|
||||
|
||||
# Output must be one or more Tensors, no TensorList (yet)
|
||||
if output_differentiability is not None:
|
||||
if isinstance(output, tuple):
|
||||
assert len(output_differentiability) == len(output)
|
||||
for differentiable, out in zip(output_differentiability, output):
|
||||
if not differentiable:
|
||||
ctx.mark_non_differentiable(out)
|
||||
else:
|
||||
assert len(output_differentiability) == 1
|
||||
if not output_differentiability[0]:
|
||||
ctx.mark_non_differentiable(output)
|
||||
|
||||
return output
|
||||
|
||||
def backward(ctx, *grads):
|
||||
saved, args_info = unpack_saved(ctx)
|
||||
# There is nothing on the ctx object for now, it is just there so
|
||||
# that we can add additional things in the future.
|
||||
inner_ctx = object()
|
||||
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
||||
|
||||
# Massage the grad_inputs_dict to a form acceptable by
|
||||
# autograd.Function.
|
||||
validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info)
|
||||
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
||||
|
||||
generated_cls = gen_autograd_function(
|
||||
forward_op._opname + '_customop', forward, backward)
|
||||
|
||||
return generated_cls.apply(*flat_args)
|
||||
return apply
|
||||
|
||||
|
||||
def gen_autograd_function(name, forward, backward):
|
||||
generated_cls = type(
|
||||
name,
|
||||
(torch.autograd.Function,),
|
||||
{
|
||||
'forward': staticmethod(forward),
|
||||
'backward': staticmethod(backward),
|
||||
}
|
||||
)
|
||||
return generated_cls
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def namedtuple_args_cls(schema):
|
||||
attribs = [arg.name for arg in schema.arguments.flat_all]
|
||||
name = str(schema.name) + "_args"
|
||||
# mypy doesn't support dynamic namedtuple name
|
||||
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
||||
return tuple_cls
|
||||
|
||||
|
||||
def namedtuple_args(schema, args):
|
||||
assert isinstance(args, tuple)
|
||||
tuple_cls = namedtuple_args_cls(schema)
|
||||
return tuple_cls(*args)
|
||||
|
||||
|
||||
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
||||
def error(what):
|
||||
backward = forward_op._get_impl('backward')
|
||||
raise RuntimeError(
|
||||
f"In the backward function defined for {forward_op} at "
|
||||
f"{backward.location} using the CustomOp API, {what}")
|
||||
|
||||
if not isinstance(grad_inputs_dict, dict):
|
||||
error(f"expected the output of the backward function to be a dict but "
|
||||
f"got {type(grad_inputs_dict)}")
|
||||
|
||||
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
|
||||
if arg.type.is_tensor_like()}
|
||||
actual_keys = grad_inputs_dict.keys()
|
||||
if expected_keys != actual_keys:
|
||||
error(f"expected the returned grad_input dict to have keys "
|
||||
f"{expected_keys} but got {actual_keys}. The backward "
|
||||
f"function must return a gradient (can be None) for each arg "
|
||||
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
||||
f"Args declared to be non-Tensor-like types should not appear "
|
||||
f"in the grad_input dict")
|
||||
|
||||
for name, grad in grad_inputs_dict.items():
|
||||
arg_info = getattr(args_info, name)
|
||||
|
||||
if isinstance(arg_info, list):
|
||||
if not isinstance(grad, (tuple, list)):
|
||||
error(f"for input '{name}' expected the grad_input dict to "
|
||||
f"hold a list of gradients but got object of type "
|
||||
f"{type(grad)}.")
|
||||
if not len(grad) == len(arg_info):
|
||||
error(f"for input '{name}' expected the grad_input dict to "
|
||||
f"hold a list of {len(arg_info)} gradients but got "
|
||||
f"{len(grad)}")
|
||||
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
||||
if g is None:
|
||||
continue
|
||||
if not isinstance(g, torch.Tensor):
|
||||
error(f"for input '{name}' expected the grad_input dict to "
|
||||
f"hold a list of None or Tensor gradients but got "
|
||||
f"object of {type(g)} at index {idx}")
|
||||
if info != torch.Tensor:
|
||||
error(f"for input '{name}', got a Tensor as the gradient "
|
||||
f"for the {idx}-th value but expected None because "
|
||||
f"the {idx}-th value was not a Tensor (it was "
|
||||
f"type {arg_info}")
|
||||
continue
|
||||
|
||||
if grad is None:
|
||||
continue
|
||||
if not isinstance(grad, torch.Tensor):
|
||||
error(f"got object of type {type(grad)} as the gradient for input "
|
||||
f"'{name}', "
|
||||
f"but expected the gradient to be either None or a Tensor")
|
||||
if arg_info != torch.Tensor:
|
||||
error(f"got a Tensor as the gradient for input '{name}' but "
|
||||
f"expected None as the gradient because input '{name}' "
|
||||
f"was not a Tensor (it was type {arg_info}).")
|
||||
|
||||
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
||||
result = []
|
||||
for name, arg_info in args_info._asdict().items():
|
||||
if name not in grad_inputs_dict:
|
||||
result.append(pytree.tree_map(lambda x: None, arg_info))
|
||||
continue
|
||||
result.append(grad_inputs_dict[name])
|
||||
return tuple(pytree.tree_flatten(result)[0])
|
||||
|
||||
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
||||
# autograd.Function prefers that users use ctx.save_for_backward to
|
||||
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
||||
# ctx object.
|
||||
def save_pytree_for_backward(ctx, stuff):
|
||||
flat_stuff, spec = pytree.tree_flatten(stuff)
|
||||
num_elts = len(flat_stuff)
|
||||
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
||||
if isinstance(thing, torch.Tensor)]
|
||||
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
||||
if not isinstance(thing, torch.Tensor)]
|
||||
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
||||
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
||||
|
||||
ctx.spec = spec
|
||||
ctx.num_elts = num_elts
|
||||
ctx.save_for_backward(*tensors)
|
||||
ctx.tensor_idxs = tensor_idxs
|
||||
ctx.saved_non_tensors = non_tensors
|
||||
ctx.non_tensor_idxs = non_tensor_idxs
|
||||
|
||||
|
||||
# Inverse operation to save_pytree_for_backward
|
||||
def unpack_saved(ctx):
|
||||
flat_stuff = [None] * ctx.num_elts
|
||||
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
||||
flat_stuff[idx] = tensor
|
||||
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
||||
flat_stuff[idx] = non_tensor
|
||||
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
||||
return stuff
|
@ -10,7 +10,8 @@ from torchgen.model import FunctionSchema, OperatorName, SchemaKind
|
||||
import torch
|
||||
import torch._C as _C
|
||||
import torch.library as library
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
|
||||
|
||||
"""
|
||||
There are various APIs for defining custom-operator-like things in PyTorch:
|
||||
@ -143,18 +144,14 @@ def custom_op(
|
||||
lib = library.Library(ns, "FRAGMENT")
|
||||
lib.define(schema_str)
|
||||
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
||||
result = CustomOp(lib, ns, function_schema.name, ophandle, _private_access=True)
|
||||
result = CustomOp(lib, ns, function_schema, function_schema.name, ophandle, _private_access=True)
|
||||
|
||||
result.__name__ = func.__name__
|
||||
result.__module__ = func.__module__
|
||||
result.__doc__ = func.__doc__
|
||||
|
||||
# NYI: autograd not supported
|
||||
# In the near future we will either directly use the
|
||||
# autograd_not_implemented kernels or make those the default fallback
|
||||
# for the Autograd and ADInplaceOrView keys. Both of those are a bit tricky.
|
||||
library.impl(lib, result._opname, "Autograd")(
|
||||
get_autograd_not_implemented_kernel(weakref.proxy(result))
|
||||
autograd_kernel_indirection(weakref.proxy(result))
|
||||
)
|
||||
|
||||
torch._C._dispatch_set_report_error_callback(
|
||||
@ -185,7 +182,7 @@ class CustomOp:
|
||||
To construct a `CustomOp`, use `custom_op`.
|
||||
"""
|
||||
|
||||
def __init__(self, lib, cpp_ns, operator_name, ophandle, *, _private_access=False):
|
||||
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
|
||||
super(CustomOp, self).__init__()
|
||||
if not _private_access:
|
||||
raise RuntimeError(
|
||||
@ -193,6 +190,7 @@ class CustomOp:
|
||||
"BC for it. Please use custom_op(...) to create a CustomOp object"
|
||||
)
|
||||
name = f"{cpp_ns}::{str(operator_name.name)}"
|
||||
self._schema = schema
|
||||
self._cpp_ns = cpp_ns
|
||||
self._lib: library.Library = lib
|
||||
self._ophandle: _C._DispatchOperatorHandle = ophandle
|
||||
@ -417,6 +415,90 @@ class CustomOp:
|
||||
|
||||
return inner
|
||||
|
||||
# NOTE ["backward", "save_for_backward", and "autograd"]
|
||||
# As a part of the explicit autograd API, a user must provide us
|
||||
# a "save_for_backward" function and a "backward" function.
|
||||
# When both of these have been provided, then we automatically
|
||||
# construct the "autograd" kernel.
|
||||
def _register_autograd_kernel(self):
|
||||
assert self._has_impl("backward")
|
||||
assert self._has_impl("save_for_backward")
|
||||
kernel = construct_autograd_kernel(
|
||||
self._schema,
|
||||
self._output_differentiability,
|
||||
self,
|
||||
self._get_impl("save_for_backward").func,
|
||||
self._get_impl("backward").func)
|
||||
self._register_impl("autograd", kernel)
|
||||
|
||||
def impl_save_for_backward(self):
|
||||
r"""Register a function that tells us what to save for backward.
|
||||
|
||||
Please see impl_backward for more details.
|
||||
"""
|
||||
def inner(f):
|
||||
self._register_impl("save_for_backward", f)
|
||||
if self._has_impl("backward"):
|
||||
self._register_autograd_kernel()
|
||||
return inner
|
||||
|
||||
def impl_backward(self, output_differentiability=None):
|
||||
r"""Registers a backward formula.
|
||||
|
||||
In order for the CustomOp to work with autograd, you need to register
|
||||
a backward formula. There are two pieces to this:
|
||||
1. You must give us a function to specify what to save for backward.
|
||||
Call this the "save for backward" function.
|
||||
2. You must give us a function that computes gradients. Call this the
|
||||
"backward" function.
|
||||
|
||||
Use `impl_save_for_backward` to define a "save for backward" function
|
||||
that specifies what gets saved for backward. The function should accept
|
||||
two arguments ``(inputs, output)`` and return the quantities to be saved
|
||||
for backward.
|
||||
|
||||
During runtime, when you call the CustomOp, PyTorch will invoke the
|
||||
"save for backward" function with the inputs and output of the CustomOp.
|
||||
|
||||
Use `impl_backward` to define the "backward" function. The backward
|
||||
function must accept ``(ctx, saved, *grads)``:
|
||||
- ``ctx`` is a context object where we may provide information
|
||||
- ``saved`` is exactly what gets returned from the "save for backward"
|
||||
function
|
||||
- ``grads`` is one or more gradients. The number of gradients matches
|
||||
the number of outputs of the CustomOp.
|
||||
|
||||
The backward function must return a dict that maps the name of
|
||||
an input to the CustomOp to its corresponding gradient. All inputs that
|
||||
were declared to be Tensors in the CustomOp definition must be accounted
|
||||
for in the dict. The gradient may be a Tensor or None.
|
||||
|
||||
TODO(rzou): Add example when this PR is closer to landing.
|
||||
|
||||
"""
|
||||
if output_differentiability is not None:
|
||||
def yell():
|
||||
raise RuntimeError(
|
||||
f"impl_backward(output_differentiability): expected "
|
||||
f"output_differentiability to be a list of bools with "
|
||||
f"length equal to the number of outputs of this CustomOp "
|
||||
f"got: {output_differentiability}")
|
||||
|
||||
if not isinstance(output_differentiability, list):
|
||||
yell()
|
||||
for diff in output_differentiability:
|
||||
if not isinstance(diff, bool):
|
||||
yell()
|
||||
if len(self._schema.returns) != len(output_differentiability):
|
||||
yell()
|
||||
|
||||
def inner(f):
|
||||
self._register_impl("backward", f)
|
||||
self._output_differentiability = output_differentiability
|
||||
if self._has_impl("save_for_backward"):
|
||||
self._register_autograd_kernel()
|
||||
return inner
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FuncAndLocation:
|
||||
@ -490,21 +572,6 @@ def validate_device_type(device_type: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_autograd_not_implemented_kernel(custom_op) -> typing.Callable:
|
||||
def autograd_not_implemented(*args, **kwargs) -> None:
|
||||
if pytree.tree_any(
|
||||
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
||||
):
|
||||
raise RuntimeError("Autograd has not been implemented for operator")
|
||||
guard = _C._AutoDispatchBelowAutograd()
|
||||
try:
|
||||
return custom_op(*args, **kwargs)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
return autograd_not_implemented
|
||||
|
||||
|
||||
def supported_param(param: inspect.Parameter) -> bool:
|
||||
return param.kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
|
@ -42,6 +42,16 @@ def numpy_cube_impl(x):
|
||||
def numpy_cube_abstract(x):
|
||||
return x.clone(), x.clone()
|
||||
|
||||
@numpy_cube.impl_save_for_backward()
|
||||
def numpy_cube_save_for_backward(inputs, output):
|
||||
return (inputs.x, output[1])
|
||||
|
||||
@numpy_cube.impl_backward()
|
||||
def numpy_cube_backward(ctx, saved, grad_out, grad_dx):
|
||||
x, dx = saved
|
||||
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
|
||||
return {'x': grad_x}
|
||||
|
||||
@custom_op('_torch_testing::numpy_mul')
|
||||
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||
...
|
||||
@ -56,6 +66,22 @@ def numpy_mul_abstract(x, y):
|
||||
assert x.device == y.device
|
||||
return (x * y).contiguous()
|
||||
|
||||
@numpy_mul.impl_save_for_backward()
|
||||
def numpy_mul_save_for_backward(inputs, output):
|
||||
saved = {}
|
||||
saved['x_requires_grad'] = inputs.x.requires_grad
|
||||
saved['y_requires_grad'] = inputs.y.requires_grad
|
||||
# Optimization: only save what is necessary
|
||||
saved['y'] = inputs.y if inputs.x.requires_grad else None
|
||||
saved['x'] = inputs.x if inputs.y.requires_grad else None
|
||||
return saved
|
||||
|
||||
@numpy_mul.impl_backward()
|
||||
def numpy_mul_backward(ctx, saved, grad_out):
|
||||
grad_x = grad_out * saved['y'] if saved['x_requires_grad'] else None
|
||||
grad_y = grad_out * saved['x'] if saved['x_requires_grad'] else None
|
||||
return {'y': grad_y, 'x': grad_x}
|
||||
|
||||
@custom_op('_torch_testing::numpy_sort')
|
||||
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
...
|
||||
@ -69,7 +95,7 @@ def numpy_sort_impl(x, dim):
|
||||
ind_inv = np.argsort(ind, axis=dim)
|
||||
result = np.take_along_axis(x, ind, axis=dim)
|
||||
return (
|
||||
torch.tensor(x, device=device),
|
||||
torch.tensor(result, device=device),
|
||||
torch.tensor(ind, device=device),
|
||||
torch.tensor(ind_inv, device=device),
|
||||
)
|
||||
@ -78,6 +104,16 @@ def numpy_sort_impl(x, dim):
|
||||
def numpy_sort_abstract(x, dim):
|
||||
return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
|
||||
|
||||
@numpy_sort.impl_save_for_backward()
|
||||
def numpy_sort_save_for_backward(inputs, output):
|
||||
out, ind, ind_inv = output
|
||||
return [inputs.dim, ind, ind_inv]
|
||||
|
||||
@numpy_sort.impl_backward(output_differentiability=[True, False, False])
|
||||
def numpy_sort_backward(ctx, saved, grad_out, grad_ind, grad_ind_inv):
|
||||
dim, ind, ind_inv = saved
|
||||
return {'x': numpy_take(grad_out, ind_inv, ind, dim)}
|
||||
|
||||
@custom_op('_torch_testing::numpy_take')
|
||||
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
|
||||
...
|
||||
@ -94,8 +130,26 @@ def numpy_take_impl(x, ind, ind_inv, dim):
|
||||
def numpy_take_abstract(x, ind, ind_inv, dim):
|
||||
assert x.device == ind.device
|
||||
assert x.device == ind_inv.device
|
||||
assert ind.dtype == torch.long
|
||||
assert ind_inv.dtype == torch.long
|
||||
return torch.empty_like(x)
|
||||
|
||||
@numpy_take.impl_save_for_backward()
|
||||
def numpy_take_save_for_backward(inputs, output):
|
||||
return {
|
||||
'dim': inputs.dim,
|
||||
'ind': inputs.ind,
|
||||
'ind_inv': inputs.ind_inv,
|
||||
}
|
||||
|
||||
@numpy_take.impl_backward()
|
||||
def numpy_take_backward(ctx, saved, grad_out):
|
||||
return {
|
||||
'x': numpy_take(grad_out, saved['ind_inv'], saved['ind'], saved['dim']),
|
||||
'ind': None,
|
||||
'ind_inv': None,
|
||||
}
|
||||
|
||||
@custom_op('_torch_testing::numpy_nonzero')
|
||||
def numpy_nonzero(x: Tensor) -> Tensor:
|
||||
...
|
||||
@ -138,11 +192,56 @@ def numpy_view_copy_impl(x, shape) -> Tensor:
|
||||
def numpy_view_copy_abstract(x, shape) -> Tensor:
|
||||
return x.clone().view(shape).clone()
|
||||
|
||||
@numpy_view_copy.impl_save_for_backward()
|
||||
def numpy_view_copy_save_for_backward(inputs, output) -> Tensor:
|
||||
return inputs.x.shape
|
||||
|
||||
@numpy_view_copy.impl_backward()
|
||||
def numpy_view_copy_backward(ctx, x_shape, grad_out) -> Tensor:
|
||||
return {'x': numpy_view_copy(grad_out, x_shape)}
|
||||
|
||||
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
result = make_arg(2, 3, 4, low=0.9, high=2)
|
||||
yield SampleInput(result, args=([2, 12],))
|
||||
|
||||
@custom_op('_torch_testing::numpy_cat')
|
||||
def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
|
||||
...
|
||||
|
||||
@numpy_cat.impl(['cpu', 'cuda'])
|
||||
def numpy_cat_impl(xs, dim):
|
||||
assert len(xs) > 0
|
||||
assert all(x.device == xs[0].device for x in xs)
|
||||
assert all(x.dtype == xs[0].dtype for x in xs)
|
||||
np_xs = [to_numpy(x) for x in xs]
|
||||
np_out = np.concatenate(np_xs, axis=dim)
|
||||
return torch.tensor(np_out, device=xs[0].device)
|
||||
|
||||
@numpy_cat.impl_abstract()
|
||||
def numpy_cat_abstract(xs, dim):
|
||||
assert len(xs) > 0
|
||||
assert all(x.device == xs[0].device for x in xs)
|
||||
assert all(x.dtype == xs[0].dtype for x in xs)
|
||||
return torch.cat(xs, dim=dim)
|
||||
|
||||
@numpy_cat.impl_save_for_backward()
|
||||
def numpy_cat_save_for_backward(inputs, output):
|
||||
dim_sizes = [x.shape[inputs.dim] for x in inputs.xs]
|
||||
return dim_sizes, inputs.dim
|
||||
|
||||
@numpy_cat.impl_backward()
|
||||
def numpy_cat_backward(ctx, saved, grad_out):
|
||||
dim_sizes, dim = saved
|
||||
return {'xs': torch.split(grad_out, dim_sizes, dim)}
|
||||
|
||||
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
r0 = make_arg(2, 3, 4, low=0.9, high=2)
|
||||
r1 = make_arg(4, 3, 4, low=0.9, high=2)
|
||||
r2 = make_arg(5, 3, 4, low=0.9, high=2)
|
||||
yield SampleInput([r0, r1, r2], args=(0,))
|
||||
|
||||
@custom_op('_torch_testing::numpy_nms')
|
||||
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
|
||||
...
|
||||
@ -257,6 +356,7 @@ custom_op_db = [
|
||||
op=wrap_for_opinfo(numpy_nonzero),
|
||||
sample_inputs_func=sample_inputs_numpy_nonzero,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_autograd=False,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
@ -264,6 +364,7 @@ custom_op_db = [
|
||||
op=wrap_for_opinfo(numpy_nms),
|
||||
sample_inputs_func=sample_inputs_numpy_nms,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_autograd=False,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
@ -271,6 +372,15 @@ custom_op_db = [
|
||||
op=wrap_for_opinfo(numpy_view_copy),
|
||||
sample_inputs_func=sample_inputs_numpy_view_copy,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
'NumpyCatCustomOp',
|
||||
op=wrap_for_opinfo(numpy_cat),
|
||||
sample_inputs_func=sample_inputs_numpy_cat,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
),
|
||||
]
|
||||
|
Reference in New Issue
Block a user