diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 5ad82922d70c..fa76202eaef0 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -516,6 +516,9 @@ class TestViewOps(TestCase): g_expected = torch.stack([gi if j == i else torch.zeros_like(gi) for j in range(3)], dim=0) self.assertEqual(g, g_expected) + # Check with gradcheck + stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True) + gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True) def test_expand_view(self, device) -> None: t = torch.ones((5, 1), device=device) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 16521d722639..74e6e1ebbae6 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -299,6 +299,11 @@ - name: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear - name: asin(Tensor self) -> Tensor self: grad * (-self * self + 1).rsqrt().conj() @@ -553,15 +558,19 @@ - name: erf(Tensor self) -> Tensor self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise - name: erfc(Tensor self) -> Tensor self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise - name: special_erfcx(Tensor self) -> Tensor self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad + result: auto_element_wise - name: erfinv(Tensor self) -> Tensor self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad + result: auto_element_wise - name: exp(Tensor self) -> Tensor self: grad * result.conj() @@ -673,7 +682,7 @@ self: hardsigmoid_backward(grad, self) - name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor - self: not_implemented("histc") + output_differentiability: [False] - name: hardswish(Tensor self) -> Tensor self: hardswish_backward(grad, self) @@ -686,15 +695,19 @@ - name: i0(Tensor self) -> Tensor self: grad * at::special_i1(self) + result: auto_element_wise - name: special_i0e(Tensor self) -> Tensor self: grad * (at::special_i1e(self) - self.sgn() * result) + result: auto_element_wise - name: special_i1(Tensor self) -> Tensor self: i1_backward(grad, self, result) + result: auto_element_wise - name: special_i1e(Tensor self) -> Tensor self: i1e_backward(grad, self, result) + result: auto_element_wise - name: igamma(Tensor self, Tensor other) -> Tensor self: 'not_implemented("igamma: input")' @@ -957,9 +970,11 @@ - name: median(Tensor self) -> Tensor self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) - name: nanmedian(Tensor self) -> Tensor self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) # This is in theory incorrect in the following case: # sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value @@ -977,9 +992,11 @@ # subgradient on one side. - name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) @@ -1166,6 +1183,7 @@ - name: rad2deg(Tensor self) -> Tensor self: rad2deg_backward(grad) + result: auto_element_wise - name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) self: zeros_like(grad) @@ -1202,6 +1220,7 @@ - name: special_ndtri(Tensor self) -> Tensor self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp() + result: auto_element_wise # DO NOT define a backward for reshape! # reshape is special in that it sometimes returns a view, and sometimes not. @@ -1226,10 +1245,12 @@ self: grad.scatter(dim, index, 0) index: non_differentiable src: grad.gather(dim, index) + result: self_t.scatter(dim, index, src_t) - name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor self: grad.scatter(dim, index, 0) index: non_differentiable + result: self_t.scatter(dim, index, 0) - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor self: grad @@ -1314,10 +1335,12 @@ - name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) - name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) - name: split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[] self: split_backward(grads, split_size, dim, self.sizes(), self.options()) @@ -1447,6 +1470,7 @@ - name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) output_differentiability: [True, False] + values: gather(self_t, dim, indices) - name: trace(Tensor self) -> Tensor self: trace_backward(grad, self.sizes()) @@ -1685,27 +1709,34 @@ - name: silu(Tensor self) -> Tensor self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" + result: auto_element_wise - name: mish(Tensor self) -> Tensor self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" + result: auto_element_wise - name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) + result: auto_element_wise - name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) - name: celu(Tensor self, Scalar alpha=1.0) -> Tensor self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) + result: auto_element_wise - name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) - name: gelu(Tensor self) -> Tensor self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)" + result: auto_element_wise - name: glu(Tensor self, int dim=-1) -> Tensor self: glu_backward(grad, self, dim) + # RuntimeError: output with shape [1] doesn't match the broadcast shape [2] + # result: auto_element_wise - name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor self: hardshrink_backward(grad, self, lambd) @@ -2342,6 +2373,7 @@ - name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] self: unbind_backward(grads, dim) + result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor tensors: "grad.defined() ? unbind(grad, dim) : std::vector(tensors.size())" diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cd46b1394353..cceccce8e42b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9618,6 +9618,7 @@ op_db: List[OpInfo] = [ dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), safe_casts_outputs=True, + supports_forward_ad=True, sample_inputs_func=sample_inputs_i0_i1), UnaryUfuncInfo('special.i0e', aten_name='special_i0e', @@ -9630,6 +9631,7 @@ op_db: List[OpInfo] = [ dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_i0_i1, + supports_forward_ad=True, safe_casts_outputs=True), UnaryUfuncInfo('special.i1', aten_name='special_i1', @@ -9649,13 +9651,15 @@ op_db: List[OpInfo] = [ "TestUnaryUfuncs", "test_out_arg_all_dtypes", device_type='cuda'), - )), + ), + supports_forward_ad=True), UnaryUfuncInfo('special.i1e', aten_name='special_i1e', ref=scipy.special.i1e if TEST_SCIPY else _NOTHING, dtypes=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool), sample_inputs_func=sample_inputs_i0_i1, + supports_forward_ad=True, safe_casts_outputs=True), UnaryUfuncInfo('special.ndtr', aten_name='special_ndtr', @@ -9664,6 +9668,7 @@ op_db: List[OpInfo] = [ ref=scipy.special.ndtr if TEST_SCIPY else _NOTHING, dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_forward_ad=True, safe_casts_outputs=True), BinaryUfuncInfo('floor_divide', dtypes=all_types_and(torch.half, torch.bfloat16), @@ -10281,12 +10286,14 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and(torch.float16), # TODO: some signatures of median do support out supports_out=False, + supports_forward_ad=True, sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('nanmedian', dtypes=all_types_and(torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16), # TODO: some signatures of nanmedian do support out supports_out=False, + supports_forward_ad=True, sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('var_mean', dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), @@ -10381,6 +10388,11 @@ op_db: List[OpInfo] = [ OpInfo('quantile', dtypes=floating_types(), sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False, skips=( # Pre-existing condition; Needs to be fixed DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'), @@ -10388,6 +10400,11 @@ op_db: List[OpInfo] = [ OpInfo('nanquantile', dtypes=floating_types(), sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False, skips=( # Pre-existing condition; Needs to be fixed DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'), @@ -10674,6 +10691,9 @@ op_db: List[OpInfo] = [ torch.as_strided(x, size, stride, storage_offset=storage_offset), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, sample_inputs_func=sample_inputs_as_strided, skips=( # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! @@ -11253,7 +11273,7 @@ op_db: List[OpInfo] = [ np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)), dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=False, + supports_forward_ad=True, supports_autograd=True, assert_autodiffed=False, supports_gradgrad=True, @@ -11263,6 +11283,8 @@ op_db: List[OpInfo] = [ inplace_variant=lambda x, alpha=1.0: torch.nn.functional.elu(x, alpha, inplace=True), decorators=[ + # Not implemented yet + DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), DecorateInfo( toleranceOverride({ torch.float16: tol(atol=1e-03, rtol=1.2e-03), @@ -11299,7 +11321,7 @@ op_db: List[OpInfo] = [ np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)), dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=False, + supports_forward_ad=True, supports_autograd=True, assert_autodiffed=False, supports_gradgrad=True, @@ -11309,6 +11331,8 @@ op_db: List[OpInfo] = [ inplace_variant=lambda x, alpha=1.0: torch.nn.functional.celu(x, alpha, inplace=True), decorators=[ + # Not implemented yet + DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), DecorateInfo( toleranceOverride({ torch.float16: tol(atol=1e-03, rtol=1.2e-03), @@ -11357,13 +11381,15 @@ op_db: List[OpInfo] = [ ), dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=False, + supports_forward_ad=True, # depends on 'elu' supports_autograd=True, assert_autodiffed=False, supports_gradgrad=True, supports_out=False, inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True), decorators=[ + # Not implemented yet (depends on 'elu_') + DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), DecorateInfo( toleranceOverride({ torch.float16: tol(atol=1e-2, rtol=1.8e-2), @@ -11433,7 +11459,7 @@ op_db: List[OpInfo] = [ ref=lambda x: x * np.tanh(reference_softplus(x)), dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=False, + supports_forward_ad=True, supports_autograd=True, assert_autodiffed=False, supports_gradgrad=True, @@ -11506,6 +11532,7 @@ op_db: List[OpInfo] = [ OpInfo('topk', dtypes=all_types_and(torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, sample_inputs_func=sample_inputs_topk), # Multiple variants for batch_norm to test with and without cuDNN disabled # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details @@ -11627,6 +11654,7 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), supports_gradgrad=True, supports_out=False, + supports_forward_ad=True, autodiff_nonfusible_nodes=["aten::gelu"]), OpInfo('nn.functional.relu6', aten_name="relu6", @@ -11759,6 +11787,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.bfloat16]), ), + supports_forward_ad=True, safe_casts_outputs=True), UnaryUfuncInfo('real', ref=np.real, @@ -12890,6 +12919,7 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), dtypesIfROCM=all_types_and(torch.float16), sample_inputs_func=sample_inputs_sort, + supports_forward_ad=True, skips=( # sort does not correctly warn when resizing out= inputs DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), @@ -12933,6 +12963,7 @@ op_db: List[OpInfo] = [ sample_inputs_func=sample_inputs_take), OpInfo('scatter', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, sample_inputs_func=sample_inputs_scatter,), OpInfo('bfloat16', op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), @@ -13435,6 +13466,7 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), dtypesIfROCM=all_types_and(torch.float16), check_batched_gradgrad=False, + supports_forward_ad=True, skips=( # msort does not correctly warn when resizing out= inputs. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), @@ -13794,6 +13826,7 @@ op_db: List[OpInfo] = [ domain=(0, 1), aten_name='special_ndtri', dtypes=all_types_and(torch.bool), + supports_forward_ad=True, safe_casts_outputs=True), UnaryUfuncInfo('erf', ref=scipy.special.erf if TEST_SCIPY else _NOTHING, @@ -13811,6 +13844,7 @@ op_db: List[OpInfo] = [ assert_jit_shape_analysis=True, supports_sparse=True, supports_sparse_csr=True, + supports_forward_ad=True, safe_casts_outputs=True), UnaryUfuncInfo('erfc', ref=scipy.special.erfc if TEST_SCIPY else _NOTHING, @@ -13820,6 +13854,7 @@ op_db: List[OpInfo] = [ dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, + supports_forward_ad=True, safe_casts_outputs=True), UnaryUfuncInfo('erfinv', ref=scipy.special.erfinv if TEST_SCIPY else _NOTHING, @@ -13831,6 +13866,7 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and(torch.bool, torch.half), safe_casts_outputs=True, supports_sparse_csr=True, + supports_forward_ad=True, domain=(-1, 1), skips=( # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 @@ -13991,6 +14027,7 @@ op_db: List[OpInfo] = [ aten_name='special_erfcx', decorators=(toleranceOverride({torch.float32: tol(atol=0, rtol=4e-6), }),), dtypes=all_types_and(torch.bool), + supports_forward_ad=True, safe_casts_outputs=True), OpInfo( "nn.functional.dropout",