Enable torch.jit.trace for mkldnn modules

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20800

Differential Revision: D15447892

fbshipit-source-id: 78e76523c5412c020a2bc22d6998ff7b36356720
This commit is contained in:
Junjie Bai
2019-05-23 12:46:08 -07:00
committed by Facebook Github Bot
parent 63585c3b81
commit 8dedb04c26
4 changed files with 25 additions and 3 deletions

View File

@ -112,6 +112,7 @@ class TestMkldnn(TestCase):
mkldnn_conv2d(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
self._test_tracing(mkldnn_conv2d, (x.to_mkldnn(),))
def test_relu(self):
x = torch.randn((4, 5), dtype=torch.float32) * 10
@ -177,6 +178,7 @@ class TestMkldnn(TestCase):
mkldnn_bn(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))
self._test_tracing(mkldnn_bn, (x.to_mkldnn(),))
def test_add(self):
N = torch.randint(3, 10, (1,)).item()
@ -244,6 +246,7 @@ class TestMkldnn(TestCase):
mkldnn_linear(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
def _test_serialization(self, module, inputs):
with TemporaryFileName() as fname:
@ -253,6 +256,12 @@ class TestMkldnn(TestCase):
module(*inputs).to_dense(),
loaded(*inputs).to_dense())
def _test_tracing(self, module, inputs):
traced = torch.jit.trace(module, inputs, check_trace=False)
self.assertEqual(
module(*inputs).to_dense(),
traced(*inputs).to_dense())
if __name__ == '__main__':
run_tests()

View File

@ -653,6 +653,16 @@ void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
block()->remapTypes(type_map);
}
void Value::inferTypeFrom(const at::Tensor& output) {
if (output.is_mkldnn()) {
// mkldnn tensor as opaque tensor doesn't have strides, so we can
// not create a CompleteTensorType
setType(DimensionedTensorType::create(output));
return;
}
setType(CompleteTensorType::create(output));
}
bool Value::mustBeNone() const {
return node_->mustBeNone();
}

View File

@ -162,9 +162,7 @@ struct Value {
public:
Value* setType(TypePtr type);
void inferTypeFrom(const at::Tensor& output) {
setType(CompleteTensorType::create(output));
}
TORCH_API void inferTypeFrom(const at::Tensor& output);
const TypePtr& type() const {
AT_ASSERT(type_ != nullptr);
return type_;

View File

@ -94,6 +94,11 @@ inline TypedIValue toTypedIValue(py::handle input) {
if (ten.is_sparse()) {
AT_ERROR("sparse tensors not supported");
}
if (ten.is_mkldnn()) {
// mkldnn tensor as opaque tensor doesn't have strides, so we can
// not create a CompleteTensorType
return TypedIValue(ten, DimensionedTensorType::create(ten));
}
return TypedIValue(ten, CompleteTensorType::create(ten));
} else if (six::isTuple(input)) {
py::tuple input_tuple = py::cast<py::tuple>(input);