Add mkldnn sigmoid operator

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20820

Reviewed By: dzhulgakov

Differential Revision: D15455866

fbshipit-source-id: 712b06dfbd441051dc284a1acdf94926df09bc1d
This commit is contained in:
Junjie Bai
2019-05-23 12:46:08 -07:00
committed by Facebook Github Bot
parent 8dedb04c26
commit 70caa2efe2
3 changed files with 63 additions and 0 deletions

View File

@ -0,0 +1,46 @@
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/NativeFunctions.h>
#if !AT_MKLDNN_ENABLED()
namespace at {
namespace native {
Tensor mkldnn_sigmoid(const Tensor& self) {
AT_ERROR("mkldnn_sigmoid: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_sigmoid_(Tensor& self) {
AT_ERROR("mkldnn_sigmoid_: ATen not compiled with MKLDNN support");
}
} // namespace native
} // namespace at
#else // AT_MKLDNN_EBABLED
#include <ATen/native/mkldnn/MKLDNNCommon.h>
namespace at {
namespace native {
Tensor mkldnn_sigmoid(const Tensor& self) {
ideep::tensor& x = itensor_from_mkldnn(self);
ideep::tensor y;
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward);
return new_with_itensor_mkldnn(std::move(y), self.options());
}
Tensor& mkldnn_sigmoid_(Tensor& self) {
ideep::tensor& x = itensor_from_mkldnn(self);
ideep::eltwise_forward::compute(
x, x, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward);
return self;
}
} // namespace native
} // namespace at
#endif // AT_MKLDNN_EBABLED

View File

@ -1593,12 +1593,17 @@
- func: sigmoid(Tensor self) -> Tensor
variants: function, method
dispatch:
CPU: sigmoid
CUDA: sigmoid
MkldnnCPU: mkldnn_sigmoid
- func: sigmoid_(Tensor(a!) self) -> Tensor(a!)
variants: function, method
dispatch:
CPU: _sigmoid__cpu
CUDA: _sigmoid__cuda
MkldnnCPU: mkldnn_sigmoid_
- func: sigmoid(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:

View File

@ -248,6 +248,18 @@ class TestMkldnn(TestCase):
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
def test_sigmoid(self):
x = torch.randn(4, 5, dtype=torch.float32) * 10
mkldnn_x = x.to_mkldnn()
self.assertEqual(
torch.sigmoid(x),
torch.sigmoid(mkldnn_x).to_dense(),
)
# inplace
torch.sigmoid_(x)
torch.sigmoid_(mkldnn_x)
self.assertEqual(x, mkldnn_x.to_dense())
def _test_serialization(self, module, inputs):
with TemporaryFileName() as fname:
torch.jit.save(module, fname)