Extend impl_backward to handle non-Tensor outputs (#106800)

Recall that the user must give us a backward function that accepts
`(ctx, saved, *grads)`, with one grad per output. Previously,
impl_backward only worked for functions that return one or more Tensors.

The new semantics are that if the output has:
- a TensorList, the backward function provided by the user will receive
a List[Tensor] of grads for that output.
- a number, the backward function provided by the user will receive
None as the grad.

Also recall that impl_backward is implemented by registering an
autograd.Function to the autograd dispatch key.
We needed to make the following changes:
- If an output is a TensorList, autograd.Function will ignore it. So we
need to tree-flatten it before returning it from the autograd.Function
- This means that the autograd.Function receives a flat list of grad
during the backwards pass. We need to tree-unflatten it into the correct
shape before passing it to the user-defined backward
- We modify the logic of output_differentiability. Only
Tensor/TensorList outputs can be marked as differentiable. If a
TensorList is marked as non-differentiable, then this is equivalent to
all Tensors in the list being non-differentiable. There is no
finer-grain control over this (to match derivatives.yaml).

Test Plan:
- There are new `numpy_split_copy` (returns TensorList) and
`numpy_split_copy_with_int` (returns (TensorList, int)) operators in
custom_op_db
- Added tests for output_differentiability into test/test_custom_ops.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106800
Approved by: https://github.com/soulitzer
ghstack dependencies: #106799
This commit is contained in:
Richard Zou
2023-08-11 11:41:53 -07:00
committed by PyTorch MergeBot
parent 9fcce1baf1
commit db9a0cf689
4 changed files with 175 additions and 18 deletions

View File

@ -1247,6 +1247,56 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo_backward(ctx, saved, grad):
return {"xs": None}
def test_backward_output_differentiability_tensorlist(self):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
raise NotImplementedError()
@custom_ops.impl(f"{self.test_ns}::foo")
def foo_impl(x):
return [x.clone(), x.clone()], x.clone()
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
def foo_save_for_backward(inputs, output):
return []
@custom_ops.impl_backward(
f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
)
def foo_backward(ctx, saved, grad_lst, grad):
return {"x": grad}
op = self.get_op(f"{self.test_ns}::foo")
x = torch.randn(3, requires_grad=True)
[a, b], c = op(x)
self.assertFalse(a.requires_grad)
self.assertFalse(b.requires_grad)
self.assertTrue(c.requires_grad)
def test_backward_output_differentiability_non_tensor(self):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: Tensor) -> Tuple[Tensor, int]:
raise NotImplementedError()
@custom_ops.impl(f"{self.test_ns}::foo")
def foo_impl(x):
return x.clone(), 3
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
def foo_save_for_backward(inputs, output):
return []
@custom_ops.impl_backward(
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
)
def foo_backward(ctx, saved, grad0, grad1):
return {"x": grad0}
op = self.get_op(f"{self.test_ns}::foo")
x = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
op(x)
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_impl_separate(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")

View File

@ -56,6 +56,37 @@ def autograd_not_implemented(custom_op):
return kernel
def mark_non_differentiable(ctx, output, output_differentiability):
# Output types are restricted to be:
# - Tensor
# - Tensor[]
# - int, bool, Scalar, float
if output_differentiability is not None:
if not isinstance(output, tuple):
tuple_output = (output,)
else:
tuple_output = output # type: ignore[assignment]
assert len(output_differentiability) == len(tuple_output)
non_differentiable_tensors = []
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
if isinstance(out, torch.Tensor):
if not differentiable:
non_differentiable_tensors.append(out)
continue
if isinstance(out, list):
if not differentiable:
non_differentiable_tensors.extend(out)
continue
if differentiable:
raise RuntimeError(
f"With output_differentiability={output_differentiability}. "
f"At idx {idx}, we received an object of type {type(out)} that "
f"is not a Tensor, so it cannot have be marked as differentiable in "
f"output_differentiability.")
if non_differentiable_tensors:
ctx.mark_non_differentiable(*non_differentiable_tensors)
def construct_autograd_kernel(
schema,
output_differentiability,
@ -65,6 +96,7 @@ def construct_autograd_kernel(
def apply(*args):
flat_args, spec = pytree.tree_flatten(args)
out_spec = None
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
@ -80,26 +112,21 @@ def construct_autograd_kernel(
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
save_pytree_for_backward(ctx, (to_save, args_info))
mark_non_differentiable(ctx, output, output_differentiability)
# Output must be one or more Tensors, no TensorList (yet)
if output_differentiability is not None:
if isinstance(output, tuple):
assert len(output_differentiability) == len(output)
for differentiable, out in zip(output_differentiability, output):
if not differentiable:
ctx.mark_non_differentiable(out)
else:
assert len(output_differentiability) == 1
if not output_differentiability[0]:
ctx.mark_non_differentiable(output)
nonlocal out_spec
flat_output, out_spec = pytree.tree_flatten(output)
return tuple(flat_output)
return output
def backward(ctx, *grads):
def backward(ctx, *flat_grad_output):
assert out_spec is not None
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
saved, args_info = unpack_saved(ctx)
# There is nothing on the ctx object for now, it is just there so
# that we can add additional things in the future.
inner_ctx = object()
if not isinstance(grads, tuple):
grads = (grads,)
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
# Massage the grad_inputs_dict to a form acceptable by
@ -110,7 +137,9 @@ def construct_autograd_kernel(
generated_cls = gen_autograd_function(
forward_op._opname + '_customop', forward, backward)
return generated_cls.apply(*flat_args)
flat_output = generated_cls.apply(*flat_args)
assert out_spec is not None
return pytree.tree_unflatten(list(flat_output), out_spec)
return apply

View File

@ -580,7 +580,6 @@ def validate_namespace(ns: str) -> None:
f"please choose something else. "
)
def validate_schema(schema: FunctionSchema) -> None:
# Coming in the future. Requires us to have correct logic for
# the ADInplaceOrView key

View File

@ -15,7 +15,7 @@ from torch.testing._internal.autograd_function_db import (
)
from torch import Tensor
from torch.types import Number
from typing import Sequence, Tuple
from typing import * # noqa: F403
import torch._custom_ops as custom_ops
# Note: [custom op db]
@ -229,7 +229,9 @@ def numpy_cat_save_for_backward(inputs, output):
@custom_ops.impl_backward('_torch_testing::numpy_cat')
def numpy_cat_backward(ctx, saved, grad_out):
dim_sizes, dim = saved
return {'xs': torch.split(grad_out, dim_sizes, dim)}
splits = list(np.cumsum(dim_sizes)[:-1])
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
return {'xs': grad_xs}
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -238,6 +240,60 @@ def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
r2 = make_arg(5, 3, 4, low=0.9, high=2)
yield SampleInput([r0, r1, r2], args=(0,))
@custom_ops.custom_op('_torch_testing::numpy_split_copy')
def numpy_split_copy(x: Tensor, sections: Sequence[int], dim: int) -> List[Tensor]:
raise NotImplementedError()
@custom_ops.impl('_torch_testing::numpy_split_copy')
def numpy_split_copy_impl(x, splits, dim):
x_np = to_numpy(x)
arrs = np.split(x_np, splits, axis=dim)
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
@custom_ops.impl_abstract('_torch_testing::numpy_split_copy')
def numpy_split_copy_abstract(x, splits, dim):
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
@custom_ops.impl_save_for_backward('_torch_testing::numpy_split_copy')
def numpy_split_copy_save_for_backward(inputs, output):
return inputs.dim
@custom_ops.impl_backward('_torch_testing::numpy_split_copy')
def numpy_split_copy_backward(ctx, saved, grad_out):
dim = saved
return {'x': torch.ops._torch_testing.numpy_cat(grad_out, dim=dim)}
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
x = make_arg(2, 9, low=0.9, high=2)
yield SampleInput(x, args=([1, 3, 6], 1))
@custom_ops.custom_op('_torch_testing::numpy_split_copy_with_int')
def numpy_split_copy_with_int(x: Tensor, sections: Sequence[int], dim: int) -> Tuple[List[Tensor], int]:
raise NotImplementedError()
@custom_ops.impl('_torch_testing::numpy_split_copy_with_int')
def numpy_split_copy_with_int_impl(x, splits, dim):
x_np = to_numpy(x)
arrs = np.split(x_np, splits, axis=dim)
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
@custom_ops.impl_abstract('_torch_testing::numpy_split_copy_with_int')
def numpy_split_copy_with_int_abstract(x, splits, dim):
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
@custom_ops.impl_save_for_backward(
'_torch_testing::numpy_split_copy_with_int')
def numpy_split_copy_with_int_save_for_backward(inputs, output):
return inputs.dim
@custom_ops.impl_backward(
'_torch_testing::numpy_split_copy_with_int',
output_differentiability=[True, False])
def numpy_split_copy_with_int_backward(ctx, saved, grad_out, _):
dim = saved
return {'x': torch.ops._torch_testing.numpy_cat(grad_out, dim=dim)}
@custom_ops.custom_op('_torch_testing::numpy_nms')
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
raise NotImplementedError()
@ -371,6 +427,29 @@ custom_op_db = [
sample_inputs_func=sample_inputs_numpy_cat,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
OpInfo(
'NumpySplitCopyCustomOp',
op=torch.ops._torch_testing.numpy_split_copy,
sample_inputs_func=sample_inputs_numpy_split_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
OpInfo(
'NumpySplitCopyWithIntCustomOp',
op=torch.ops._torch_testing.numpy_split_copy_with_int,
sample_inputs_func=sample_inputs_numpy_split_copy,
dtypes=all_types_and(torch.bool, torch.half),
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
]