Files
pytorch/torch/testing/_internal/autograd_function_db.py

634 lines
19 KiB
Python

# mypy: ignore-errors
import torch
from functools import partial
from torch.testing import make_tensor
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
# Note: [autograd.Function db]
#
# This is a collection of autograd.Function test cases written as OpInfos
# so they can easily be consumed by OpInfo-based tests to check if a subsystem
# supports autograd.Function.
#
# Axes:
# - saves {output, input, intermediate, non-tensor}
# - {inputs, output} x {single tensor, tensors, arbitrary objects}
# - Uses {mark_dirty, mark_non_differentiable, once_differentiable}
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpyCube(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
return torch.tensor(input_np ** 3, device=input.device), dinput
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(inputs[0], output[1])
ctx.save_for_forward(inputs[0], output[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
input, dinput = ctx.saved_tensors
return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input)
@staticmethod
def vmap(info, in_dims, input):
result = NumpyCube.apply(input)
return result, (in_dims[0], in_dims[0])
@staticmethod
def jvp(ctx, input_tangent):
input, dinput = ctx.saved_tensors
return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
class CubeGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x):
return x ** 3, 3 * x ** 2
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(inputs[0], outputs[1])
ctx.save_for_forward(inputs[0], outputs[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
_input, dinput = ctx.saved_tensors
result = grad_output * dinput + 6 * dinput
return result
@staticmethod
def jvp(ctx, input_tangent):
input, dinput = ctx.saved_tensors
return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(1, low=0.8, high=2), args=())
class NumpyCubeNotComposable(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
return torch.tensor(input_np ** 3, device=input.device), input_np
@staticmethod
def setup_context(ctx, inputs, output):
_, input_np = output
ctx.input_np = input_np
ctx.device = inputs[0].device
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output, grad_saved):
result_np = 3 * (ctx.input_np ** 2)
return torch.tensor(result_np, device=ctx.device)
class NumpyMul(torch.autograd.Function):
@staticmethod
def forward(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = NumpyMul.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = NumpyMul.apply(grad_output, x)
return gx, gy
@staticmethod
def vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = NumpyMul.apply(x, y)
result = result.movedim(-1, 0)
return result, 0
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Broadcasting
yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),))
def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14})
class MulGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, y):
return x * y
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = MulGenVmap.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = MulGenVmap.apply(grad_output, x)
return gx, gy
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
class NumpyExp_(torch.autograd.Function):
@staticmethod
def forward(x):
x_np = to_numpy(x)
np.exp(x_np, x_np)
return x
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
ctx.mark_dirty(x)
ctx.save_for_backward(output)
ctx.save_for_forward(output)
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return NumpyMul.apply(grad_output, output)
@staticmethod
def vmap(info, in_dims, x):
NumpyExp_.apply(x)
return x, in_dims[0]
@staticmethod
def jvp(ctx, x_tangent):
# Doesn't call numpy operations because I didn't want to write NumpyMul_
output, = ctx.saved_tensors
x_tangent.mul_(output)
return x_tangent
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
return (
torch.tensor(x, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
_x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
@staticmethod
def vmap(info, in_dims, x, dim):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
# wrap dim
dim = dim if dim >= 0 else dim + x.dim() - 1
return NumpySort.apply(x, dim + 1), (0, 0, 0)
@staticmethod
def jvp(ctx, x_tangent, _):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
class SortGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, dim):
ind = torch.argsort(x, dim=dim)
ind_inv = torch.argsort(ind, axis=dim)
result = torch.take_along_dim(x, ind, dim=dim)
return result, ind, ind_inv
@staticmethod
def setup_context(ctx, inputs, outputs):
x, dim = inputs
_, ind, ind_inv = outputs
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None
@staticmethod
def jvp(ctx, x_tangent, _):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5), args=(1,))
def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
tensor = make_arg(3, 5)
dim = 1
_, ind, ind_inv = NumpySort.apply(tensor, 1)
yield SampleInput(tensor, args=(ind, ind_inv, dim))
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
_x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# wrap dim
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def expand_bdim(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
x = expand_bdim(x, x_bdim)
ind = expand_bdim(ind, ind_bdim)
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
@staticmethod
def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
assert ind_tangent is None
assert ind_inv_tangent is None
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim)
class TakeGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, ind, ind_inv, dim):
return torch.take_along_dim(x, ind, dim)
@staticmethod
def setup_context(ctx, inputs, outputs):
_x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim)
class Select(torch.autograd.Function):
@staticmethod
def forward(x, idx):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return Select.apply(x, idx), 0
@staticmethod
def jvp(ctx, x_tangent, _):
return Select.apply(x_tangent, ctx.idx)
class SelectGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, idx):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, outputs):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def jvp(ctx, x_tangent, _):
return SelectGenVmap.apply(x_tangent, ctx.idx)
def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5), args=(2,))
class ScaleGradGenVmap(torch.autograd.Function):
generate_vmap_rule = True
scale = 3.14
@staticmethod
def forward(x):
return x.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@staticmethod
def backward(ctx, grad_output):
return grad_output * ScaleGradGenVmap.scale
@staticmethod
def jvp(ctx, x_tangent):
return x_tangent * ScaleGradGenVmap.scale
class ZeroGradientsGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, y):
return x.clone(), y.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@staticmethod
def backward(ctx, gx, gy):
# Intentionally returning torch.zeros instead of zeros_like or new_zeros.
# Also intentionally not None.
return (
# Intentionally too-large gradient
torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
@staticmethod
def jvp(ctx, gx, gy):
# Intentionally returning torch.zeros instead of zeros_like or new_zeros.
# Also intentionally not None.
return (
torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5))
class ForwardHasDefaultArgs(torch.autograd.Function):
@staticmethod
def forward(x, idx=(2,)):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return ForwardHasDefaultArgs.apply(x, idx), 0
@staticmethod
def jvp(ctx, x_tangent, _):
return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)
autograd_function_db = [
OpInfo(
'NumpyCubeAutogradFunction',
op=NumpyCube.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyExpMarkDirtyAutogradFunction',
op=lambda x: NumpyExp_.apply(x.clone()),
inplace_variant=NumpyExp_.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulAutogradFunction',
op=NumpyMul.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyCubeNotComposableAutogradFunction',
op=lambda x: NumpyCubeNotComposable.apply(x)[0],
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpySortAutogradFunction',
op=NumpySort.apply,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
gradcheck_wrapper=lambda y, ind: y,
),
OpInfo(
'NumpyTakeAutogradFunction',
op=NumpyTake.apply,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_take,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'SelectAutogradFunction',
op=Select.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_select,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'CubeGenVmapAutogradFunction',
op=CubeGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'MulGenVmapAutogradFunction',
op=MulGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'SortGenVmapAutogradFunction',
op=SortGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
gradcheck_wrapper=lambda y, ind: y,
),
OpInfo(
'SelectGenVmapAutogradFunction',
op=SelectGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_select,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ScaleGradGenVmapAutogradFunction',
op=ScaleGradGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ZeroGradientsGenVmapAutogradFunction',
op=ZeroGradientsGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ForwardHasDefaultArgsAutogradFunction',
op=ForwardHasDefaultArgs.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_forward_default_args,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
]