Add some forward AD formulas (#69384)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69384

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33020602

Pulled By: soulitzer

fbshipit-source-id: a92dd243f2b5b21fe277b0bb17bcd61dfe5a0d67
This commit is contained in:
soulitzer
2021-12-12 00:07:50 -08:00
committed by Facebook GitHub Bot
parent baf92f9d5a
commit 0dcbd73eee
3 changed files with 78 additions and 6 deletions

View File

@ -516,6 +516,9 @@ class TestViewOps(TestCase):
g_expected = torch.stack([gi if j == i else torch.zeros_like(gi)
for j in range(3)], dim=0)
self.assertEqual(g, g_expected)
# Check with gradcheck
stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True)
gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True)
def test_expand_view(self, device) -> None:
t = torch.ones((5, 1), device=device)

View File

@ -299,6 +299,11 @@
- name: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
result: auto_linear
- name: as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
result: auto_linear
- name: asin(Tensor self) -> Tensor
self: grad * (-self * self + 1).rsqrt().conj()
@ -553,15 +558,19 @@
- name: erf(Tensor self) -> Tensor
self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
result: auto_element_wise
- name: erfc(Tensor self) -> Tensor
self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
result: auto_element_wise
- name: special_erfcx(Tensor self) -> Tensor
self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad
result: auto_element_wise
- name: erfinv(Tensor self) -> Tensor
self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad
result: auto_element_wise
- name: exp(Tensor self) -> Tensor
self: grad * result.conj()
@ -673,7 +682,7 @@
self: hardsigmoid_backward(grad, self)
- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
self: not_implemented("histc")
output_differentiability: [False]
- name: hardswish(Tensor self) -> Tensor
self: hardswish_backward(grad, self)
@ -686,15 +695,19 @@
- name: i0(Tensor self) -> Tensor
self: grad * at::special_i1(self)
result: auto_element_wise
- name: special_i0e(Tensor self) -> Tensor
self: grad * (at::special_i1e(self) - self.sgn() * result)
result: auto_element_wise
- name: special_i1(Tensor self) -> Tensor
self: i1_backward(grad, self, result)
result: auto_element_wise
- name: special_i1e(Tensor self) -> Tensor
self: i1e_backward(grad, self, result)
result: auto_element_wise
- name: igamma(Tensor self, Tensor other) -> Tensor
self: 'not_implemented("igamma: input")'
@ -957,9 +970,11 @@
- name: median(Tensor self) -> Tensor
self: evenly_distribute_backward(grad, self, result)
result: evenly_read_jvp(self_t, self_p, result)
- name: nanmedian(Tensor self) -> Tensor
self: evenly_distribute_backward(grad, self, result)
result: evenly_read_jvp(self_t, self_p, result)
# This is in theory incorrect in the following case:
# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
@ -977,9 +992,11 @@
# subgradient on one side.
- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
@ -1166,6 +1183,7 @@
- name: rad2deg(Tensor self) -> Tensor
self: rad2deg_backward(grad)
result: auto_element_wise
- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
@ -1202,6 +1220,7 @@
- name: special_ndtri(Tensor self) -> Tensor
self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
result: auto_element_wise
# DO NOT define a backward for reshape!
# reshape is special in that it sometimes returns a view, and sometimes not.
@ -1226,10 +1245,12 @@
self: grad.scatter(dim, index, 0)
index: non_differentiable
src: grad.gather(dim, index)
result: self_t.scatter(dim, index, src_t)
- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
self: grad.scatter(dim, index, 0)
index: non_differentiable
result: self_t.scatter(dim, index, 0)
- name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
self: grad
@ -1314,10 +1335,12 @@
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
output_differentiability: [True, False]
values: gather_with_keepdimed_indices(self_t, dim, indices, true)
- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
output_differentiability: [True, False]
values: gather_with_keepdimed_indices(self_t, dim, indices, true)
- name: split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]
self: split_backward(grads, split_size, dim, self.sizes(), self.options())
@ -1447,6 +1470,7 @@
- name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
output_differentiability: [True, False]
values: gather(self_t, dim, indices)
- name: trace(Tensor self) -> Tensor
self: trace_backward(grad, self.sizes())
@ -1685,27 +1709,34 @@
- name: silu(Tensor self) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)"
result: auto_element_wise
- name: mish(Tensor self) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)"
result: auto_element_wise
- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self)
result: auto_element_wise
- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)
self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result)
- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self)
result: auto_element_wise
- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)
- name: gelu(Tensor self) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)"
result: auto_element_wise
- name: glu(Tensor self, int dim=-1) -> Tensor
self: glu_backward(grad, self, dim)
# RuntimeError: output with shape [1] doesn't match the broadcast shape [2]
# result: auto_element_wise
- name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
self: hardshrink_backward(grad, self, lambd)
@ -2342,6 +2373,7 @@
- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
self: unbind_backward(grads, dim)
result: auto_linear
- name: stack(Tensor[] tensors, int dim=0) -> Tensor
tensors: "grad.defined() ? unbind(grad, dim) : std::vector<Tensor>(tensors.size())"

View File

@ -9618,6 +9618,7 @@ op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
safe_casts_outputs=True,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_i0_i1),
UnaryUfuncInfo('special.i0e',
aten_name='special_i0e',
@ -9630,6 +9631,7 @@ op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('special.i1',
aten_name='special_i1',
@ -9649,13 +9651,15 @@ op_db: List[OpInfo] = [
"TestUnaryUfuncs",
"test_out_arg_all_dtypes",
device_type='cuda'),
)),
),
supports_forward_ad=True),
UnaryUfuncInfo('special.i1e',
aten_name='special_i1e',
ref=scipy.special.i1e if TEST_SCIPY else _NOTHING,
dtypes=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('special.ndtr',
aten_name='special_ndtr',
@ -9664,6 +9668,7 @@ op_db: List[OpInfo] = [
ref=scipy.special.ndtr if TEST_SCIPY else _NOTHING,
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16),
supports_forward_ad=True,
safe_casts_outputs=True),
BinaryUfuncInfo('floor_divide',
dtypes=all_types_and(torch.half, torch.bfloat16),
@ -10281,12 +10286,14 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=all_types_and(torch.float16),
# TODO: some signatures of median do support out
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
OpInfo('nanmedian',
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16),
# TODO: some signatures of nanmedian do support out
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
OpInfo('var_mean',
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
@ -10381,6 +10388,11 @@ op_db: List[OpInfo] = [
OpInfo('quantile',
dtypes=floating_types(),
sample_inputs_func=sample_inputs_reduction_quantile,
supports_forward_ad=True,
# See https://github.com/pytorch/pytorch/issues/66357
# Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
# does not have a batching rule in core
check_batched_forward_grad=False,
skips=(
# Pre-existing condition; Needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'),
@ -10388,6 +10400,11 @@ op_db: List[OpInfo] = [
OpInfo('nanquantile',
dtypes=floating_types(),
sample_inputs_func=sample_inputs_reduction_quantile,
supports_forward_ad=True,
# See https://github.com/pytorch/pytorch/issues/66357
# Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
# does not have a batching rule in core
check_batched_forward_grad=False,
skips=(
# Pre-existing condition; Needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'),
@ -10674,6 +10691,9 @@ op_db: List[OpInfo] = [
torch.as_strided(x, size, stride, storage_offset=storage_offset),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_as_strided,
skips=(
# FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
@ -11253,7 +11273,7 @@ op_db: List[OpInfo] = [
np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)),
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=False,
supports_forward_ad=True,
supports_autograd=True,
assert_autodiffed=False,
supports_gradgrad=True,
@ -11263,6 +11283,8 @@ op_db: List[OpInfo] = [
inplace_variant=lambda x, alpha=1.0:
torch.nn.functional.elu(x, alpha, inplace=True),
decorators=[
# Not implemented yet
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=1e-03, rtol=1.2e-03),
@ -11299,7 +11321,7 @@ op_db: List[OpInfo] = [
np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)),
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=False,
supports_forward_ad=True,
supports_autograd=True,
assert_autodiffed=False,
supports_gradgrad=True,
@ -11309,6 +11331,8 @@ op_db: List[OpInfo] = [
inplace_variant=lambda x, alpha=1.0:
torch.nn.functional.celu(x, alpha, inplace=True),
decorators=[
# Not implemented yet
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=1e-03, rtol=1.2e-03),
@ -11357,13 +11381,15 @@ op_db: List[OpInfo] = [
),
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=False,
supports_forward_ad=True, # depends on 'elu'
supports_autograd=True,
assert_autodiffed=False,
supports_gradgrad=True,
supports_out=False,
inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True),
decorators=[
# Not implemented yet (depends on 'elu_')
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=1e-2, rtol=1.8e-2),
@ -11433,7 +11459,7 @@ op_db: List[OpInfo] = [
ref=lambda x: x * np.tanh(reference_softplus(x)),
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=False,
supports_forward_ad=True,
supports_autograd=True,
assert_autodiffed=False,
supports_gradgrad=True,
@ -11506,6 +11532,7 @@ op_db: List[OpInfo] = [
OpInfo('topk',
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_topk),
# Multiple variants for batch_norm to test with and without cuDNN disabled
# See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details
@ -11627,6 +11654,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
supports_gradgrad=True,
supports_out=False,
supports_forward_ad=True,
autodiff_nonfusible_nodes=["aten::gelu"]),
OpInfo('nn.functional.relu6',
aten_name="relu6",
@ -11759,6 +11787,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
dtypes=[torch.bfloat16]),
),
supports_forward_ad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('real',
ref=np.real,
@ -12890,6 +12919,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfROCM=all_types_and(torch.float16),
sample_inputs_func=sample_inputs_sort,
supports_forward_ad=True,
skips=(
# sort does not correctly warn when resizing out= inputs
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
@ -12933,6 +12963,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_take),
OpInfo('scatter',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_scatter,),
OpInfo('bfloat16',
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
@ -13435,6 +13466,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfROCM=all_types_and(torch.float16),
check_batched_gradgrad=False,
supports_forward_ad=True,
skips=(
# msort does not correctly warn when resizing out= inputs.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
@ -13794,6 +13826,7 @@ op_db: List[OpInfo] = [
domain=(0, 1),
aten_name='special_ndtri',
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('erf',
ref=scipy.special.erf if TEST_SCIPY else _NOTHING,
@ -13811,6 +13844,7 @@ op_db: List[OpInfo] = [
assert_jit_shape_analysis=True,
supports_sparse=True,
supports_sparse_csr=True,
supports_forward_ad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('erfc',
ref=scipy.special.erfc if TEST_SCIPY else _NOTHING,
@ -13820,6 +13854,7 @@ op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('erfinv',
ref=scipy.special.erfinv if TEST_SCIPY else _NOTHING,
@ -13831,6 +13866,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
safe_casts_outputs=True,
supports_sparse_csr=True,
supports_forward_ad=True,
domain=(-1, 1),
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
@ -13991,6 +14027,7 @@ op_db: List[OpInfo] = [
aten_name='special_erfcx',
decorators=(toleranceOverride({torch.float32: tol(atol=0, rtol=4e-6), }),),
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
safe_casts_outputs=True),
OpInfo(
"nn.functional.dropout",