mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Extend impl_backward to handle non-Tensor outputs (#106800)
Recall that the user must give us a backward function that accepts `(ctx, saved, *grads)`, with one grad per output. Previously, impl_backward only worked for functions that return one or more Tensors. The new semantics are that if the output has: - a TensorList, the backward function provided by the user will receive a List[Tensor] of grads for that output. - a number, the backward function provided by the user will receive None as the grad. Also recall that impl_backward is implemented by registering an autograd.Function to the autograd dispatch key. We needed to make the following changes: - If an output is a TensorList, autograd.Function will ignore it. So we need to tree-flatten it before returning it from the autograd.Function - This means that the autograd.Function receives a flat list of grad during the backwards pass. We need to tree-unflatten it into the correct shape before passing it to the user-defined backward - We modify the logic of output_differentiability. Only Tensor/TensorList outputs can be marked as differentiable. If a TensorList is marked as non-differentiable, then this is equivalent to all Tensors in the list being non-differentiable. There is no finer-grain control over this (to match derivatives.yaml). Test Plan: - There are new `numpy_split_copy` (returns TensorList) and `numpy_split_copy_with_int` (returns (TensorList, int)) operators in custom_op_db - Added tests for output_differentiability into test/test_custom_ops.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/106800 Approved by: https://github.com/soulitzer ghstack dependencies: #106799
This commit is contained in:
committed by
PyTorch MergeBot
parent
9fcce1baf1
commit
db9a0cf689
@ -1247,6 +1247,56 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def foo_backward(ctx, saved, grad):
|
||||
return {"xs": None}
|
||||
|
||||
def test_backward_output_differentiability_tensorlist(self):
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@custom_ops.impl(f"{self.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
return [x.clone(), x.clone()], x.clone()
|
||||
|
||||
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return []
|
||||
|
||||
@custom_ops.impl_backward(
|
||||
f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
|
||||
)
|
||||
def foo_backward(ctx, saved, grad_lst, grad):
|
||||
return {"x": grad}
|
||||
|
||||
op = self.get_op(f"{self.test_ns}::foo")
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
[a, b], c = op(x)
|
||||
self.assertFalse(a.requires_grad)
|
||||
self.assertFalse(b.requires_grad)
|
||||
self.assertTrue(c.requires_grad)
|
||||
|
||||
def test_backward_output_differentiability_non_tensor(self):
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: Tensor) -> Tuple[Tensor, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@custom_ops.impl(f"{self.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
return x.clone(), 3
|
||||
|
||||
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return []
|
||||
|
||||
@custom_ops.impl_backward(
|
||||
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
|
||||
)
|
||||
def foo_backward(ctx, saved, grad0, grad1):
|
||||
return {"x": grad0}
|
||||
|
||||
op = self.get_op(f"{self.test_ns}::foo")
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
|
||||
op(x)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_impl_separate(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
|
||||
@ -56,6 +56,37 @@ def autograd_not_implemented(custom_op):
|
||||
return kernel
|
||||
|
||||
|
||||
def mark_non_differentiable(ctx, output, output_differentiability):
|
||||
# Output types are restricted to be:
|
||||
# - Tensor
|
||||
# - Tensor[]
|
||||
# - int, bool, Scalar, float
|
||||
if output_differentiability is not None:
|
||||
if not isinstance(output, tuple):
|
||||
tuple_output = (output,)
|
||||
else:
|
||||
tuple_output = output # type: ignore[assignment]
|
||||
assert len(output_differentiability) == len(tuple_output)
|
||||
non_differentiable_tensors = []
|
||||
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
|
||||
if isinstance(out, torch.Tensor):
|
||||
if not differentiable:
|
||||
non_differentiable_tensors.append(out)
|
||||
continue
|
||||
if isinstance(out, list):
|
||||
if not differentiable:
|
||||
non_differentiable_tensors.extend(out)
|
||||
continue
|
||||
if differentiable:
|
||||
raise RuntimeError(
|
||||
f"With output_differentiability={output_differentiability}. "
|
||||
f"At idx {idx}, we received an object of type {type(out)} that "
|
||||
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
||||
f"output_differentiability.")
|
||||
if non_differentiable_tensors:
|
||||
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
||||
|
||||
|
||||
def construct_autograd_kernel(
|
||||
schema,
|
||||
output_differentiability,
|
||||
@ -65,6 +96,7 @@ def construct_autograd_kernel(
|
||||
|
||||
def apply(*args):
|
||||
flat_args, spec = pytree.tree_flatten(args)
|
||||
out_spec = None
|
||||
|
||||
def forward(ctx, *flat_args):
|
||||
ctx.set_materialize_grads(True)
|
||||
@ -80,26 +112,21 @@ def construct_autograd_kernel(
|
||||
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
||||
|
||||
save_pytree_for_backward(ctx, (to_save, args_info))
|
||||
mark_non_differentiable(ctx, output, output_differentiability)
|
||||
|
||||
# Output must be one or more Tensors, no TensorList (yet)
|
||||
if output_differentiability is not None:
|
||||
if isinstance(output, tuple):
|
||||
assert len(output_differentiability) == len(output)
|
||||
for differentiable, out in zip(output_differentiability, output):
|
||||
if not differentiable:
|
||||
ctx.mark_non_differentiable(out)
|
||||
else:
|
||||
assert len(output_differentiability) == 1
|
||||
if not output_differentiability[0]:
|
||||
ctx.mark_non_differentiable(output)
|
||||
nonlocal out_spec
|
||||
flat_output, out_spec = pytree.tree_flatten(output)
|
||||
return tuple(flat_output)
|
||||
|
||||
return output
|
||||
|
||||
def backward(ctx, *grads):
|
||||
def backward(ctx, *flat_grad_output):
|
||||
assert out_spec is not None
|
||||
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
||||
saved, args_info = unpack_saved(ctx)
|
||||
# There is nothing on the ctx object for now, it is just there so
|
||||
# that we can add additional things in the future.
|
||||
inner_ctx = object()
|
||||
if not isinstance(grads, tuple):
|
||||
grads = (grads,)
|
||||
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
||||
|
||||
# Massage the grad_inputs_dict to a form acceptable by
|
||||
@ -110,7 +137,9 @@ def construct_autograd_kernel(
|
||||
generated_cls = gen_autograd_function(
|
||||
forward_op._opname + '_customop', forward, backward)
|
||||
|
||||
return generated_cls.apply(*flat_args)
|
||||
flat_output = generated_cls.apply(*flat_args)
|
||||
assert out_spec is not None
|
||||
return pytree.tree_unflatten(list(flat_output), out_spec)
|
||||
return apply
|
||||
|
||||
|
||||
|
||||
@ -580,7 +580,6 @@ def validate_namespace(ns: str) -> None:
|
||||
f"please choose something else. "
|
||||
)
|
||||
|
||||
|
||||
def validate_schema(schema: FunctionSchema) -> None:
|
||||
# Coming in the future. Requires us to have correct logic for
|
||||
# the ADInplaceOrView key
|
||||
|
||||
@ -15,7 +15,7 @@ from torch.testing._internal.autograd_function_db import (
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.types import Number
|
||||
from typing import Sequence, Tuple
|
||||
from typing import * # noqa: F403
|
||||
import torch._custom_ops as custom_ops
|
||||
|
||||
# Note: [custom op db]
|
||||
@ -229,7 +229,9 @@ def numpy_cat_save_for_backward(inputs, output):
|
||||
@custom_ops.impl_backward('_torch_testing::numpy_cat')
|
||||
def numpy_cat_backward(ctx, saved, grad_out):
|
||||
dim_sizes, dim = saved
|
||||
return {'xs': torch.split(grad_out, dim_sizes, dim)}
|
||||
splits = list(np.cumsum(dim_sizes)[:-1])
|
||||
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
|
||||
return {'xs': grad_xs}
|
||||
|
||||
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
@ -238,6 +240,60 @@ def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
r2 = make_arg(5, 3, 4, low=0.9, high=2)
|
||||
yield SampleInput([r0, r1, r2], args=(0,))
|
||||
|
||||
@custom_ops.custom_op('_torch_testing::numpy_split_copy')
|
||||
def numpy_split_copy(x: Tensor, sections: Sequence[int], dim: int) -> List[Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@custom_ops.impl('_torch_testing::numpy_split_copy')
|
||||
def numpy_split_copy_impl(x, splits, dim):
|
||||
x_np = to_numpy(x)
|
||||
arrs = np.split(x_np, splits, axis=dim)
|
||||
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
|
||||
|
||||
@custom_ops.impl_abstract('_torch_testing::numpy_split_copy')
|
||||
def numpy_split_copy_abstract(x, splits, dim):
|
||||
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
|
||||
|
||||
@custom_ops.impl_save_for_backward('_torch_testing::numpy_split_copy')
|
||||
def numpy_split_copy_save_for_backward(inputs, output):
|
||||
return inputs.dim
|
||||
|
||||
@custom_ops.impl_backward('_torch_testing::numpy_split_copy')
|
||||
def numpy_split_copy_backward(ctx, saved, grad_out):
|
||||
dim = saved
|
||||
return {'x': torch.ops._torch_testing.numpy_cat(grad_out, dim=dim)}
|
||||
|
||||
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
x = make_arg(2, 9, low=0.9, high=2)
|
||||
yield SampleInput(x, args=([1, 3, 6], 1))
|
||||
|
||||
@custom_ops.custom_op('_torch_testing::numpy_split_copy_with_int')
|
||||
def numpy_split_copy_with_int(x: Tensor, sections: Sequence[int], dim: int) -> Tuple[List[Tensor], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@custom_ops.impl('_torch_testing::numpy_split_copy_with_int')
|
||||
def numpy_split_copy_with_int_impl(x, splits, dim):
|
||||
x_np = to_numpy(x)
|
||||
arrs = np.split(x_np, splits, axis=dim)
|
||||
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
|
||||
|
||||
@custom_ops.impl_abstract('_torch_testing::numpy_split_copy_with_int')
|
||||
def numpy_split_copy_with_int_abstract(x, splits, dim):
|
||||
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
|
||||
|
||||
@custom_ops.impl_save_for_backward(
|
||||
'_torch_testing::numpy_split_copy_with_int')
|
||||
def numpy_split_copy_with_int_save_for_backward(inputs, output):
|
||||
return inputs.dim
|
||||
|
||||
@custom_ops.impl_backward(
|
||||
'_torch_testing::numpy_split_copy_with_int',
|
||||
output_differentiability=[True, False])
|
||||
def numpy_split_copy_with_int_backward(ctx, saved, grad_out, _):
|
||||
dim = saved
|
||||
return {'x': torch.ops._torch_testing.numpy_cat(grad_out, dim=dim)}
|
||||
|
||||
@custom_ops.custom_op('_torch_testing::numpy_nms')
|
||||
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
@ -371,6 +427,29 @@ custom_op_db = [
|
||||
sample_inputs_func=sample_inputs_numpy_cat,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_autograd=True,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
'NumpySplitCopyCustomOp',
|
||||
op=torch.ops._torch_testing.numpy_split_copy,
|
||||
sample_inputs_func=sample_inputs_numpy_split_copy,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_autograd=True,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
'NumpySplitCopyWithIntCustomOp',
|
||||
op=torch.ops._torch_testing.numpy_split_copy_with_int,
|
||||
sample_inputs_func=sample_inputs_numpy_split_copy,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
|
||||
supports_autograd=True,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
supports_out=False,
|
||||
),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user