mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add some forward AD formulas (#69384)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69384 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33020602 Pulled By: soulitzer fbshipit-source-id: a92dd243f2b5b21fe277b0bb17bcd61dfe5a0d67
This commit is contained in:
		
				
					committed by
					
						 Facebook GitHub Bot
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							baf92f9d5a
						
					
				
				
					commit
					0dcbd73eee
				
			| @ -516,6 +516,9 @@ class TestViewOps(TestCase): | ||||
|             g_expected = torch.stack([gi if j == i else torch.zeros_like(gi) | ||||
|                                       for j in range(3)], dim=0) | ||||
|             self.assertEqual(g, g_expected) | ||||
|         # Check with gradcheck | ||||
|         stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True) | ||||
|         gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True) | ||||
|  | ||||
|     def test_expand_view(self, device) -> None: | ||||
|         t = torch.ones((5, 1), device=device) | ||||
|  | ||||
| @ -299,6 +299,11 @@ | ||||
|  | ||||
| - name: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) | ||||
|   self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) | ||||
|   result: auto_linear | ||||
|  | ||||
| - name: as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!) | ||||
|   self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) | ||||
|   result: auto_linear | ||||
|  | ||||
| - name: asin(Tensor self) -> Tensor | ||||
|   self: grad * (-self * self + 1).rsqrt().conj() | ||||
| @ -553,15 +558,19 @@ | ||||
|  | ||||
| - name: erf(Tensor self) -> Tensor | ||||
|   self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: erfc(Tensor self) -> Tensor | ||||
|   self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: special_erfcx(Tensor self) -> Tensor | ||||
|   self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: erfinv(Tensor self) -> Tensor | ||||
|   self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: exp(Tensor self) -> Tensor | ||||
|   self: grad * result.conj() | ||||
| @ -673,7 +682,7 @@ | ||||
|   self: hardsigmoid_backward(grad, self) | ||||
|  | ||||
| - name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor | ||||
|   self: not_implemented("histc") | ||||
|   output_differentiability: [False] | ||||
|  | ||||
| - name: hardswish(Tensor self) -> Tensor | ||||
|   self: hardswish_backward(grad, self) | ||||
| @ -686,15 +695,19 @@ | ||||
|  | ||||
| - name: i0(Tensor self) -> Tensor | ||||
|   self: grad * at::special_i1(self) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: special_i0e(Tensor self) -> Tensor | ||||
|   self: grad * (at::special_i1e(self) - self.sgn() * result) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: special_i1(Tensor self) -> Tensor | ||||
|   self: i1_backward(grad, self, result) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: special_i1e(Tensor self) -> Tensor | ||||
|   self: i1e_backward(grad, self, result) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: igamma(Tensor self, Tensor other) -> Tensor | ||||
|   self: 'not_implemented("igamma: input")' | ||||
| @ -957,9 +970,11 @@ | ||||
|  | ||||
| - name: median(Tensor self) -> Tensor | ||||
|   self: evenly_distribute_backward(grad, self, result) | ||||
|   result: evenly_read_jvp(self_t, self_p, result) | ||||
|  | ||||
| - name: nanmedian(Tensor self) -> Tensor | ||||
|   self: evenly_distribute_backward(grad, self, result) | ||||
|   result: evenly_read_jvp(self_t, self_p, result) | ||||
|  | ||||
| # This is in theory incorrect in the following case: | ||||
| #   sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value | ||||
| @ -977,9 +992,11 @@ | ||||
| # subgradient on one side. | ||||
| - name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||
|   self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) | ||||
|   values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) | ||||
|  | ||||
| - name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||
|   self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) | ||||
|   values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) | ||||
|  | ||||
| - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||
|   self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) | ||||
| @ -1166,6 +1183,7 @@ | ||||
|  | ||||
| - name: rad2deg(Tensor self) -> Tensor | ||||
|   self: rad2deg_backward(grad) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) | ||||
|   self: zeros_like(grad) | ||||
| @ -1202,6 +1220,7 @@ | ||||
|  | ||||
| - name: special_ndtri(Tensor self) -> Tensor | ||||
|   self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp() | ||||
|   result: auto_element_wise | ||||
|  | ||||
| # DO NOT define a backward for reshape! | ||||
| # reshape is special in that it sometimes returns a view, and sometimes not. | ||||
| @ -1226,10 +1245,12 @@ | ||||
|   self: grad.scatter(dim, index, 0) | ||||
|   index: non_differentiable | ||||
|   src: grad.gather(dim, index) | ||||
|   result: self_t.scatter(dim, index, src_t) | ||||
|  | ||||
| - name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor | ||||
|   self: grad.scatter(dim, index, 0) | ||||
|   index: non_differentiable | ||||
|   result: self_t.scatter(dim, index, 0) | ||||
|  | ||||
| - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor | ||||
|   self: grad | ||||
| @ -1314,10 +1335,12 @@ | ||||
| - name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) | ||||
|   self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) | ||||
|   output_differentiability: [True, False] | ||||
|   values: gather_with_keepdimed_indices(self_t, dim, indices, true) | ||||
|  | ||||
| - name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) | ||||
|   self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) | ||||
|   output_differentiability: [True, False] | ||||
|   values: gather_with_keepdimed_indices(self_t, dim, indices, true) | ||||
|  | ||||
| - name: split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[] | ||||
|   self: split_backward(grads, split_size, dim, self.sizes(), self.options()) | ||||
| @ -1447,6 +1470,7 @@ | ||||
| - name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) | ||||
|   self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) | ||||
|   output_differentiability: [True, False] | ||||
|   values: gather(self_t, dim, indices) | ||||
|  | ||||
| - name: trace(Tensor self) -> Tensor | ||||
|   self: trace_backward(grad, self.sizes()) | ||||
| @ -1685,27 +1709,34 @@ | ||||
|  | ||||
| - name: silu(Tensor self) -> Tensor | ||||
|   self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: mish(Tensor self) -> Tensor | ||||
|   self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor | ||||
|   self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) | ||||
|   self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) | ||||
|  | ||||
| - name: celu(Tensor self, Scalar alpha=1.0) -> Tensor | ||||
|   self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) | ||||
|   self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) | ||||
|  | ||||
| - name: gelu(Tensor self) -> Tensor | ||||
|   self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)" | ||||
|   result: auto_element_wise | ||||
|  | ||||
| - name: glu(Tensor self, int dim=-1) -> Tensor | ||||
|   self: glu_backward(grad, self, dim) | ||||
|   # RuntimeError: output with shape [1] doesn't match the broadcast shape [2] | ||||
|   # result: auto_element_wise | ||||
|  | ||||
| - name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor | ||||
|   self: hardshrink_backward(grad, self, lambd) | ||||
| @ -2342,6 +2373,7 @@ | ||||
|  | ||||
| - name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] | ||||
|   self: unbind_backward(grads, dim) | ||||
|   result: auto_linear | ||||
|  | ||||
| - name: stack(Tensor[] tensors, int dim=0) -> Tensor | ||||
|   tensors: "grad.defined() ? unbind(grad, dim) : std::vector<Tensor>(tensors.size())" | ||||
|  | ||||
| @ -9618,6 +9618,7 @@ op_db: List[OpInfo] = [ | ||||
|                    dtypes=all_types_and(torch.bool, torch.bfloat16), | ||||
|                    dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), | ||||
|                    safe_casts_outputs=True, | ||||
|                    supports_forward_ad=True, | ||||
|                    sample_inputs_func=sample_inputs_i0_i1), | ||||
|     UnaryUfuncInfo('special.i0e', | ||||
|                    aten_name='special_i0e', | ||||
| @ -9630,6 +9631,7 @@ op_db: List[OpInfo] = [ | ||||
|                    dtypes=all_types_and(torch.bool, torch.bfloat16), | ||||
|                    dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), | ||||
|                    sample_inputs_func=sample_inputs_i0_i1, | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     UnaryUfuncInfo('special.i1', | ||||
|                    aten_name='special_i1', | ||||
| @ -9649,13 +9651,15 @@ op_db: List[OpInfo] = [ | ||||
|                                     "TestUnaryUfuncs", | ||||
|                                     "test_out_arg_all_dtypes", | ||||
|                                     device_type='cuda'), | ||||
|                    )), | ||||
|                    ), | ||||
|                    supports_forward_ad=True), | ||||
|     UnaryUfuncInfo('special.i1e', | ||||
|                    aten_name='special_i1e', | ||||
|                    ref=scipy.special.i1e if TEST_SCIPY else _NOTHING, | ||||
|                    dtypes=all_types_and(torch.bool), | ||||
|                    dtypesIfCUDA=all_types_and(torch.bool), | ||||
|                    sample_inputs_func=sample_inputs_i0_i1, | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     UnaryUfuncInfo('special.ndtr', | ||||
|                    aten_name='special_ndtr', | ||||
| @ -9664,6 +9668,7 @@ op_db: List[OpInfo] = [ | ||||
|                    ref=scipy.special.ndtr if TEST_SCIPY else _NOTHING, | ||||
|                    dtypes=all_types_and(torch.bool, torch.bfloat16), | ||||
|                    dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16), | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     BinaryUfuncInfo('floor_divide', | ||||
|                     dtypes=all_types_and(torch.half, torch.bfloat16), | ||||
| @ -10281,12 +10286,14 @@ op_db: List[OpInfo] = [ | ||||
|            dtypesIfCUDA=all_types_and(torch.float16), | ||||
|            # TODO: some signatures of median do support out | ||||
|            supports_out=False, | ||||
|            supports_forward_ad=True, | ||||
|            sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), | ||||
|     OpInfo('nanmedian', | ||||
|            dtypes=all_types_and(torch.bfloat16), | ||||
|            dtypesIfCUDA=all_types_and(torch.float16), | ||||
|            # TODO: some signatures of nanmedian do support out | ||||
|            supports_out=False, | ||||
|            supports_forward_ad=True, | ||||
|            sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), | ||||
|     OpInfo('var_mean', | ||||
|            dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), | ||||
| @ -10381,6 +10388,11 @@ op_db: List[OpInfo] = [ | ||||
|     OpInfo('quantile', | ||||
|            dtypes=floating_types(), | ||||
|            sample_inputs_func=sample_inputs_reduction_quantile, | ||||
|            supports_forward_ad=True, | ||||
|            # See https://github.com/pytorch/pytorch/issues/66357 | ||||
|            # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which | ||||
|            # does not have a batching rule in core | ||||
|            check_batched_forward_grad=False, | ||||
|            skips=( | ||||
|                # Pre-existing condition; Needs to be fixed | ||||
|                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'), | ||||
| @ -10388,6 +10400,11 @@ op_db: List[OpInfo] = [ | ||||
|     OpInfo('nanquantile', | ||||
|            dtypes=floating_types(), | ||||
|            sample_inputs_func=sample_inputs_reduction_quantile, | ||||
|            supports_forward_ad=True, | ||||
|            # See https://github.com/pytorch/pytorch/issues/66357 | ||||
|            # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which | ||||
|            # does not have a batching rule in core | ||||
|            check_batched_forward_grad=False, | ||||
|            skips=( | ||||
|                # Pre-existing condition; Needs to be fixed | ||||
|                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'), | ||||
| @ -10674,6 +10691,9 @@ op_db: List[OpInfo] = [ | ||||
|                torch.as_strided(x, size, stride, storage_offset=storage_offset), | ||||
|            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), | ||||
|            supports_out=False, | ||||
|            supports_forward_ad=True, | ||||
|            # vmap does not support inplace views | ||||
|            check_inplace_batched_forward_grad=False, | ||||
|            sample_inputs_func=sample_inputs_as_strided, | ||||
|            skips=( | ||||
|                # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! | ||||
| @ -11253,7 +11273,7 @@ op_db: List[OpInfo] = [ | ||||
|             np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)), | ||||
|         dtypes=floating_types(), | ||||
|         dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), | ||||
|         supports_forward_ad=False, | ||||
|         supports_forward_ad=True, | ||||
|         supports_autograd=True, | ||||
|         assert_autodiffed=False, | ||||
|         supports_gradgrad=True, | ||||
| @ -11263,6 +11283,8 @@ op_db: List[OpInfo] = [ | ||||
|         inplace_variant=lambda x, alpha=1.0: | ||||
|             torch.nn.functional.elu(x, alpha, inplace=True), | ||||
|         decorators=[ | ||||
|             # Not implemented yet | ||||
|             DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), | ||||
|             DecorateInfo( | ||||
|                 toleranceOverride({ | ||||
|                     torch.float16: tol(atol=1e-03, rtol=1.2e-03), | ||||
| @ -11299,7 +11321,7 @@ op_db: List[OpInfo] = [ | ||||
|             np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)), | ||||
|         dtypes=floating_types(), | ||||
|         dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), | ||||
|         supports_forward_ad=False, | ||||
|         supports_forward_ad=True, | ||||
|         supports_autograd=True, | ||||
|         assert_autodiffed=False, | ||||
|         supports_gradgrad=True, | ||||
| @ -11309,6 +11331,8 @@ op_db: List[OpInfo] = [ | ||||
|         inplace_variant=lambda x, alpha=1.0: | ||||
|             torch.nn.functional.celu(x, alpha, inplace=True), | ||||
|         decorators=[ | ||||
|             # Not implemented yet | ||||
|             DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), | ||||
|             DecorateInfo( | ||||
|                 toleranceOverride({ | ||||
|                     torch.float16: tol(atol=1e-03, rtol=1.2e-03), | ||||
| @ -11357,13 +11381,15 @@ op_db: List[OpInfo] = [ | ||||
|             ), | ||||
|         dtypes=floating_types(), | ||||
|         dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), | ||||
|         supports_forward_ad=False, | ||||
|         supports_forward_ad=True,  # depends on 'elu' | ||||
|         supports_autograd=True, | ||||
|         assert_autodiffed=False, | ||||
|         supports_gradgrad=True, | ||||
|         supports_out=False, | ||||
|         inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True), | ||||
|         decorators=[ | ||||
|             # Not implemented yet (depends on 'elu_') | ||||
|             DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), | ||||
|             DecorateInfo( | ||||
|                 toleranceOverride({ | ||||
|                     torch.float16: tol(atol=1e-2, rtol=1.8e-2), | ||||
| @ -11433,7 +11459,7 @@ op_db: List[OpInfo] = [ | ||||
|         ref=lambda x: x * np.tanh(reference_softplus(x)), | ||||
|         dtypes=floating_types(), | ||||
|         dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), | ||||
|         supports_forward_ad=False, | ||||
|         supports_forward_ad=True, | ||||
|         supports_autograd=True, | ||||
|         assert_autodiffed=False, | ||||
|         supports_gradgrad=True, | ||||
| @ -11506,6 +11532,7 @@ op_db: List[OpInfo] = [ | ||||
|     OpInfo('topk', | ||||
|            dtypes=all_types_and(torch.bfloat16), | ||||
|            dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), | ||||
|            supports_forward_ad=True, | ||||
|            sample_inputs_func=sample_inputs_topk), | ||||
|     # Multiple variants for batch_norm to test with and without cuDNN disabled | ||||
|     # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details | ||||
| @ -11627,6 +11654,7 @@ op_db: List[OpInfo] = [ | ||||
|            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), | ||||
|            supports_gradgrad=True, | ||||
|            supports_out=False, | ||||
|            supports_forward_ad=True, | ||||
|            autodiff_nonfusible_nodes=["aten::gelu"]), | ||||
|     OpInfo('nn.functional.relu6', | ||||
|            aten_name="relu6", | ||||
| @ -11759,6 +11787,7 @@ op_db: List[OpInfo] = [ | ||||
|                        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', | ||||
|                                     dtypes=[torch.bfloat16]), | ||||
|                    ), | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     UnaryUfuncInfo('real', | ||||
|                    ref=np.real, | ||||
| @ -12890,6 +12919,7 @@ op_db: List[OpInfo] = [ | ||||
|            dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), | ||||
|            dtypesIfROCM=all_types_and(torch.float16), | ||||
|            sample_inputs_func=sample_inputs_sort, | ||||
|            supports_forward_ad=True, | ||||
|            skips=( | ||||
|                # sort does not correctly warn when resizing out= inputs | ||||
|                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), | ||||
| @ -12933,6 +12963,7 @@ op_db: List[OpInfo] = [ | ||||
|            sample_inputs_func=sample_inputs_take), | ||||
|     OpInfo('scatter', | ||||
|            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), | ||||
|            supports_forward_ad=True, | ||||
|            sample_inputs_func=sample_inputs_scatter,), | ||||
|     OpInfo('bfloat16', | ||||
|            op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), | ||||
| @ -13435,6 +13466,7 @@ op_db: List[OpInfo] = [ | ||||
|            dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), | ||||
|            dtypesIfROCM=all_types_and(torch.float16), | ||||
|            check_batched_gradgrad=False, | ||||
|            supports_forward_ad=True, | ||||
|            skips=( | ||||
|                # msort does not correctly warn when resizing out= inputs. | ||||
|                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), | ||||
| @ -13794,6 +13826,7 @@ op_db: List[OpInfo] = [ | ||||
|                    domain=(0, 1), | ||||
|                    aten_name='special_ndtri', | ||||
|                    dtypes=all_types_and(torch.bool), | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     UnaryUfuncInfo('erf', | ||||
|                    ref=scipy.special.erf if TEST_SCIPY else _NOTHING, | ||||
| @ -13811,6 +13844,7 @@ op_db: List[OpInfo] = [ | ||||
|                    assert_jit_shape_analysis=True, | ||||
|                    supports_sparse=True, | ||||
|                    supports_sparse_csr=True, | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     UnaryUfuncInfo('erfc', | ||||
|                    ref=scipy.special.erfc if TEST_SCIPY else _NOTHING, | ||||
| @ -13820,6 +13854,7 @@ op_db: List[OpInfo] = [ | ||||
|                    dtypes=all_types_and(torch.bool, torch.bfloat16), | ||||
|                    dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), | ||||
|                    assert_autodiffed=True, | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     UnaryUfuncInfo('erfinv', | ||||
|                    ref=scipy.special.erfinv if TEST_SCIPY else _NOTHING, | ||||
| @ -13831,6 +13866,7 @@ op_db: List[OpInfo] = [ | ||||
|                    dtypesIfCUDA=all_types_and(torch.bool, torch.half), | ||||
|                    safe_casts_outputs=True, | ||||
|                    supports_sparse_csr=True, | ||||
|                    supports_forward_ad=True, | ||||
|                    domain=(-1, 1), | ||||
|                    skips=( | ||||
|                        # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 | ||||
| @ -13991,6 +14027,7 @@ op_db: List[OpInfo] = [ | ||||
|                    aten_name='special_erfcx', | ||||
|                    decorators=(toleranceOverride({torch.float32: tol(atol=0, rtol=4e-6), }),), | ||||
|                    dtypes=all_types_and(torch.bool), | ||||
|                    supports_forward_ad=True, | ||||
|                    safe_casts_outputs=True), | ||||
|     OpInfo( | ||||
|         "nn.functional.dropout", | ||||
|  | ||||
		Reference in New Issue
	
	Block a user