mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[aotd] capture rrelu_with_noise noise mutation in compile (#141867)
Rebase-copy of long standing already approved PR https://github.com/pytorch/pytorch/pull/138503 that was blocked on landing by xla build issues. Got a new PR with the same content (ghstack checkout was failing due to changed submodules) Corresponding xla PR: https://github.com/pytorch/xla/pull/8363 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141867 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
61dc5e9c0a
commit
f85e238186
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
2ec22641e390cda25ec7c61fcbce07507727d584
|
e0684027996c60dcbb99fddf205385c208fb9ed7
|
||||||
|
@ -280,6 +280,44 @@ static void fill__Tensor_batch_rule(
|
|||||||
std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
|
std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
|
rrelu_with_noise_batch_rule(
|
||||||
|
const Tensor& self,
|
||||||
|
std::optional<int64_t> self_bdim,
|
||||||
|
Tensor& noise,
|
||||||
|
std::optional<int64_t> noise_bdim,
|
||||||
|
const at::Scalar& lower,
|
||||||
|
const at::Scalar& upper,
|
||||||
|
bool training,
|
||||||
|
std::optional<at::Generator> generator) {
|
||||||
|
|
||||||
|
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||||
|
auto noise_ = moveBatchDimToFront(self, noise_bdim);
|
||||||
|
|
||||||
|
auto ret = at::rrelu_with_noise(self_, noise_, lower, upper, training, std::move(generator));
|
||||||
|
|
||||||
|
return std::make_tuple(ret, 0, noise_, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Tensor rrelu_with_noise_batch(
|
||||||
|
const Tensor& self,
|
||||||
|
Tensor& noise,
|
||||||
|
const Scalar& lower,
|
||||||
|
const Scalar& upper,
|
||||||
|
bool training,
|
||||||
|
std::optional<Generator> generator) {
|
||||||
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||||
|
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
|
||||||
|
int64_t cur_level = maybe_layer->layerId();
|
||||||
|
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
|
||||||
|
auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level);
|
||||||
|
TORCH_CHECK(!noise_bdim.has_value(), "vmap: Attempted to vmap over 'noise' in torch.rrelu_with_noise. This is not supported.");
|
||||||
|
auto res = rrelu_with_noise_batch_rule(self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, std::move(generator));
|
||||||
|
return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
|
||||||
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor, std::optional<int64_t>> log_sigmoid_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> log_sigmoid_backward_batch_rule(
|
||||||
Tensor& grad, std::optional<int64_t> grad_bdim,
|
Tensor& grad, std::optional<int64_t> grad_bdim,
|
||||||
Tensor& self, std::optional<int64_t> self_bdim,
|
Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
@ -421,7 +459,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||||||
POINTWISE_BOXED(polygamma);
|
POINTWISE_BOXED(polygamma);
|
||||||
BINARY_SCALAR_2(sub, Tensor, Scalar);
|
BINARY_SCALAR_2(sub, Tensor, Scalar);
|
||||||
BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
|
BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
|
||||||
BINARY_POINTWISE(rrelu_with_noise);
|
|
||||||
BINARY_SCALAR_2(rsub, Tensor, Scalar);
|
BINARY_SCALAR_2(rsub, Tensor, Scalar);
|
||||||
|
|
||||||
BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
|
BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
|
||||||
@ -509,6 +546,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||||||
VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
|
VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
|
||||||
|
|
||||||
VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
|
VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
|
||||||
|
m.impl("rrelu_with_noise", rrelu_with_noise_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::functorch
|
} // namespace at::functorch
|
||||||
|
@ -576,7 +576,7 @@ template <typename scalar_t>
|
|||||||
inline void _rrelu_with_noise_train(
|
inline void _rrelu_with_noise_train(
|
||||||
Tensor& output,
|
Tensor& output,
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower_,
|
const Scalar& lower_,
|
||||||
const Scalar& upper_,
|
const Scalar& upper_,
|
||||||
std::optional<Generator> generator) {
|
std::optional<Generator> generator) {
|
||||||
@ -606,7 +606,7 @@ inline void _rrelu_with_noise_train(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
|
Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower,
|
const Scalar& lower,
|
||||||
const Scalar& upper,
|
const Scalar& upper,
|
||||||
bool training,
|
bool training,
|
||||||
@ -629,7 +629,7 @@ Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
|
|||||||
|
|
||||||
Tensor rrelu_with_noise_cpu(
|
Tensor rrelu_with_noise_cpu(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower,
|
const Scalar& lower,
|
||||||
const Scalar& upper,
|
const Scalar& upper,
|
||||||
bool training,
|
bool training,
|
||||||
@ -641,7 +641,7 @@ Tensor rrelu_with_noise_cpu(
|
|||||||
|
|
||||||
Tensor& rrelu_with_noise_cpu_(
|
Tensor& rrelu_with_noise_cpu_(
|
||||||
Tensor& self,
|
Tensor& self,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower,
|
const Scalar& lower,
|
||||||
const Scalar& upper,
|
const Scalar& upper,
|
||||||
bool training,
|
bool training,
|
||||||
@ -670,12 +670,14 @@ Tensor rrelu_with_noise_backward(
|
|||||||
|
|
||||||
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
||||||
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
||||||
return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
|
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
|
return at::rrelu_with_noise(self, noise, lower, upper, training, std::move(generator));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
||||||
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
||||||
return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
|
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
|
return at::rrelu_with_noise_(self, noise, lower, upper, training, std::move(generator));
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
|
TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
|
||||||
|
@ -71,7 +71,7 @@ template <typename scalar_t>
|
|||||||
inline void _rrelu_with_noise_cuda_train(
|
inline void _rrelu_with_noise_cuda_train(
|
||||||
Tensor& output,
|
Tensor& output,
|
||||||
const Tensor& input_,
|
const Tensor& input_,
|
||||||
const Tensor& noise_,
|
Tensor& noise_,
|
||||||
const Scalar& lower_,
|
const Scalar& lower_,
|
||||||
const Scalar& upper_,
|
const Scalar& upper_,
|
||||||
std::optional<Generator> generator) {
|
std::optional<Generator> generator) {
|
||||||
@ -139,7 +139,7 @@ inline void _rrelu_with_noise_cuda_train(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
|
Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower,
|
const Scalar& lower,
|
||||||
const Scalar& upper,
|
const Scalar& upper,
|
||||||
bool training,
|
bool training,
|
||||||
@ -173,7 +173,7 @@ Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
|
|||||||
|
|
||||||
Tensor rrelu_with_noise_cuda(
|
Tensor rrelu_with_noise_cuda(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower,
|
const Scalar& lower,
|
||||||
const Scalar& upper,
|
const Scalar& upper,
|
||||||
bool training,
|
bool training,
|
||||||
@ -184,7 +184,7 @@ Tensor rrelu_with_noise_cuda(
|
|||||||
|
|
||||||
Tensor& rrelu_with_noise_cuda_(
|
Tensor& rrelu_with_noise_cuda_(
|
||||||
Tensor& self,
|
Tensor& self,
|
||||||
const Tensor& noise,
|
Tensor& noise,
|
||||||
const Scalar& lower,
|
const Scalar& lower,
|
||||||
const Scalar& upper,
|
const Scalar& upper,
|
||||||
bool training,
|
bool training,
|
||||||
|
@ -12018,19 +12018,20 @@
|
|||||||
CUDA: log_sigmoid_backward_cuda
|
CUDA: log_sigmoid_backward_cuda
|
||||||
MPS: log_sigmoid_backward_mps
|
MPS: log_sigmoid_backward_mps
|
||||||
|
|
||||||
- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
|
- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
python_module: nn
|
python_module: nn
|
||||||
tags: nondeterministic_seeded
|
tags: nondeterministic_seeded
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: rrelu_with_noise_out_cpu
|
CPU: rrelu_with_noise_out_cpu
|
||||||
CUDA: rrelu_with_noise_out_cuda
|
CUDA: rrelu_with_noise_out_cuda
|
||||||
|
|
||||||
- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||||
python_module: nn
|
python_module: nn
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: rrelu_with_noise_cpu
|
CPU: rrelu_with_noise_cpu
|
||||||
CUDA: rrelu_with_noise_cuda
|
CUDA: rrelu_with_noise_cuda
|
||||||
tags: nondeterministic_seeded
|
tags: nondeterministic_seeded
|
||||||
|
autogen: rrelu_with_noise_functional
|
||||||
|
|
||||||
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
|
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
|
||||||
python_module: nn
|
python_module: nn
|
||||||
@ -12038,7 +12039,7 @@
|
|||||||
CompositeExplicitAutograd: rrelu_with_noise_backward
|
CompositeExplicitAutograd: rrelu_with_noise_backward
|
||||||
autogen: rrelu_with_noise_backward.out
|
autogen: rrelu_with_noise_backward.out
|
||||||
|
|
||||||
- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||||
python_module: nn
|
python_module: nn
|
||||||
tags: nondeterministic_seeded
|
tags: nondeterministic_seeded
|
||||||
dispatch:
|
dispatch:
|
||||||
|
@ -1125,6 +1125,10 @@ aten::resize_as_sparse_
|
|||||||
aten::row_indices
|
aten::row_indices
|
||||||
aten::row_indices_copy
|
aten::row_indices_copy
|
||||||
aten::row_indices_copy.out
|
aten::row_indices_copy.out
|
||||||
|
aten::rrelu_with_noise
|
||||||
|
aten::rrelu_with_noise.out
|
||||||
|
aten::rrelu_with_noise_
|
||||||
|
aten::rrelu_with_noise_functional
|
||||||
aten::scalar_tensor
|
aten::scalar_tensor
|
||||||
aten::scalar_tensor.out
|
aten::scalar_tensor.out
|
||||||
aten::scatter.reduce
|
aten::scatter.reduce
|
||||||
|
@ -137,6 +137,7 @@ ALLOW_LIST = [
|
|||||||
("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)),
|
("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)),
|
||||||
("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)),
|
("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)),
|
||||||
("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)),
|
("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)),
|
||||||
|
("aten::rrelu_with_noise", datetime.date(2024, 12, 31)),
|
||||||
]
|
]
|
||||||
|
|
||||||
ALLOW_LIST_COMPILED = [
|
ALLOW_LIST_COMPILED = [
|
||||||
|
@ -6327,6 +6327,42 @@ metadata incorrectly.
|
|||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
self.assertEqual(ref_x.grad, x.grad)
|
self.assertEqual(ref_x.grad, x.grad)
|
||||||
|
|
||||||
|
def test_rrelu_with_noise_mutation(self):
|
||||||
|
def fn_functional(x):
|
||||||
|
noise = torch.ones_like(x)
|
||||||
|
result, noise_out = torch.ops.aten.rrelu_with_noise_functional(
|
||||||
|
x, noise, 0.2, 0.8, True
|
||||||
|
)
|
||||||
|
return result, noise_out
|
||||||
|
|
||||||
|
def fn_mutation(x):
|
||||||
|
noise = torch.ones_like(x)
|
||||||
|
result = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.8, True)
|
||||||
|
return result, noise
|
||||||
|
|
||||||
|
def fn_inplace(x):
|
||||||
|
noise = torch.ones_like(x, requires_grad=False)
|
||||||
|
torch.ops.aten.rrelu_with_noise_(x, noise, 0.2, 0.8, True)
|
||||||
|
return x, noise
|
||||||
|
|
||||||
|
def _test_fn(fn, check_backward=True):
|
||||||
|
x = -torch.abs(torch.randn(4, 4, dtype=torch.bfloat16, requires_grad=True))
|
||||||
|
|
||||||
|
ref_y, ref_noise = fn(x)
|
||||||
|
self.assertTrue(torch.all(ref_noise < torch.ones_like(ref_noise)).item())
|
||||||
|
|
||||||
|
comp_y, comp_noise = torch.compile(fn, backend="inductor", fullgraph=True)(
|
||||||
|
x
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_backward:
|
||||||
|
comp_y.sum().backward()
|
||||||
|
self.assertTrue(torch.all(comp_noise < torch.ones_like(comp_noise)).item())
|
||||||
|
|
||||||
|
_test_fn(fn_functional)
|
||||||
|
_test_fn(fn_mutation)
|
||||||
|
_test_fn(fn_inplace, check_backward=False)
|
||||||
|
|
||||||
|
|
||||||
# entries in here don't work and need to be fixed.
|
# entries in here don't work and need to be fixed.
|
||||||
# Each one of these is a bug (or needs to be investigated)
|
# Each one of these is a bug (or needs to be investigated)
|
||||||
|
@ -656,50 +656,6 @@ class TestDecomp(TestCase):
|
|||||||
for dim in (-1, 0, 1):
|
for dim in (-1, 0, 1):
|
||||||
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
|
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
|
||||||
|
|
||||||
def test_rrelu_with_noise(self, device):
|
|
||||||
# rrelu_with_noise behavior depends on a) whether elements in the input
|
|
||||||
# are <= 0, and b) whether we're in training mode. Cover all cases:
|
|
||||||
dtype = torch.float64
|
|
||||||
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
|
|
||||||
lower = 1.0
|
|
||||||
upper = 4.0
|
|
||||||
training = False
|
|
||||||
|
|
||||||
torch.manual_seed(123)
|
|
||||||
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
|
|
||||||
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
|
|
||||||
|
|
||||||
torch.manual_seed(123)
|
|
||||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
|
||||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
|
||||||
x,
|
|
||||||
noise_res,
|
|
||||||
lower,
|
|
||||||
upper,
|
|
||||||
training,
|
|
||||||
)
|
|
||||||
self.assertEqual(ref, res)
|
|
||||||
self.assertEqual(noise_ref, noise_res)
|
|
||||||
|
|
||||||
# Now with training=True:
|
|
||||||
training = True
|
|
||||||
|
|
||||||
torch.manual_seed(123)
|
|
||||||
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
|
|
||||||
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
|
|
||||||
|
|
||||||
torch.manual_seed(123)
|
|
||||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
|
||||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
|
||||||
x,
|
|
||||||
noise_res,
|
|
||||||
lower,
|
|
||||||
upper,
|
|
||||||
training,
|
|
||||||
)
|
|
||||||
self.assertEqual(ref, res)
|
|
||||||
self.assertEqual(noise_ref, noise_res)
|
|
||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
@tf32_off()
|
@tf32_off()
|
||||||
# only tests RNNs since we have py dispsatcher decomps for them
|
# only tests RNNs since we have py dispsatcher decomps for them
|
||||||
|
@ -2182,13 +2182,17 @@
|
|||||||
result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t)
|
result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t)
|
||||||
result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p)
|
result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p)
|
||||||
|
|
||||||
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
- name: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||||
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
||||||
result: auto_element_wise
|
result: auto_element_wise
|
||||||
|
|
||||||
- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
- name: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||||
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
|
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
|
||||||
|
|
||||||
|
- name: rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out)
|
||||||
|
noise: non_differentiable
|
||||||
|
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
||||||
|
|
||||||
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
||||||
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
|
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||||
result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true))
|
result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true))
|
||||||
|
@ -305,44 +305,6 @@ def _prelu_kernel_backward(
|
|||||||
return (input_grad, weight_grad)
|
return (input_grad, weight_grad)
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.rrelu_with_noise)
|
|
||||||
@aten.rrelu_with_noise.default.py_impl(DispatchKey.Autograd)
|
|
||||||
@out_wrapper()
|
|
||||||
@pw_cast_for_opmath
|
|
||||||
def rrelu_with_noise(
|
|
||||||
self: Tensor,
|
|
||||||
noise: Tensor,
|
|
||||||
lower: float = 0.125,
|
|
||||||
upper: float = 0.3333333333333333,
|
|
||||||
training: bool = False,
|
|
||||||
generator: Optional[torch.Generator] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
assert generator is None
|
|
||||||
if training:
|
|
||||||
not_positive = self <= 0
|
|
||||||
r = aten.uniform(self, lower, upper)
|
|
||||||
output = torch.where(not_positive, self * r, self)
|
|
||||||
noise.copy_(torch.where(not_positive, r, 1))
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
negative_slope = (lower + upper) / 2
|
|
||||||
return aten.leaky_relu(self, negative_slope)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.rrelu_with_noise_)
|
|
||||||
@aten.rrelu_with_noise_.default.py_impl(DispatchKey.Autograd)
|
|
||||||
@pw_cast_for_opmath
|
|
||||||
def rrelu_with_noise_(
|
|
||||||
self: Tensor,
|
|
||||||
noise: Tensor,
|
|
||||||
lower: float = 0.125,
|
|
||||||
upper: float = 0.3333333333333333,
|
|
||||||
training: bool = False,
|
|
||||||
generator: Optional[torch.Generator] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator))
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.rrelu_with_noise_backward)
|
@register_decomposition(aten.rrelu_with_noise_backward)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
@pw_cast_for_opmath
|
@pw_cast_for_opmath
|
||||||
|
@ -76,6 +76,7 @@ inductor_decompositions = get_decompositions(
|
|||||||
aten.native_layer_norm,
|
aten.native_layer_norm,
|
||||||
aten.nll_loss2d_backward,
|
aten.nll_loss2d_backward,
|
||||||
aten.permute_copy,
|
aten.permute_copy,
|
||||||
|
aten.rrelu_with_noise_backward,
|
||||||
aten._softmax,
|
aten._softmax,
|
||||||
aten.sin_,
|
aten.sin_,
|
||||||
aten.sqrt_,
|
aten.sqrt_,
|
||||||
@ -1022,3 +1023,23 @@ def searchsorted_scalar(
|
|||||||
side=side,
|
side=side,
|
||||||
sorter=sorter,
|
sorter=sorter,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_decomposition(aten.rrelu_with_noise_functional)
|
||||||
|
def rrelu_with_noise_functional(
|
||||||
|
self: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
lower: float = 0.125,
|
||||||
|
upper: float = 0.3333333333333333,
|
||||||
|
training: bool = False,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if training:
|
||||||
|
not_positive = self <= 0
|
||||||
|
r = aten.uniform(self, lower, upper, generator=generator)
|
||||||
|
output = torch.where(not_positive, self * r, self)
|
||||||
|
noise_out = torch.where(not_positive, r, 1)
|
||||||
|
return output, noise_out
|
||||||
|
else:
|
||||||
|
negative_slope = (lower + upper) / 2
|
||||||
|
return aten.leaky_relu(self, negative_slope), torch.Tensor()
|
||||||
|
@ -3734,6 +3734,28 @@ def meta__add_relu(self, other, alpha=1) -> Tensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta([aten.rrelu_with_noise])
|
||||||
|
@out_wrapper()
|
||||||
|
def meta_rrelu_with_noise(
|
||||||
|
self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
|
||||||
|
):
|
||||||
|
return torch.empty_like(self)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta([aten.rrelu_with_noise_functional])
|
||||||
|
def meta_rrelu_with_noise_functional(
|
||||||
|
self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
|
||||||
|
):
|
||||||
|
return torch.empty_like(self), torch.empty_like(noise)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta([aten.rrelu_with_noise_])
|
||||||
|
def meta_rrelu_with_noise_(
|
||||||
|
self, lower=0.125, upper=0.3333333333333333, training=False, generator=None
|
||||||
|
):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
|
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
|
||||||
def meta_index_put(self, indices, values, accumulate=False):
|
def meta_index_put(self, indices, values, accumulate=False):
|
||||||
return torch.empty_like(self)
|
return torch.empty_like(self)
|
||||||
|
@ -448,10 +448,10 @@ def add_generated_native_functions(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
base_fn = (
|
base_fn = (
|
||||||
d[SchemaKind.inplace]
|
d[SchemaKind.mutable]
|
||||||
if has_inplace
|
|
||||||
else d[SchemaKind.mutable]
|
|
||||||
if has_mutable
|
if has_mutable
|
||||||
|
else d[SchemaKind.inplace]
|
||||||
|
if has_inplace
|
||||||
else d[SchemaKind.out]
|
else d[SchemaKind.out]
|
||||||
if has_out
|
if has_out
|
||||||
else d[SchemaKind.functional]
|
else d[SchemaKind.functional]
|
||||||
|
Reference in New Issue
Block a user