diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 0a78fd8f8d1f..6e1592cb6101 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -2ec22641e390cda25ec7c61fcbce07507727d584 +e0684027996c60dcbb99fddf205385c208fb9ed7 diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 37765dad49cb..b3d7ee40cf9d 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -280,6 +280,44 @@ static void fill__Tensor_batch_rule( std::get<0>(self_and_other).copy_(std::get<1>(self_and_other)); } +static +std::tuple,Tensor, std::optional> +rrelu_with_noise_batch_rule( + const Tensor& self, + std::optional self_bdim, + Tensor& noise, + std::optional noise_bdim, + const at::Scalar& lower, + const at::Scalar& upper, + bool training, + std::optional generator) { + + auto self_ = moveBatchDimToFront(self, self_bdim); + auto noise_ = moveBatchDimToFront(self, noise_bdim); + + auto ret = at::rrelu_with_noise(self_, noise_, lower, upper, training, std::move(generator)); + + return std::make_tuple(ret, 0, noise_, 0); +} + +static Tensor rrelu_with_noise_batch( + const Tensor& self, + Tensor& noise, + const Scalar& lower, + const Scalar& upper, + bool training, + std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level); + TORCH_CHECK(!noise_bdim.has_value(), "vmap: Attempted to vmap over 'noise' in torch.rrelu_with_noise. This is not supported."); + auto res = rrelu_with_noise_batch_rule(self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, std::move(generator)); + return makeBatched(std::get<0>(res), std::get<1>(res), cur_level); +} + static std::tuple> log_sigmoid_backward_batch_rule( Tensor& grad, std::optional grad_bdim, Tensor& self, std::optional self_bdim, @@ -421,7 +459,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { POINTWISE_BOXED(polygamma); BINARY_SCALAR_2(sub, Tensor, Scalar); BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor); - BINARY_POINTWISE(rrelu_with_noise); BINARY_SCALAR_2(rsub, Tensor, Scalar); BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar); @@ -509,6 +546,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule); VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule); + m.impl("rrelu_with_noise", rrelu_with_noise_batch); } } // namespace at::functorch diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 404868b737ed..8607e02c436e 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -576,7 +576,7 @@ template inline void _rrelu_with_noise_train( Tensor& output, const Tensor& input, - const Tensor& noise, + Tensor& noise, const Scalar& lower_, const Scalar& upper_, std::optional generator) { @@ -606,7 +606,7 @@ inline void _rrelu_with_noise_train( } Tensor& rrelu_with_noise_out_cpu(const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -629,7 +629,7 @@ Tensor& rrelu_with_noise_out_cpu(const Tensor& self, Tensor rrelu_with_noise_cpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -641,7 +641,7 @@ Tensor rrelu_with_noise_cpu( Tensor& rrelu_with_noise_cpu_( Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -670,12 +670,14 @@ Tensor rrelu_with_noise_backward( Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional generator) { TORCH_CHECK(lower.to() <= upper.to(), "Lower bound should be less than or equal to the upper bound") - return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator)); + auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + return at::rrelu_with_noise(self, noise, lower, upper, training, std::move(generator)); } Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional generator) { TORCH_CHECK(lower.to() <= upper.to(), "Lower bound should be less than or equal to the upper bound") - return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator)); + auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + return at::rrelu_with_noise_(self, noise, lower, upper, training, std::move(generator)); } TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) { diff --git a/aten/src/ATen/native/cuda/RreluWithNoise.cu b/aten/src/ATen/native/cuda/RreluWithNoise.cu index 60285f5f70ee..11127d103889 100644 --- a/aten/src/ATen/native/cuda/RreluWithNoise.cu +++ b/aten/src/ATen/native/cuda/RreluWithNoise.cu @@ -71,7 +71,7 @@ template inline void _rrelu_with_noise_cuda_train( Tensor& output, const Tensor& input_, - const Tensor& noise_, + Tensor& noise_, const Scalar& lower_, const Scalar& upper_, std::optional generator) { @@ -139,7 +139,7 @@ inline void _rrelu_with_noise_cuda_train( } Tensor& rrelu_with_noise_out_cuda(const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -173,7 +173,7 @@ Tensor& rrelu_with_noise_out_cuda(const Tensor& self, Tensor rrelu_with_noise_cuda( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -184,7 +184,7 @@ Tensor rrelu_with_noise_cuda( Tensor& rrelu_with_noise_cuda_( Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4ced922094c9..3f7df7676f7b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -12018,19 +12018,20 @@ CUDA: log_sigmoid_backward_cuda MPS: log_sigmoid_backward_mps -- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) +- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: CPU: rrelu_with_noise_out_cpu CUDA: rrelu_with_noise_out_cuda -- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor +- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor python_module: nn dispatch: CPU: rrelu_with_noise_cpu CUDA: rrelu_with_noise_cuda tags: nondeterministic_seeded + autogen: rrelu_with_noise_functional - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor python_module: nn @@ -12038,7 +12039,7 @@ CompositeExplicitAutograd: rrelu_with_noise_backward autogen: rrelu_with_noise_backward.out -- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) +- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 36f49ec11c53..883399f855cc 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1125,6 +1125,10 @@ aten::resize_as_sparse_ aten::row_indices aten::row_indices_copy aten::row_indices_copy.out +aten::rrelu_with_noise +aten::rrelu_with_noise.out +aten::rrelu_with_noise_ +aten::rrelu_with_noise_functional aten::scalar_tensor aten::scalar_tensor.out aten::scatter.reduce diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 7fe56facf5b5..e0ab6a07fe25 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -137,6 +137,7 @@ ALLOW_LIST = [ ("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)), ("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)), ("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), + ("aten::rrelu_with_noise", datetime.date(2024, 12, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 791547593e56..64eb0f630a76 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6327,6 +6327,42 @@ metadata incorrectly. out.sum().backward() self.assertEqual(ref_x.grad, x.grad) + def test_rrelu_with_noise_mutation(self): + def fn_functional(x): + noise = torch.ones_like(x) + result, noise_out = torch.ops.aten.rrelu_with_noise_functional( + x, noise, 0.2, 0.8, True + ) + return result, noise_out + + def fn_mutation(x): + noise = torch.ones_like(x) + result = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.8, True) + return result, noise + + def fn_inplace(x): + noise = torch.ones_like(x, requires_grad=False) + torch.ops.aten.rrelu_with_noise_(x, noise, 0.2, 0.8, True) + return x, noise + + def _test_fn(fn, check_backward=True): + x = -torch.abs(torch.randn(4, 4, dtype=torch.bfloat16, requires_grad=True)) + + ref_y, ref_noise = fn(x) + self.assertTrue(torch.all(ref_noise < torch.ones_like(ref_noise)).item()) + + comp_y, comp_noise = torch.compile(fn, backend="inductor", fullgraph=True)( + x + ) + + if check_backward: + comp_y.sum().backward() + self.assertTrue(torch.all(comp_noise < torch.ones_like(comp_noise)).item()) + + _test_fn(fn_functional) + _test_fn(fn_mutation) + _test_fn(fn_inplace, check_backward=False) + # entries in here don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) diff --git a/test/test_decomp.py b/test/test_decomp.py index 0dc38143fe33..0177c50ca7d8 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -656,50 +656,6 @@ class TestDecomp(TestCase): for dim in (-1, 0, 1): self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) - def test_rrelu_with_noise(self, device): - # rrelu_with_noise behavior depends on a) whether elements in the input - # are <= 0, and b) whether we're in training mode. Cover all cases: - dtype = torch.float64 - x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device) - lower = 1.0 - upper = 4.0 - training = False - - torch.manual_seed(123) - noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) - ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) - - torch.manual_seed(123) - noise_res = torch.zeros(x.shape, dtype=dtype, device=device) - res = torch._decomp.decompositions.rrelu_with_noise( - x, - noise_res, - lower, - upper, - training, - ) - self.assertEqual(ref, res) - self.assertEqual(noise_ref, noise_res) - - # Now with training=True: - training = True - - torch.manual_seed(123) - noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) - ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) - - torch.manual_seed(123) - noise_res = torch.zeros(x.shape, dtype=dtype, device=device) - res = torch._decomp.decompositions.rrelu_with_noise( - x, - noise_res, - lower, - upper, - training, - ) - self.assertEqual(ref, res) - self.assertEqual(noise_ref, noise_res) - @suppress_warnings @tf32_off() # only tests RNNs since we have py dispsatcher decomps for them diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 3f944a6dae3c..fa77b906b1b4 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2182,13 +2182,17 @@ result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t) result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p) -- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor +- name: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) result: auto_element_wise -- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) +- name: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) +- name: rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out) + noise: non_differentiable + self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + - name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor self: _softmax_backward_data(grad, result, dim, self.scalar_type()) result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true)) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index c2e622ac1c02..097569035228 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -305,44 +305,6 @@ def _prelu_kernel_backward( return (input_grad, weight_grad) -@register_decomposition(aten.rrelu_with_noise) -@aten.rrelu_with_noise.default.py_impl(DispatchKey.Autograd) -@out_wrapper() -@pw_cast_for_opmath -def rrelu_with_noise( - self: Tensor, - noise: Tensor, - lower: float = 0.125, - upper: float = 0.3333333333333333, - training: bool = False, - generator: Optional[torch.Generator] = None, -) -> Tensor: - assert generator is None - if training: - not_positive = self <= 0 - r = aten.uniform(self, lower, upper) - output = torch.where(not_positive, self * r, self) - noise.copy_(torch.where(not_positive, r, 1)) - return output - else: - negative_slope = (lower + upper) / 2 - return aten.leaky_relu(self, negative_slope) - - -@register_decomposition(aten.rrelu_with_noise_) -@aten.rrelu_with_noise_.default.py_impl(DispatchKey.Autograd) -@pw_cast_for_opmath -def rrelu_with_noise_( - self: Tensor, - noise: Tensor, - lower: float = 0.125, - upper: float = 0.3333333333333333, - training: bool = False, - generator: Optional[torch.Generator] = None, -) -> Tensor: - return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator)) - - @register_decomposition(aten.rrelu_with_noise_backward) @out_wrapper() @pw_cast_for_opmath diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index c7f0139d98c9..4701f1bdafb4 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -76,6 +76,7 @@ inductor_decompositions = get_decompositions( aten.native_layer_norm, aten.nll_loss2d_backward, aten.permute_copy, + aten.rrelu_with_noise_backward, aten._softmax, aten.sin_, aten.sqrt_, @@ -1022,3 +1023,23 @@ def searchsorted_scalar( side=side, sorter=sorter, )[0] + + +@register_decomposition(aten.rrelu_with_noise_functional) +def rrelu_with_noise_functional( + self: torch.Tensor, + noise: torch.Tensor, + lower: float = 0.125, + upper: float = 0.3333333333333333, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if training: + not_positive = self <= 0 + r = aten.uniform(self, lower, upper, generator=generator) + output = torch.where(not_positive, self * r, self) + noise_out = torch.where(not_positive, r, 1) + return output, noise_out + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu(self, negative_slope), torch.Tensor() diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c915dd4c6b6f..06c9e029d774 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3734,6 +3734,28 @@ def meta__add_relu(self, other, alpha=1) -> Tensor: ) +@register_meta([aten.rrelu_with_noise]) +@out_wrapper() +def meta_rrelu_with_noise( + self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None +): + return torch.empty_like(self) + + +@register_meta([aten.rrelu_with_noise_functional]) +def meta_rrelu_with_noise_functional( + self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None +): + return torch.empty_like(self), torch.empty_like(noise) + + +@register_meta([aten.rrelu_with_noise_]) +def meta_rrelu_with_noise_( + self, lower=0.125, upper=0.3333333333333333, training=False, generator=None +): + return self + + @register_meta([aten.index_put.default, aten._unsafe_index_put.default]) def meta_index_put(self, indices, values, accumulate=False): return torch.empty_like(self) diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 98692a704234..f986c77f8faa 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -448,10 +448,10 @@ def add_generated_native_functions( continue base_fn = ( - d[SchemaKind.inplace] - if has_inplace - else d[SchemaKind.mutable] + d[SchemaKind.mutable] if has_mutable + else d[SchemaKind.inplace] + if has_inplace else d[SchemaKind.out] if has_out else d[SchemaKind.functional]