[aotd] capture rrelu_with_noise noise mutation in compile (#141867)

Rebase-copy of long standing already approved PR https://github.com/pytorch/pytorch/pull/138503 that was blocked on landing by xla build issues.

Got a new  PR with the same content (ghstack checkout was failing due to changed submodules)

Corresponding xla PR:
https://github.com/pytorch/xla/pull/8363

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141867
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev
2024-12-03 02:54:06 -08:00
committed by PyTorch MergeBot
parent 61dc5e9c0a
commit f85e238186
14 changed files with 149 additions and 102 deletions

View File

@ -1 +1 @@
2ec22641e390cda25ec7c61fcbce07507727d584
e0684027996c60dcbb99fddf205385c208fb9ed7

View File

@ -280,6 +280,44 @@ static void fill__Tensor_batch_rule(
std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
}
static
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
rrelu_with_noise_batch_rule(
const Tensor& self,
std::optional<int64_t> self_bdim,
Tensor& noise,
std::optional<int64_t> noise_bdim,
const at::Scalar& lower,
const at::Scalar& upper,
bool training,
std::optional<at::Generator> generator) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto noise_ = moveBatchDimToFront(self, noise_bdim);
auto ret = at::rrelu_with_noise(self_, noise_, lower, upper, training, std::move(generator));
return std::make_tuple(ret, 0, noise_, 0);
}
static Tensor rrelu_with_noise_batch(
const Tensor& self,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
std::optional<Generator> generator) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
int64_t cur_level = maybe_layer->layerId();
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level);
TORCH_CHECK(!noise_bdim.has_value(), "vmap: Attempted to vmap over 'noise' in torch.rrelu_with_noise. This is not supported.");
auto res = rrelu_with_noise_batch_rule(self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, std::move(generator));
return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
}
static std::tuple<Tensor, std::optional<int64_t>> log_sigmoid_backward_batch_rule(
Tensor& grad, std::optional<int64_t> grad_bdim,
Tensor& self, std::optional<int64_t> self_bdim,
@ -421,7 +459,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
POINTWISE_BOXED(polygamma);
BINARY_SCALAR_2(sub, Tensor, Scalar);
BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
BINARY_POINTWISE(rrelu_with_noise);
BINARY_SCALAR_2(rsub, Tensor, Scalar);
BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
@ -509,6 +546,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
m.impl("rrelu_with_noise", rrelu_with_noise_batch);
}
} // namespace at::functorch

View File

@ -576,7 +576,7 @@ template <typename scalar_t>
inline void _rrelu_with_noise_train(
Tensor& output,
const Tensor& input,
const Tensor& noise,
Tensor& noise,
const Scalar& lower_,
const Scalar& upper_,
std::optional<Generator> generator) {
@ -606,7 +606,7 @@ inline void _rrelu_with_noise_train(
}
Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
@ -629,7 +629,7 @@ Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
Tensor rrelu_with_noise_cpu(
const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
@ -641,7 +641,7 @@ Tensor rrelu_with_noise_cpu(
Tensor& rrelu_with_noise_cpu_(
Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
@ -670,12 +670,14 @@ Tensor rrelu_with_noise_backward(
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::rrelu_with_noise(self, noise, lower, upper, training, std::move(generator));
}
Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::rrelu_with_noise_(self, noise, lower, upper, training, std::move(generator));
}
TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {

View File

@ -71,7 +71,7 @@ template <typename scalar_t>
inline void _rrelu_with_noise_cuda_train(
Tensor& output,
const Tensor& input_,
const Tensor& noise_,
Tensor& noise_,
const Scalar& lower_,
const Scalar& upper_,
std::optional<Generator> generator) {
@ -139,7 +139,7 @@ inline void _rrelu_with_noise_cuda_train(
}
Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
@ -173,7 +173,7 @@ Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
Tensor rrelu_with_noise_cuda(
const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
@ -184,7 +184,7 @@ Tensor rrelu_with_noise_cuda(
Tensor& rrelu_with_noise_cuda_(
Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,

View File

@ -12018,19 +12018,20 @@
CUDA: log_sigmoid_backward_cuda
MPS: log_sigmoid_backward_mps
- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
tags: nondeterministic_seeded
dispatch:
CPU: rrelu_with_noise_out_cpu
CUDA: rrelu_with_noise_out_cuda
- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
python_module: nn
dispatch:
CPU: rrelu_with_noise_cpu
CUDA: rrelu_with_noise_cuda
tags: nondeterministic_seeded
autogen: rrelu_with_noise_functional
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
python_module: nn
@ -12038,7 +12039,7 @@
CompositeExplicitAutograd: rrelu_with_noise_backward
autogen: rrelu_with_noise_backward.out
- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
python_module: nn
tags: nondeterministic_seeded
dispatch:

View File

@ -1125,6 +1125,10 @@ aten::resize_as_sparse_
aten::row_indices
aten::row_indices_copy
aten::row_indices_copy.out
aten::rrelu_with_noise
aten::rrelu_with_noise.out
aten::rrelu_with_noise_
aten::rrelu_with_noise_functional
aten::scalar_tensor
aten::scalar_tensor.out
aten::scatter.reduce

View File

@ -137,6 +137,7 @@ ALLOW_LIST = [
("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)),
("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)),
("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)),
("aten::rrelu_with_noise", datetime.date(2024, 12, 31)),
]
ALLOW_LIST_COMPILED = [

View File

@ -6327,6 +6327,42 @@ metadata incorrectly.
out.sum().backward()
self.assertEqual(ref_x.grad, x.grad)
def test_rrelu_with_noise_mutation(self):
def fn_functional(x):
noise = torch.ones_like(x)
result, noise_out = torch.ops.aten.rrelu_with_noise_functional(
x, noise, 0.2, 0.8, True
)
return result, noise_out
def fn_mutation(x):
noise = torch.ones_like(x)
result = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.8, True)
return result, noise
def fn_inplace(x):
noise = torch.ones_like(x, requires_grad=False)
torch.ops.aten.rrelu_with_noise_(x, noise, 0.2, 0.8, True)
return x, noise
def _test_fn(fn, check_backward=True):
x = -torch.abs(torch.randn(4, 4, dtype=torch.bfloat16, requires_grad=True))
ref_y, ref_noise = fn(x)
self.assertTrue(torch.all(ref_noise < torch.ones_like(ref_noise)).item())
comp_y, comp_noise = torch.compile(fn, backend="inductor", fullgraph=True)(
x
)
if check_backward:
comp_y.sum().backward()
self.assertTrue(torch.all(comp_noise < torch.ones_like(comp_noise)).item())
_test_fn(fn_functional)
_test_fn(fn_mutation)
_test_fn(fn_inplace, check_backward=False)
# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)

View File

@ -656,50 +656,6 @@ class TestDecomp(TestCase):
for dim in (-1, 0, 1):
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
def test_rrelu_with_noise(self, device):
# rrelu_with_noise behavior depends on a) whether elements in the input
# are <= 0, and b) whether we're in training mode. Cover all cases:
dtype = torch.float64
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
lower = 1.0
upper = 4.0
training = False
torch.manual_seed(123)
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
torch.manual_seed(123)
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
res = torch._decomp.decompositions.rrelu_with_noise(
x,
noise_res,
lower,
upper,
training,
)
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
# Now with training=True:
training = True
torch.manual_seed(123)
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
torch.manual_seed(123)
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
res = torch._decomp.decompositions.rrelu_with_noise(
x,
noise_res,
lower,
upper,
training,
)
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
@suppress_warnings
@tf32_off()
# only tests RNNs since we have py dispsatcher decomps for them

View File

@ -2182,13 +2182,17 @@
result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t)
result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p)
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
- name: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
result: auto_element_wise
- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
- name: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
- name: rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out)
noise: non_differentiable
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true))

View File

@ -305,44 +305,6 @@ def _prelu_kernel_backward(
return (input_grad, weight_grad)
@register_decomposition(aten.rrelu_with_noise)
@aten.rrelu_with_noise.default.py_impl(DispatchKey.Autograd)
@out_wrapper()
@pw_cast_for_opmath
def rrelu_with_noise(
self: Tensor,
noise: Tensor,
lower: float = 0.125,
upper: float = 0.3333333333333333,
training: bool = False,
generator: Optional[torch.Generator] = None,
) -> Tensor:
assert generator is None
if training:
not_positive = self <= 0
r = aten.uniform(self, lower, upper)
output = torch.where(not_positive, self * r, self)
noise.copy_(torch.where(not_positive, r, 1))
return output
else:
negative_slope = (lower + upper) / 2
return aten.leaky_relu(self, negative_slope)
@register_decomposition(aten.rrelu_with_noise_)
@aten.rrelu_with_noise_.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def rrelu_with_noise_(
self: Tensor,
noise: Tensor,
lower: float = 0.125,
upper: float = 0.3333333333333333,
training: bool = False,
generator: Optional[torch.Generator] = None,
) -> Tensor:
return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator))
@register_decomposition(aten.rrelu_with_noise_backward)
@out_wrapper()
@pw_cast_for_opmath

View File

@ -76,6 +76,7 @@ inductor_decompositions = get_decompositions(
aten.native_layer_norm,
aten.nll_loss2d_backward,
aten.permute_copy,
aten.rrelu_with_noise_backward,
aten._softmax,
aten.sin_,
aten.sqrt_,
@ -1022,3 +1023,23 @@ def searchsorted_scalar(
side=side,
sorter=sorter,
)[0]
@register_decomposition(aten.rrelu_with_noise_functional)
def rrelu_with_noise_functional(
self: torch.Tensor,
noise: torch.Tensor,
lower: float = 0.125,
upper: float = 0.3333333333333333,
training: bool = False,
generator: Optional[torch.Generator] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if training:
not_positive = self <= 0
r = aten.uniform(self, lower, upper, generator=generator)
output = torch.where(not_positive, self * r, self)
noise_out = torch.where(not_positive, r, 1)
return output, noise_out
else:
negative_slope = (lower + upper) / 2
return aten.leaky_relu(self, negative_slope), torch.Tensor()

View File

@ -3734,6 +3734,28 @@ def meta__add_relu(self, other, alpha=1) -> Tensor:
)
@register_meta([aten.rrelu_with_noise])
@out_wrapper()
def meta_rrelu_with_noise(
self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
):
return torch.empty_like(self)
@register_meta([aten.rrelu_with_noise_functional])
def meta_rrelu_with_noise_functional(
self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
):
return torch.empty_like(self), torch.empty_like(noise)
@register_meta([aten.rrelu_with_noise_])
def meta_rrelu_with_noise_(
self, lower=0.125, upper=0.3333333333333333, training=False, generator=None
):
return self
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
def meta_index_put(self, indices, values, accumulate=False):
return torch.empty_like(self)

View File

@ -448,10 +448,10 @@ def add_generated_native_functions(
continue
base_fn = (
d[SchemaKind.inplace]
if has_inplace
else d[SchemaKind.mutable]
d[SchemaKind.mutable]
if has_mutable
else d[SchemaKind.inplace]
if has_inplace
else d[SchemaKind.out]
if has_out
else d[SchemaKind.functional]