mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[aotd] capture rrelu_with_noise noise mutation in compile (#141867)
Rebase-copy of long standing already approved PR https://github.com/pytorch/pytorch/pull/138503 that was blocked on landing by xla build issues. Got a new PR with the same content (ghstack checkout was failing due to changed submodules) Corresponding xla PR: https://github.com/pytorch/xla/pull/8363 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141867 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
61dc5e9c0a
commit
f85e238186
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
2ec22641e390cda25ec7c61fcbce07507727d584
|
||||
e0684027996c60dcbb99fddf205385c208fb9ed7
|
||||
|
@ -280,6 +280,44 @@ static void fill__Tensor_batch_rule(
|
||||
std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
|
||||
}
|
||||
|
||||
static
|
||||
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||
rrelu_with_noise_batch_rule(
|
||||
const Tensor& self,
|
||||
std::optional<int64_t> self_bdim,
|
||||
Tensor& noise,
|
||||
std::optional<int64_t> noise_bdim,
|
||||
const at::Scalar& lower,
|
||||
const at::Scalar& upper,
|
||||
bool training,
|
||||
std::optional<at::Generator> generator) {
|
||||
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto noise_ = moveBatchDimToFront(self, noise_bdim);
|
||||
|
||||
auto ret = at::rrelu_with_noise(self_, noise_, lower, upper, training, std::move(generator));
|
||||
|
||||
return std::make_tuple(ret, 0, noise_, 0);
|
||||
}
|
||||
|
||||
static Tensor rrelu_with_noise_batch(
|
||||
const Tensor& self,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
std::optional<Generator> generator) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
|
||||
int64_t cur_level = maybe_layer->layerId();
|
||||
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
|
||||
auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level);
|
||||
TORCH_CHECK(!noise_bdim.has_value(), "vmap: Attempted to vmap over 'noise' in torch.rrelu_with_noise. This is not supported.");
|
||||
auto res = rrelu_with_noise_batch_rule(self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, std::move(generator));
|
||||
return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, std::optional<int64_t>> log_sigmoid_backward_batch_rule(
|
||||
Tensor& grad, std::optional<int64_t> grad_bdim,
|
||||
Tensor& self, std::optional<int64_t> self_bdim,
|
||||
@ -421,7 +459,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
POINTWISE_BOXED(polygamma);
|
||||
BINARY_SCALAR_2(sub, Tensor, Scalar);
|
||||
BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
|
||||
BINARY_POINTWISE(rrelu_with_noise);
|
||||
BINARY_SCALAR_2(rsub, Tensor, Scalar);
|
||||
|
||||
BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
|
||||
@ -509,6 +546,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
|
||||
|
||||
VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
|
||||
m.impl("rrelu_with_noise", rrelu_with_noise_batch);
|
||||
}
|
||||
|
||||
} // namespace at::functorch
|
||||
|
@ -576,7 +576,7 @@ template <typename scalar_t>
|
||||
inline void _rrelu_with_noise_train(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower_,
|
||||
const Scalar& upper_,
|
||||
std::optional<Generator> generator) {
|
||||
@ -606,7 +606,7 @@ inline void _rrelu_with_noise_train(
|
||||
}
|
||||
|
||||
Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
@ -629,7 +629,7 @@ Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
|
||||
|
||||
Tensor rrelu_with_noise_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
@ -641,7 +641,7 @@ Tensor rrelu_with_noise_cpu(
|
||||
|
||||
Tensor& rrelu_with_noise_cpu_(
|
||||
Tensor& self,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
@ -670,12 +670,14 @@ Tensor rrelu_with_noise_backward(
|
||||
|
||||
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
||||
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
||||
return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
|
||||
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return at::rrelu_with_noise(self, noise, lower, upper, training, std::move(generator));
|
||||
}
|
||||
|
||||
Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
||||
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
||||
return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
|
||||
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return at::rrelu_with_noise_(self, noise, lower, upper, training, std::move(generator));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
|
||||
|
@ -71,7 +71,7 @@ template <typename scalar_t>
|
||||
inline void _rrelu_with_noise_cuda_train(
|
||||
Tensor& output,
|
||||
const Tensor& input_,
|
||||
const Tensor& noise_,
|
||||
Tensor& noise_,
|
||||
const Scalar& lower_,
|
||||
const Scalar& upper_,
|
||||
std::optional<Generator> generator) {
|
||||
@ -139,7 +139,7 @@ inline void _rrelu_with_noise_cuda_train(
|
||||
}
|
||||
|
||||
Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
@ -173,7 +173,7 @@ Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
|
||||
|
||||
Tensor rrelu_with_noise_cuda(
|
||||
const Tensor& self,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
@ -184,7 +184,7 @@ Tensor rrelu_with_noise_cuda(
|
||||
|
||||
Tensor& rrelu_with_noise_cuda_(
|
||||
Tensor& self,
|
||||
const Tensor& noise,
|
||||
Tensor& noise,
|
||||
const Scalar& lower,
|
||||
const Scalar& upper,
|
||||
bool training,
|
||||
|
@ -12018,19 +12018,20 @@
|
||||
CUDA: log_sigmoid_backward_cuda
|
||||
MPS: log_sigmoid_backward_mps
|
||||
|
||||
- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
tags: nondeterministic_seeded
|
||||
dispatch:
|
||||
CPU: rrelu_with_noise_out_cpu
|
||||
CUDA: rrelu_with_noise_out_cuda
|
||||
|
||||
- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||
- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: rrelu_with_noise_cpu
|
||||
CUDA: rrelu_with_noise_cuda
|
||||
tags: nondeterministic_seeded
|
||||
autogen: rrelu_with_noise_functional
|
||||
|
||||
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
|
||||
python_module: nn
|
||||
@ -12038,7 +12039,7 @@
|
||||
CompositeExplicitAutograd: rrelu_with_noise_backward
|
||||
autogen: rrelu_with_noise_backward.out
|
||||
|
||||
- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||
- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||
python_module: nn
|
||||
tags: nondeterministic_seeded
|
||||
dispatch:
|
||||
|
@ -1125,6 +1125,10 @@ aten::resize_as_sparse_
|
||||
aten::row_indices
|
||||
aten::row_indices_copy
|
||||
aten::row_indices_copy.out
|
||||
aten::rrelu_with_noise
|
||||
aten::rrelu_with_noise.out
|
||||
aten::rrelu_with_noise_
|
||||
aten::rrelu_with_noise_functional
|
||||
aten::scalar_tensor
|
||||
aten::scalar_tensor.out
|
||||
aten::scatter.reduce
|
||||
|
@ -137,6 +137,7 @@ ALLOW_LIST = [
|
||||
("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)),
|
||||
("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)),
|
||||
("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)),
|
||||
("aten::rrelu_with_noise", datetime.date(2024, 12, 31)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
@ -6327,6 +6327,42 @@ metadata incorrectly.
|
||||
out.sum().backward()
|
||||
self.assertEqual(ref_x.grad, x.grad)
|
||||
|
||||
def test_rrelu_with_noise_mutation(self):
|
||||
def fn_functional(x):
|
||||
noise = torch.ones_like(x)
|
||||
result, noise_out = torch.ops.aten.rrelu_with_noise_functional(
|
||||
x, noise, 0.2, 0.8, True
|
||||
)
|
||||
return result, noise_out
|
||||
|
||||
def fn_mutation(x):
|
||||
noise = torch.ones_like(x)
|
||||
result = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.8, True)
|
||||
return result, noise
|
||||
|
||||
def fn_inplace(x):
|
||||
noise = torch.ones_like(x, requires_grad=False)
|
||||
torch.ops.aten.rrelu_with_noise_(x, noise, 0.2, 0.8, True)
|
||||
return x, noise
|
||||
|
||||
def _test_fn(fn, check_backward=True):
|
||||
x = -torch.abs(torch.randn(4, 4, dtype=torch.bfloat16, requires_grad=True))
|
||||
|
||||
ref_y, ref_noise = fn(x)
|
||||
self.assertTrue(torch.all(ref_noise < torch.ones_like(ref_noise)).item())
|
||||
|
||||
comp_y, comp_noise = torch.compile(fn, backend="inductor", fullgraph=True)(
|
||||
x
|
||||
)
|
||||
|
||||
if check_backward:
|
||||
comp_y.sum().backward()
|
||||
self.assertTrue(torch.all(comp_noise < torch.ones_like(comp_noise)).item())
|
||||
|
||||
_test_fn(fn_functional)
|
||||
_test_fn(fn_mutation)
|
||||
_test_fn(fn_inplace, check_backward=False)
|
||||
|
||||
|
||||
# entries in here don't work and need to be fixed.
|
||||
# Each one of these is a bug (or needs to be investigated)
|
||||
|
@ -656,50 +656,6 @@ class TestDecomp(TestCase):
|
||||
for dim in (-1, 0, 1):
|
||||
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
|
||||
|
||||
def test_rrelu_with_noise(self, device):
|
||||
# rrelu_with_noise behavior depends on a) whether elements in the input
|
||||
# are <= 0, and b) whether we're in training mode. Cover all cases:
|
||||
dtype = torch.float64
|
||||
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
|
||||
lower = 1.0
|
||||
upper = 4.0
|
||||
training = False
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
||||
x,
|
||||
noise_res,
|
||||
lower,
|
||||
upper,
|
||||
training,
|
||||
)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(noise_ref, noise_res)
|
||||
|
||||
# Now with training=True:
|
||||
training = True
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
||||
x,
|
||||
noise_res,
|
||||
lower,
|
||||
upper,
|
||||
training,
|
||||
)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(noise_ref, noise_res)
|
||||
|
||||
@suppress_warnings
|
||||
@tf32_off()
|
||||
# only tests RNNs since we have py dispsatcher decomps for them
|
||||
|
@ -2182,13 +2182,17 @@
|
||||
result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t)
|
||||
result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p)
|
||||
|
||||
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||
- name: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
||||
result: auto_element_wise
|
||||
|
||||
- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||
- name: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
|
||||
|
||||
- name: rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out)
|
||||
noise: non_differentiable
|
||||
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
||||
|
||||
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
||||
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||
result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true))
|
||||
|
@ -305,44 +305,6 @@ def _prelu_kernel_backward(
|
||||
return (input_grad, weight_grad)
|
||||
|
||||
|
||||
@register_decomposition(aten.rrelu_with_noise)
|
||||
@aten.rrelu_with_noise.default.py_impl(DispatchKey.Autograd)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def rrelu_with_noise(
|
||||
self: Tensor,
|
||||
noise: Tensor,
|
||||
lower: float = 0.125,
|
||||
upper: float = 0.3333333333333333,
|
||||
training: bool = False,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
assert generator is None
|
||||
if training:
|
||||
not_positive = self <= 0
|
||||
r = aten.uniform(self, lower, upper)
|
||||
output = torch.where(not_positive, self * r, self)
|
||||
noise.copy_(torch.where(not_positive, r, 1))
|
||||
return output
|
||||
else:
|
||||
negative_slope = (lower + upper) / 2
|
||||
return aten.leaky_relu(self, negative_slope)
|
||||
|
||||
|
||||
@register_decomposition(aten.rrelu_with_noise_)
|
||||
@aten.rrelu_with_noise_.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def rrelu_with_noise_(
|
||||
self: Tensor,
|
||||
noise: Tensor,
|
||||
lower: float = 0.125,
|
||||
upper: float = 0.3333333333333333,
|
||||
training: bool = False,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator))
|
||||
|
||||
|
||||
@register_decomposition(aten.rrelu_with_noise_backward)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
|
@ -76,6 +76,7 @@ inductor_decompositions = get_decompositions(
|
||||
aten.native_layer_norm,
|
||||
aten.nll_loss2d_backward,
|
||||
aten.permute_copy,
|
||||
aten.rrelu_with_noise_backward,
|
||||
aten._softmax,
|
||||
aten.sin_,
|
||||
aten.sqrt_,
|
||||
@ -1022,3 +1023,23 @@ def searchsorted_scalar(
|
||||
side=side,
|
||||
sorter=sorter,
|
||||
)[0]
|
||||
|
||||
|
||||
@register_decomposition(aten.rrelu_with_noise_functional)
|
||||
def rrelu_with_noise_functional(
|
||||
self: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
lower: float = 0.125,
|
||||
upper: float = 0.3333333333333333,
|
||||
training: bool = False,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if training:
|
||||
not_positive = self <= 0
|
||||
r = aten.uniform(self, lower, upper, generator=generator)
|
||||
output = torch.where(not_positive, self * r, self)
|
||||
noise_out = torch.where(not_positive, r, 1)
|
||||
return output, noise_out
|
||||
else:
|
||||
negative_slope = (lower + upper) / 2
|
||||
return aten.leaky_relu(self, negative_slope), torch.Tensor()
|
||||
|
@ -3734,6 +3734,28 @@ def meta__add_relu(self, other, alpha=1) -> Tensor:
|
||||
)
|
||||
|
||||
|
||||
@register_meta([aten.rrelu_with_noise])
|
||||
@out_wrapper()
|
||||
def meta_rrelu_with_noise(
|
||||
self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
|
||||
):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
@register_meta([aten.rrelu_with_noise_functional])
|
||||
def meta_rrelu_with_noise_functional(
|
||||
self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
|
||||
):
|
||||
return torch.empty_like(self), torch.empty_like(noise)
|
||||
|
||||
|
||||
@register_meta([aten.rrelu_with_noise_])
|
||||
def meta_rrelu_with_noise_(
|
||||
self, lower=0.125, upper=0.3333333333333333, training=False, generator=None
|
||||
):
|
||||
return self
|
||||
|
||||
|
||||
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
|
||||
def meta_index_put(self, indices, values, accumulate=False):
|
||||
return torch.empty_like(self)
|
||||
|
@ -448,10 +448,10 @@ def add_generated_native_functions(
|
||||
continue
|
||||
|
||||
base_fn = (
|
||||
d[SchemaKind.inplace]
|
||||
if has_inplace
|
||||
else d[SchemaKind.mutable]
|
||||
d[SchemaKind.mutable]
|
||||
if has_mutable
|
||||
else d[SchemaKind.inplace]
|
||||
if has_inplace
|
||||
else d[SchemaKind.out]
|
||||
if has_out
|
||||
else d[SchemaKind.functional]
|
||||
|
Reference in New Issue
Block a user