mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable fp32/bf16 PRelu forward and backward in MkldnnCPU path (#60427)
Enable fp32/bf16 PRelu forward and backward in MkldnnCPU path. Fixes https://github.com/pytorch/pytorch/issues/58896 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60427 Approved by: https://github.com/VitalyFedyunin, https://github.com/ngimel, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
8d4e069e66
commit
cd33e412a2
79
aten/src/ATen/native/mkldnn/Prelu.cpp
Normal file
79
aten/src/ATen/native/mkldnn/Prelu.cpp
Normal file
@ -0,0 +1,79 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Config.h>
|
||||
|
||||
|
||||
#if !AT_MKLDNN_ENABLED()
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) {
|
||||
TORCH_CHECK(false, "mkldnn_prelu: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> mkldnn_prelu_backward(const Tensor& grad_output, const Tensor& input, const Tensor& weight) {
|
||||
TORCH_CHECK(false, "mkldnn_prelu_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
#else // AT_MKLDNN_EBABLED
|
||||
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#include <ATen/native/mkldnn/Utils.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) {
|
||||
if (input.scalar_type() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(mkldnn_bf16_device_check(),
|
||||
"mkldnn_relu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
|
||||
}
|
||||
|
||||
int64_t weight_num = weight.numel();
|
||||
if (weight_num != 1) {
|
||||
int64_t channel_size = input.dim() > 1 ? input.size(1) : 1;
|
||||
TORCH_CHECK(channel_size == weight_num,
|
||||
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
|
||||
" and channel size = ", channel_size, ".");
|
||||
}
|
||||
const ideep::tensor& x = itensor_from_mkldnn(input);
|
||||
const ideep::tensor& w = itensor_from_tensor(weight);
|
||||
|
||||
ideep::tensor y;
|
||||
ideep::prelu_forward::compute(
|
||||
x, w, y, ideep::prop_kind::forward_training);
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> mkldnn_prelu_backward(const Tensor& grad_output, const Tensor& input, const Tensor& weight) {
|
||||
const ideep::tensor& x = itensor_from_mkldnn(input);
|
||||
const ideep::tensor& w = itensor_from_tensor(weight);
|
||||
const ideep::tensor grady = itensor_from_mkldnn(grad_output);
|
||||
ideep::tensor gradx;
|
||||
ideep::tensor gradw;
|
||||
|
||||
ideep::prelu_backward::compute(
|
||||
x, w, grady, gradx, gradw, ideep::prop_kind::backward);
|
||||
if (weight.is_mkldnn()) {
|
||||
return std::make_tuple(
|
||||
new_with_itensor_mkldnn(std::move(gradx),
|
||||
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
|
||||
grad_output.options().device_opt()),
|
||||
new_with_itensor_mkldnn(std::move(gradw),
|
||||
optTypeMetaToScalarType(weight.options().dtype_opt()),
|
||||
weight.options().device_opt()));
|
||||
} else {
|
||||
return std::make_tuple(
|
||||
new_with_itensor_mkldnn(std::move(gradx),
|
||||
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
|
||||
grad_output.options().device_opt()),
|
||||
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradw),
|
||||
optTypeMetaToScalarType(weight.options().dtype_opt()),
|
||||
weight.options().device_opt())));
|
||||
}
|
||||
}
|
||||
}}
|
||||
|
||||
#endif // AT_MKLDNN_EBABLED
|
@ -3790,12 +3790,14 @@
|
||||
- func: prelu(Tensor self, Tensor weight) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
MkldnnCPU: mkldnn_prelu
|
||||
CPU: prelu_cpu
|
||||
CUDA: prelu_cuda
|
||||
|
||||
- func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
MkldnnCPU: mkldnn_prelu_backward
|
||||
CPU: prelu_backward_cpu
|
||||
CUDA: prelu_backward_cuda
|
||||
|
||||
|
@ -450,6 +450,74 @@ class TestMkldnn(TestCase):
|
||||
msg,
|
||||
lambda: m(x2))
|
||||
|
||||
def _test_prelu_base(self, size, num_channels):
|
||||
x = torch.randn(size, dtype=torch.float32)
|
||||
x1 = x.clone().requires_grad_()
|
||||
x2 = x.clone().to_mkldnn().requires_grad_()
|
||||
x3 = x.clone().to_mkldnn().requires_grad_()
|
||||
m1 = torch.nn.PReLU(num_channels)
|
||||
m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1))
|
||||
m3 = copy.deepcopy(m1)
|
||||
y1 = m1(x1)
|
||||
y2 = m2(x2).to_dense()
|
||||
y3 = m3(x3).to_dense() # Only convert data to mkldnn, weight is Aten tensor
|
||||
loss1 = y1.sum()
|
||||
loss1.backward()
|
||||
loss2 = y2.sum()
|
||||
loss2.backward()
|
||||
loss3 = y3.sum()
|
||||
loss3.backward()
|
||||
self.assertEqual(y1, y2)
|
||||
self.assertEqual(y1, y3)
|
||||
self.assertEqual(x1.grad, x2.grad.to_dense())
|
||||
self.assertEqual(x1.grad, x3.grad.to_dense())
|
||||
|
||||
def test_prelu(self):
|
||||
self._test_prelu_base(torch.Size([16]), 1)
|
||||
self._test_prelu_base(torch.Size([16, 64]), 1)
|
||||
self._test_prelu_base(torch.Size([16, 64]), 64)
|
||||
self._test_prelu_base(torch.Size([16, 64, 112]), 1)
|
||||
self._test_prelu_base(torch.Size([16, 64, 112]), 64)
|
||||
self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1)
|
||||
self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64)
|
||||
self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1)
|
||||
self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def _test_prelu_bf16_base(self, size, num_channels):
|
||||
if has_bf16_support():
|
||||
x = torch.randn(size, dtype=torch.float32)
|
||||
x_fp32 = x.clone().to_mkldnn().requires_grad_()
|
||||
x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
|
||||
m = mkldnn_utils.to_mkldnn(torch.nn.PReLU())
|
||||
m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
|
||||
|
||||
y = m(x_fp32).to_dense()
|
||||
y_bf16 = m_bf16(x_bf16).to_dense()
|
||||
self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3)
|
||||
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
loss_bf16 = y_bf16.sum()
|
||||
loss_bf16.backward()
|
||||
self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32))
|
||||
else:
|
||||
x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_()
|
||||
m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
|
||||
msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
msg,
|
||||
lambda: m_bf16(x_bf16))
|
||||
|
||||
def test_prelu_bf16(self):
|
||||
self._test_prelu_bf16_base(torch.Size([16]), 1)
|
||||
self._test_prelu_bf16_base(torch.Size([16, 64]), 1)
|
||||
self._test_prelu_bf16_base(torch.Size([16, 64]), 64)
|
||||
self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1)
|
||||
self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64)
|
||||
self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1)
|
||||
self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64)
|
||||
|
||||
def _test_max_pool_base(self, dim, input):
|
||||
pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
|
||||
for stride in [1, 2, 3]:
|
||||
|
@ -1079,6 +1079,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp",
|
||||
"aten/src/ATen/native/mkldnn/Normalization.cpp",
|
||||
"aten/src/ATen/native/mkldnn/Pooling.cpp",
|
||||
"aten/src/ATen/native/mkldnn/Prelu.cpp",
|
||||
"aten/src/ATen/native/mkldnn/Relu.cpp",
|
||||
"aten/src/ATen/native/mkldnn/SoftMax.cpp",
|
||||
"aten/src/ATen/native/mkldnn/TensorFactories.cpp",
|
||||
|
@ -13,6 +13,9 @@ def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tens
|
||||
def mkldnn_reorder_conv2d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
|
||||
def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
|
||||
|
||||
# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
|
||||
def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
|
||||
|
||||
# Defined at tools/autograd/templates/python_nn_functions.cpp
|
||||
@overload
|
||||
def _parse_to(device: _device, dtype: _dtype, non_blocking: _bool, copy: _bool, *,
|
||||
|
@ -180,6 +180,26 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
|
||||
False, # cuda_enabled
|
||||
)
|
||||
|
||||
class MkldnnPrelu(torch.jit.ScriptModule):
|
||||
def __init__(self, dense_module, dtype):
|
||||
super(MkldnnPrelu, self).__init__()
|
||||
self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
|
||||
|
||||
@torch.jit.script_method
|
||||
def __getstate__(self):
|
||||
return (self.weight.to_dense(), self.training)
|
||||
|
||||
@torch.jit.script_method
|
||||
def __setstate__(self, state):
|
||||
self.weight = state[0].to_mkldnn()
|
||||
self.training = state[1]
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
|
||||
y_mkldnn = torch.prelu(x_mkldnn, self.weight)
|
||||
y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
|
||||
return y
|
||||
|
||||
def to_mkldnn(module, dtype=torch.float):
|
||||
assert dtype in [torch.float, torch.bfloat16], \
|
||||
@ -198,6 +218,8 @@ def to_mkldnn(module, dtype=torch.float):
|
||||
# For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
|
||||
# so it doesn't need dtype argument.
|
||||
return MkldnnBatchNorm(m)
|
||||
elif isinstance(m, torch.nn.PReLU):
|
||||
return MkldnnPrelu(m, d)
|
||||
else:
|
||||
return m
|
||||
|
||||
|
Reference in New Issue
Block a user