Enable fp32/bf16 PRelu forward and backward in MkldnnCPU path (#60427)

Enable fp32/bf16 PRelu forward and backward in MkldnnCPU path.

Fixes https://github.com/pytorch/pytorch/issues/58896

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60427
Approved by: https://github.com/VitalyFedyunin, https://github.com/ngimel, https://github.com/malfet
This commit is contained in:
yanbing-j
2022-05-10 17:29:11 +00:00
committed by PyTorch MergeBot
parent 8d4e069e66
commit cd33e412a2
6 changed files with 175 additions and 0 deletions

View File

@ -0,0 +1,79 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#if !AT_MKLDNN_ENABLED()
namespace at { namespace native {
Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) {
TORCH_CHECK(false, "mkldnn_prelu: ATen not compiled with MKLDNN support");
}
std::tuple<Tensor, Tensor> mkldnn_prelu_backward(const Tensor& grad_output, const Tensor& input, const Tensor& weight) {
TORCH_CHECK(false, "mkldnn_prelu_backward: ATen not compiled with MKLDNN support");
}
}}
#else // AT_MKLDNN_EBABLED
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
namespace at { namespace native {
Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) {
if (input.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_relu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
int64_t weight_num = weight.numel();
if (weight_num != 1) {
int64_t channel_size = input.dim() > 1 ? input.size(1) : 1;
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");
}
const ideep::tensor& x = itensor_from_mkldnn(input);
const ideep::tensor& w = itensor_from_tensor(weight);
ideep::tensor y;
ideep::prelu_forward::compute(
x, w, y, ideep::prop_kind::forward_training);
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
input.options().device_opt());
}
std::tuple<Tensor, Tensor> mkldnn_prelu_backward(const Tensor& grad_output, const Tensor& input, const Tensor& weight) {
const ideep::tensor& x = itensor_from_mkldnn(input);
const ideep::tensor& w = itensor_from_tensor(weight);
const ideep::tensor grady = itensor_from_mkldnn(grad_output);
ideep::tensor gradx;
ideep::tensor gradw;
ideep::prelu_backward::compute(
x, w, grady, gradx, gradw, ideep::prop_kind::backward);
if (weight.is_mkldnn()) {
return std::make_tuple(
new_with_itensor_mkldnn(std::move(gradx),
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
grad_output.options().device_opt()),
new_with_itensor_mkldnn(std::move(gradw),
optTypeMetaToScalarType(weight.options().dtype_opt()),
weight.options().device_opt()));
} else {
return std::make_tuple(
new_with_itensor_mkldnn(std::move(gradx),
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
grad_output.options().device_opt()),
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradw),
optTypeMetaToScalarType(weight.options().dtype_opt()),
weight.options().device_opt())));
}
}
}}
#endif // AT_MKLDNN_EBABLED

View File

@ -3790,12 +3790,14 @@
- func: prelu(Tensor self, Tensor weight) -> Tensor
variants: function, method
dispatch:
MkldnnCPU: mkldnn_prelu
CPU: prelu_cpu
CUDA: prelu_cuda
- func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
variants: function, method
dispatch:
MkldnnCPU: mkldnn_prelu_backward
CPU: prelu_backward_cpu
CUDA: prelu_backward_cuda

View File

@ -450,6 +450,74 @@ class TestMkldnn(TestCase):
msg,
lambda: m(x2))
def _test_prelu_base(self, size, num_channels):
x = torch.randn(size, dtype=torch.float32)
x1 = x.clone().requires_grad_()
x2 = x.clone().to_mkldnn().requires_grad_()
x3 = x.clone().to_mkldnn().requires_grad_()
m1 = torch.nn.PReLU(num_channels)
m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1))
m3 = copy.deepcopy(m1)
y1 = m1(x1)
y2 = m2(x2).to_dense()
y3 = m3(x3).to_dense() # Only convert data to mkldnn, weight is Aten tensor
loss1 = y1.sum()
loss1.backward()
loss2 = y2.sum()
loss2.backward()
loss3 = y3.sum()
loss3.backward()
self.assertEqual(y1, y2)
self.assertEqual(y1, y3)
self.assertEqual(x1.grad, x2.grad.to_dense())
self.assertEqual(x1.grad, x3.grad.to_dense())
def test_prelu(self):
self._test_prelu_base(torch.Size([16]), 1)
self._test_prelu_base(torch.Size([16, 64]), 1)
self._test_prelu_base(torch.Size([16, 64]), 64)
self._test_prelu_base(torch.Size([16, 64, 112]), 1)
self._test_prelu_base(torch.Size([16, 64, 112]), 64)
self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1)
self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64)
self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1)
self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def _test_prelu_bf16_base(self, size, num_channels):
if has_bf16_support():
x = torch.randn(size, dtype=torch.float32)
x_fp32 = x.clone().to_mkldnn().requires_grad_()
x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
m = mkldnn_utils.to_mkldnn(torch.nn.PReLU())
m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
y = m(x_fp32).to_dense()
y_bf16 = m_bf16(x_bf16).to_dense()
self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3)
loss = y.sum()
loss.backward()
loss_bf16 = y_bf16.sum()
loss_bf16.backward()
self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32))
else:
x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_()
m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
self.assertRaisesRegex(RuntimeError,
msg,
lambda: m_bf16(x_bf16))
def test_prelu_bf16(self):
self._test_prelu_bf16_base(torch.Size([16]), 1)
self._test_prelu_bf16_base(torch.Size([16, 64]), 1)
self._test_prelu_bf16_base(torch.Size([16, 64]), 64)
self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1)
self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64)
self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1)
self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64)
def _test_max_pool_base(self, dim, input):
pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
for stride in [1, 2, 3]:

View File

@ -1079,6 +1079,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp",
"aten/src/ATen/native/mkldnn/Normalization.cpp",
"aten/src/ATen/native/mkldnn/Pooling.cpp",
"aten/src/ATen/native/mkldnn/Prelu.cpp",
"aten/src/ATen/native/mkldnn/Relu.cpp",
"aten/src/ATen/native/mkldnn/SoftMax.cpp",
"aten/src/ATen/native/mkldnn/TensorFactories.cpp",

View File

@ -13,6 +13,9 @@ def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tens
def mkldnn_reorder_conv2d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
# Defined at tools/autograd/templates/python_nn_functions.cpp
@overload
def _parse_to(device: _device, dtype: _dtype, non_blocking: _bool, copy: _bool, *,

View File

@ -180,6 +180,26 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
False, # cuda_enabled
)
class MkldnnPrelu(torch.jit.ScriptModule):
def __init__(self, dense_module, dtype):
super(MkldnnPrelu, self).__init__()
self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.training = state[1]
@torch.jit.script_method
def forward(self, x):
x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
y_mkldnn = torch.prelu(x_mkldnn, self.weight)
y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
return y
def to_mkldnn(module, dtype=torch.float):
assert dtype in [torch.float, torch.bfloat16], \
@ -198,6 +218,8 @@ def to_mkldnn(module, dtype=torch.float):
# For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
# so it doesn't need dtype argument.
return MkldnnBatchNorm(m)
elif isinstance(m, torch.nn.PReLU):
return MkldnnPrelu(m, d)
else:
return m