Files
pytorch/torch/testing/_internal/custom_op_db.py
Richard Zou 723f111545 [custom_op] explicit autograd API (#101824)
This PR adds an explicit API for registering a backward formula for a
CustomOp. In the end state, we will likely have this explicit API and a
magic API (which is sugar on top of an explicit API), since different
parties of users prefer different ones.

Concretely, to define a backward formula for a CustomOp:
- a user must provide us a "save for backward" function that accepts
(inputs, output) and returns exactly what they want saved for backward
- a user must provide us a "backward" function that accepts
(ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs
are returned as a dict mapping str to a gradient.
Please see the changes in custom_op_db.py for examples of the API.

There are a number of pieces to this PR and I'm happy to split it if it
helps. They are:
- The actual APIs for specifying the two functions
(impl_save_for_backward, impl_backward)
- The autograd kernel: we take the functions the user give us and
construct an autograd.Function object that we then register to
the Autograd dispatch key
- Indirection for the autograd kernel. We add a layer of indirection so
that one can swap out the autograd kernel. This is necessary because by
default, we register an "autograd not implemented" kernel as the
Autograd implementation but then swap it for the actual kernel when the
user provides it.

Test Plan:
- We apply this API to give backward formulas for things in
custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests.
- Various tests in test_python_dispatch.py to check error cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101824
Approved by: https://github.com/ezyang
2023-05-23 18:31:29 +00:00

387 lines
12 KiB
Python

import torch
import functools
from torch.testing import make_tensor
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
from torch._custom_op.impl import custom_op
from torch.testing._internal.autograd_function_db import (
sample_inputs_numpy_cube,
sample_inputs_numpy_mul,
sample_inputs_numpy_sort,
sample_inputs_numpy_take,
)
from torch import Tensor
from torch.types import Number
from typing import Sequence, Tuple
# Note: [custom op db]
#
# This is a collection of custom operator test cases written as OpInfos
# so they can easily be consumed by OpInfo-based tests to check if subsystems
# support them correctly.
def to_numpy(tensor):
return tensor.cpu().numpy()
@custom_op('_torch_testing::numpy_cube')
def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
...
@numpy_cube.impl('cpu')
@numpy_cube.impl('cuda')
def numpy_cube_impl(x):
x_np = to_numpy(x)
dx = torch.tensor(3 * x_np ** 2, device=x.device)
return torch.tensor(x_np ** 3, device=x.device), dx
@numpy_cube.impl_abstract()
def numpy_cube_abstract(x):
return x.clone(), x.clone()
@numpy_cube.impl_save_for_backward()
def numpy_cube_save_for_backward(inputs, output):
return (inputs.x, output[1])
@numpy_cube.impl_backward()
def numpy_cube_backward(ctx, saved, grad_out, grad_dx):
x, dx = saved
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
return {'x': grad_x}
@custom_op('_torch_testing::numpy_mul')
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
...
@numpy_mul.impl('cpu')
@numpy_mul.impl('cuda')
def numpy_mul_impl(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@numpy_mul.impl_abstract()
def numpy_mul_abstract(x, y):
assert x.device == y.device
return (x * y).contiguous()
@numpy_mul.impl_save_for_backward()
def numpy_mul_save_for_backward(inputs, output):
saved = {}
saved['x_requires_grad'] = inputs.x.requires_grad
saved['y_requires_grad'] = inputs.y.requires_grad
# Optimization: only save what is necessary
saved['y'] = inputs.y if inputs.x.requires_grad else None
saved['x'] = inputs.x if inputs.y.requires_grad else None
return saved
@numpy_mul.impl_backward()
def numpy_mul_backward(ctx, saved, grad_out):
grad_x = grad_out * saved['y'] if saved['x_requires_grad'] else None
grad_y = grad_out * saved['x'] if saved['x_requires_grad'] else None
return {'y': grad_y, 'x': grad_x}
@custom_op('_torch_testing::numpy_sort')
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
...
@numpy_sort.impl('cpu')
@numpy_sort.impl('cuda')
def numpy_sort_impl(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@numpy_sort.impl_abstract()
def numpy_sort_abstract(x, dim):
return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
@numpy_sort.impl_save_for_backward()
def numpy_sort_save_for_backward(inputs, output):
out, ind, ind_inv = output
return [inputs.dim, ind, ind_inv]
@numpy_sort.impl_backward(output_differentiability=[True, False, False])
def numpy_sort_backward(ctx, saved, grad_out, grad_ind, grad_ind_inv):
dim, ind, ind_inv = saved
return {'x': numpy_take(grad_out, ind_inv, ind, dim)}
@custom_op('_torch_testing::numpy_take')
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
...
@numpy_take.impl('cpu')
@numpy_take.impl('cuda')
def numpy_take_impl(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@numpy_take.impl_abstract()
def numpy_take_abstract(x, ind, ind_inv, dim):
assert x.device == ind.device
assert x.device == ind_inv.device
assert ind.dtype == torch.long
assert ind_inv.dtype == torch.long
return torch.empty_like(x)
@numpy_take.impl_save_for_backward()
def numpy_take_save_for_backward(inputs, output):
return {
'dim': inputs.dim,
'ind': inputs.ind,
'ind_inv': inputs.ind_inv,
}
@numpy_take.impl_backward()
def numpy_take_backward(ctx, saved, grad_out):
return {
'x': numpy_take(grad_out, saved['ind_inv'], saved['ind'], saved['dim']),
'ind': None,
'ind_inv': None,
}
@custom_op('_torch_testing::numpy_nonzero')
def numpy_nonzero(x: Tensor) -> Tensor:
...
@numpy_nonzero.impl(['cpu', 'cuda'])
def numpy_nonzero_impl(x):
x_np = to_numpy(x)
res = np.stack(np.nonzero(x_np), axis=1)
if res.shape[0] <= 1:
raise RuntimeError("not supported")
return torch.tensor(res, device=x.device)
@numpy_nonzero.impl_abstract()
def numpy_nonzero_abstract(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
shape = [x.dim(), i0]
result = x.new_empty(shape, dtype=torch.long)
return result
def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
shape = 10
result = make_arg(shape, low=0.9, high=2)
mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long)
with torch.no_grad():
result *= mask
yield SampleInput(result, args=())
@custom_op('_torch_testing::numpy_view_copy')
def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
...
@numpy_view_copy.impl(['cpu', 'cuda'])
def numpy_view_copy_impl(x, shape) -> Tensor:
return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
@numpy_view_copy.impl_abstract()
def numpy_view_copy_abstract(x, shape) -> Tensor:
return x.clone().view(shape).clone()
@numpy_view_copy.impl_save_for_backward()
def numpy_view_copy_save_for_backward(inputs, output) -> Tensor:
return inputs.x.shape
@numpy_view_copy.impl_backward()
def numpy_view_copy_backward(ctx, x_shape, grad_out) -> Tensor:
return {'x': numpy_view_copy(grad_out, x_shape)}
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
result = make_arg(2, 3, 4, low=0.9, high=2)
yield SampleInput(result, args=([2, 12],))
@custom_op('_torch_testing::numpy_cat')
def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
...
@numpy_cat.impl(['cpu', 'cuda'])
def numpy_cat_impl(xs, dim):
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
np_xs = [to_numpy(x) for x in xs]
np_out = np.concatenate(np_xs, axis=dim)
return torch.tensor(np_out, device=xs[0].device)
@numpy_cat.impl_abstract()
def numpy_cat_abstract(xs, dim):
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
return torch.cat(xs, dim=dim)
@numpy_cat.impl_save_for_backward()
def numpy_cat_save_for_backward(inputs, output):
dim_sizes = [x.shape[inputs.dim] for x in inputs.xs]
return dim_sizes, inputs.dim
@numpy_cat.impl_backward()
def numpy_cat_backward(ctx, saved, grad_out):
dim_sizes, dim = saved
return {'xs': torch.split(grad_out, dim_sizes, dim)}
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
r0 = make_arg(2, 3, 4, low=0.9, high=2)
r1 = make_arg(4, 3, 4, low=0.9, high=2)
r2 = make_arg(5, 3, 4, low=0.9, high=2)
yield SampleInput([r0, r1, r2], args=(0,))
@custom_op('_torch_testing::numpy_nms')
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
...
@numpy_nms.impl(['cpu', 'cuda'])
def numpy_nms_impl(boxes, scores, iou_threshold):
# Adapted from Ross Girshick's fast-rcnn implementation at
# https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
assert boxes.device == scores.device
device = boxes.device
boxes = to_numpy(boxes)
scores = to_numpy(scores)
N = boxes.shape[0]
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= iou_threshold)[0]
order = order[inds + 1]
result = np.stack(keep)
result = torch.tensor(np.stack(keep), device=device)
# Needed for data-dependent condition :(
assert result.size(0) >= 2
return result
@numpy_nms.impl_abstract()
def numpy_nms_abstract(boxes, scores, iou_threshold):
assert boxes.device == scores.device
N = boxes.shape[0]
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
result = boxes.new_empty([i0, 4])
return result
def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
N = 64
xs = make_arg([N], low=0, high=28)
dx = make_arg([N], low=0, high=4)
ys = make_arg([N], low=0, high=28)
dy = make_arg([N], low=0, high=4)
boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad)
scores = make_arg([N], low=0, high=1, requires_grad=requires_grad)
iou_threshold = make_arg([], low=0, high=1).item()
yield SampleInput(boxes, args=(scores, iou_threshold))
# CustomOp isn't deepcopy-able, so we wrap in a function that is.
def wrap_for_opinfo(op):
def inner(*args, **kwargs):
return op(*args, **kwargs)
return inner
custom_op_db = [
OpInfo(
'NumpyCubeCustomOp',
op=wrap_for_opinfo(numpy_cube),
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulCustomOp',
op=wrap_for_opinfo(numpy_mul),
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpySortCustomOp',
op=wrap_for_opinfo(numpy_sort),
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyTakeCustomOp',
op=wrap_for_opinfo(numpy_take),
sample_inputs_func=sample_inputs_numpy_take,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyNonzeroCustomOp',
op=wrap_for_opinfo(numpy_nonzero),
sample_inputs_func=sample_inputs_numpy_nonzero,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
'NumpyNMSCustomOp',
op=wrap_for_opinfo(numpy_nms),
sample_inputs_func=sample_inputs_numpy_nms,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
'NumpyViewCopyCustomOp',
op=wrap_for_opinfo(numpy_view_copy),
sample_inputs_func=sample_inputs_numpy_view_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
supports_out=False,
),
OpInfo(
'NumpyCatCustomOp',
op=wrap_for_opinfo(numpy_cat),
sample_inputs_func=sample_inputs_numpy_cat,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
supports_out=False,
),
]