disable leaky_relu_ backward calculation with negative slope (#33639)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33639

Test Plan: Imported from OSS

Differential Revision: D20045735

Pulled By: glaringlee

fbshipit-source-id: b3becf30a8fe9ee178792bd88f6ee10102504ed5
This commit is contained in:
lixinyu
2020-02-27 18:51:44 -08:00
committed by Facebook Github Bot
parent 997b5b5797
commit d66c320b10
5 changed files with 49 additions and 60 deletions

View File

@ -165,7 +165,7 @@ inline void _rrelu_with_noise_train(
output.copy_(tmp_tensor);
}
}
Tensor& rrelu_with_noise_out_cpu(
Tensor& output,
const Tensor& self,
@ -209,33 +209,14 @@ Tensor& rrelu_with_noise_cpu_(
return at::native::rrelu_with_noise_out_cpu(self, self, noise, lower, upper, training, generator);
}
Tensor& rrelu_with_noise_backward_out(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& self,
const Tensor& noise,
Scalar lower,
Scalar upper,
bool training) {
auto lower_tensor = scalar_to_tensor(lower, grad_output.device());
auto upper_tensor = scalar_to_tensor(upper, grad_output.device());
if (training && (upper_tensor - lower_tensor).item().to<float>() > 1E-6) {
grad_input = grad_output.mul(noise);
} else {
auto negative = (lower_tensor + upper_tensor) / 2;
Scalar negative_slope = negative.item();
grad_input = at::leaky_relu_backward(grad_output, self, negative_slope);
}
return grad_input;
}
Tensor rrelu_with_noise_backward(
const Tensor& grad_output,
const Tensor& self,
const Tensor& self_or_result,
const Tensor& noise,
Scalar lower,
Scalar upper,
bool training) {
bool training,
bool is_result) {
auto lower_tensor = scalar_to_tensor(lower, grad_output.device());
auto upper_tensor = scalar_to_tensor(upper, grad_output.device());
if (training && (upper_tensor - lower_tensor).item().to<float>() > 1E-6) {
@ -243,8 +224,8 @@ Tensor rrelu_with_noise_backward(
} else {
auto negative = (lower_tensor + upper_tensor) / 2;
Scalar negative_slope = negative.item();
return at::leaky_relu_backward(grad_output, self, negative_slope);
}
return at::leaky_relu_backward(grad_output, self_or_result, negative_slope, is_result);
}
}
Tensor rrelu(const Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
@ -663,22 +644,26 @@ Tensor & leaky_relu_(
return at::leaky_relu_out(self, self, neg_val);
}
Tensor& leaky_relu_backward_out(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
Scalar negval) {
auto iter = TensorIterator::binary_op(grad_input, input, grad_output);
leaky_relu_backward_stub(iter.device_type(), iter, negval);
return grad_input;
}
// Note: leakyReLu backward calculation doesn't support in-place call with non-positive slope.
// The reason is that for in-place forward call, the forward result will be saved into autograd
// node instead of the input itself, when calculating backward gradient, there is no way to know
// whether the original input for current node is positive or not if the input slope is
// non-positive. eg. forward is 2, slope is -0.2, the original input for this node could be
// either 2, or -10, so no way to get a correct backward gradient in this case.
Tensor leaky_relu_backward(
const Tensor& grad_output,
const Tensor& input,
Scalar negval) {
const Tensor& self_or_result,
Scalar negval,
bool is_result) {
TORCH_CHECK(
!is_result || negval.to<double>() > 0.0,
"In-place leakyReLu backward calculation is triggered with a non-positive slope which is not supported. "
"This is caused by calling in-place forward function with a non-positive slope, "
"please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do "
"require supporting in-place leakRelu backward calculation with non-positive slope");
Tensor result;
auto iter = TensorIterator::binary_op(result, input, grad_output);
auto iter = TensorIterator::binary_op(result, self_or_result, grad_output);
leaky_relu_backward_stub(iter.device_type(), iter, negval);
return iter.output();
}

View File

@ -5610,13 +5610,7 @@
CUDA: leaky_relu
QuantizedCPU: quantized_leaky_relu
- func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
CPU: leaky_relu_backward_out
CUDA: leaky_relu_backward_out
- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope) -> Tensor
- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
use_c10_dispatcher: full
python_module: nn
@ -5671,13 +5665,7 @@
CPU: rrelu_with_noise_cpu
CUDA: legacy::cuda::_thnn_rrelu_with_noise_forward
- func: rrelu_with_noise_backward.grad_input(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
CPU: rrelu_with_noise_backward_out
CUDA: rrelu_with_noise_backward_out
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training) -> Tensor
- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
use_c10_dispatcher: full
python_module: nn

View File

@ -57,6 +57,8 @@ white_list = [
('quantized::add_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
('quantized::cat_(relu_)?out', datetime.date(2020, 3, 1)),
('quantized::mul_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
('aten::leaky_relu_backward', datetime.date(2020, 3, 6)),
('aten::rrelu_with_noise_backward', datetime.date(2020, 3, 6)),
]

View File

@ -3886,6 +3886,18 @@ for shape in [(1,), ()]:
with self.assertRaisesRegex(RuntimeError, "must implement the backward"):
BadBw.apply(inp).sum().backward()
def test_leaky_relu_inplace_with_neg_slope(self):
for device in torch.testing.get_all_device_types():
a = torch.tensor([-1., 1.], device=device, requires_grad=True)
b = torch.nn.functional.leaky_relu_(a.clone(), -2)
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
b.backward(torch.ones(2, device=device))
a = torch.tensor([-1., 1.], device=device, requires_grad=True)
b = torch.nn.functional.rrelu_(a.clone(), -5.0, 1.0)
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
b.backward(torch.ones(2, device=device))
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)

View File

@ -1081,10 +1081,10 @@
self: hardtanh_backward(grad, result, min_val, max_val)
- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
self: leaky_relu_backward(grad, self, negative_slope)
self: leaky_relu_backward(grad, self, negative_slope, false)
- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)
self: leaky_relu_backward(grad, result, negative_slope)
self: leaky_relu_backward(grad, result, negative_slope, true)
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
@ -1099,10 +1099,10 @@
grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight)
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training)
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _softmax_backward_data(grad, result, dim, self)
@ -1313,8 +1313,9 @@
grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
self: log_softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(self.dtype())
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope) -> Tensor
grad_output: leaky_relu_backward(grad, self, negative_slope)
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
grad_output: leaky_relu_backward(grad, self, negative_slope, false)
self: zeros_like(grad, at::MemoryFormat::Preserve)
- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
@ -1346,8 +1347,9 @@
self: zeros_like(grad, at::MemoryFormat::Preserve)
target: non_differentiable
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training) -> Tensor
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
self: zeros_like(grad, at::MemoryFormat::Preserve)
- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor