mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
disable leaky_relu_ backward calculation with negative slope (#33639)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33639 Test Plan: Imported from OSS Differential Revision: D20045735 Pulled By: glaringlee fbshipit-source-id: b3becf30a8fe9ee178792bd88f6ee10102504ed5
This commit is contained in:
committed by
Facebook Github Bot
parent
997b5b5797
commit
d66c320b10
@ -165,7 +165,7 @@ inline void _rrelu_with_noise_train(
|
||||
output.copy_(tmp_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Tensor& rrelu_with_noise_out_cpu(
|
||||
Tensor& output,
|
||||
const Tensor& self,
|
||||
@ -209,33 +209,14 @@ Tensor& rrelu_with_noise_cpu_(
|
||||
return at::native::rrelu_with_noise_out_cpu(self, self, noise, lower, upper, training, generator);
|
||||
}
|
||||
|
||||
Tensor& rrelu_with_noise_backward_out(
|
||||
Tensor& grad_input,
|
||||
const Tensor& grad_output,
|
||||
const Tensor& self,
|
||||
const Tensor& noise,
|
||||
Scalar lower,
|
||||
Scalar upper,
|
||||
bool training) {
|
||||
auto lower_tensor = scalar_to_tensor(lower, grad_output.device());
|
||||
auto upper_tensor = scalar_to_tensor(upper, grad_output.device());
|
||||
if (training && (upper_tensor - lower_tensor).item().to<float>() > 1E-6) {
|
||||
grad_input = grad_output.mul(noise);
|
||||
} else {
|
||||
auto negative = (lower_tensor + upper_tensor) / 2;
|
||||
Scalar negative_slope = negative.item();
|
||||
grad_input = at::leaky_relu_backward(grad_output, self, negative_slope);
|
||||
}
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor rrelu_with_noise_backward(
|
||||
const Tensor& grad_output,
|
||||
const Tensor& self,
|
||||
const Tensor& self_or_result,
|
||||
const Tensor& noise,
|
||||
Scalar lower,
|
||||
Scalar upper,
|
||||
bool training) {
|
||||
bool training,
|
||||
bool is_result) {
|
||||
auto lower_tensor = scalar_to_tensor(lower, grad_output.device());
|
||||
auto upper_tensor = scalar_to_tensor(upper, grad_output.device());
|
||||
if (training && (upper_tensor - lower_tensor).item().to<float>() > 1E-6) {
|
||||
@ -243,8 +224,8 @@ Tensor rrelu_with_noise_backward(
|
||||
} else {
|
||||
auto negative = (lower_tensor + upper_tensor) / 2;
|
||||
Scalar negative_slope = negative.item();
|
||||
return at::leaky_relu_backward(grad_output, self, negative_slope);
|
||||
}
|
||||
return at::leaky_relu_backward(grad_output, self_or_result, negative_slope, is_result);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor rrelu(const Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
|
||||
@ -663,22 +644,26 @@ Tensor & leaky_relu_(
|
||||
return at::leaky_relu_out(self, self, neg_val);
|
||||
}
|
||||
|
||||
Tensor& leaky_relu_backward_out(
|
||||
Tensor& grad_input,
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
Scalar negval) {
|
||||
auto iter = TensorIterator::binary_op(grad_input, input, grad_output);
|
||||
leaky_relu_backward_stub(iter.device_type(), iter, negval);
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
// Note: leakyReLu backward calculation doesn't support in-place call with non-positive slope.
|
||||
// The reason is that for in-place forward call, the forward result will be saved into autograd
|
||||
// node instead of the input itself, when calculating backward gradient, there is no way to know
|
||||
// whether the original input for current node is positive or not if the input slope is
|
||||
// non-positive. eg. forward is 2, slope is -0.2, the original input for this node could be
|
||||
// either 2, or -10, so no way to get a correct backward gradient in this case.
|
||||
Tensor leaky_relu_backward(
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
Scalar negval) {
|
||||
const Tensor& self_or_result,
|
||||
Scalar negval,
|
||||
bool is_result) {
|
||||
TORCH_CHECK(
|
||||
!is_result || negval.to<double>() > 0.0,
|
||||
"In-place leakyReLu backward calculation is triggered with a non-positive slope which is not supported. "
|
||||
"This is caused by calling in-place forward function with a non-positive slope, "
|
||||
"please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do "
|
||||
"require supporting in-place leakRelu backward calculation with non-positive slope");
|
||||
|
||||
Tensor result;
|
||||
auto iter = TensorIterator::binary_op(result, input, grad_output);
|
||||
auto iter = TensorIterator::binary_op(result, self_or_result, grad_output);
|
||||
leaky_relu_backward_stub(iter.device_type(), iter, negval);
|
||||
return iter.output();
|
||||
}
|
||||
|
@ -5610,13 +5610,7 @@
|
||||
CUDA: leaky_relu
|
||||
QuantizedCPU: quantized_leaky_relu
|
||||
|
||||
- func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: leaky_relu_backward_out
|
||||
CUDA: leaky_relu_backward_out
|
||||
|
||||
- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope) -> Tensor
|
||||
- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
python_module: nn
|
||||
|
||||
@ -5671,13 +5665,7 @@
|
||||
CPU: rrelu_with_noise_cpu
|
||||
CUDA: legacy::cuda::_thnn_rrelu_with_noise_forward
|
||||
|
||||
- func: rrelu_with_noise_backward.grad_input(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: rrelu_with_noise_backward_out
|
||||
CUDA: rrelu_with_noise_backward_out
|
||||
|
||||
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training) -> Tensor
|
||||
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
python_module: nn
|
||||
|
||||
|
@ -57,6 +57,8 @@ white_list = [
|
||||
('quantized::add_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
|
||||
('quantized::cat_(relu_)?out', datetime.date(2020, 3, 1)),
|
||||
('quantized::mul_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
|
||||
('aten::leaky_relu_backward', datetime.date(2020, 3, 6)),
|
||||
('aten::rrelu_with_noise_backward', datetime.date(2020, 3, 6)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -3886,6 +3886,18 @@ for shape in [(1,), ()]:
|
||||
with self.assertRaisesRegex(RuntimeError, "must implement the backward"):
|
||||
BadBw.apply(inp).sum().backward()
|
||||
|
||||
def test_leaky_relu_inplace_with_neg_slope(self):
|
||||
for device in torch.testing.get_all_device_types():
|
||||
a = torch.tensor([-1., 1.], device=device, requires_grad=True)
|
||||
b = torch.nn.functional.leaky_relu_(a.clone(), -2)
|
||||
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
|
||||
b.backward(torch.ones(2, device=device))
|
||||
|
||||
a = torch.tensor([-1., 1.], device=device, requires_grad=True)
|
||||
b = torch.nn.functional.rrelu_(a.clone(), -5.0, 1.0)
|
||||
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
|
||||
b.backward(torch.ones(2, device=device))
|
||||
|
||||
def index_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
shape = (shape,)
|
||||
|
@ -1081,10 +1081,10 @@
|
||||
self: hardtanh_backward(grad, result, min_val, max_val)
|
||||
|
||||
- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
|
||||
self: leaky_relu_backward(grad, self, negative_slope)
|
||||
self: leaky_relu_backward(grad, self, negative_slope, false)
|
||||
|
||||
- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)
|
||||
self: leaky_relu_backward(grad, result, negative_slope)
|
||||
self: leaky_relu_backward(grad, result, negative_slope, true)
|
||||
|
||||
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
|
||||
self: log_sigmoid_backward(grad, self, buffer)
|
||||
@ -1099,10 +1099,10 @@
|
||||
grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight)
|
||||
|
||||
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
|
||||
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
||||
|
||||
- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training)
|
||||
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
|
||||
|
||||
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
||||
self: _softmax_backward_data(grad, result, dim, self)
|
||||
@ -1313,8 +1313,9 @@
|
||||
grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
|
||||
self: log_softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(self.dtype())
|
||||
|
||||
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope) -> Tensor
|
||||
grad_output: leaky_relu_backward(grad, self, negative_slope)
|
||||
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
|
||||
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
|
||||
grad_output: leaky_relu_backward(grad, self, negative_slope, false)
|
||||
self: zeros_like(grad, at::MemoryFormat::Preserve)
|
||||
|
||||
- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
|
||||
@ -1346,8 +1347,9 @@
|
||||
self: zeros_like(grad, at::MemoryFormat::Preserve)
|
||||
target: non_differentiable
|
||||
|
||||
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training) -> Tensor
|
||||
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
|
||||
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
|
||||
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
|
||||
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
|
||||
self: zeros_like(grad, at::MemoryFormat::Preserve)
|
||||
|
||||
- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor
|
||||
|
Reference in New Issue
Block a user