mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable torch.jit.trace for mkldnn modules
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20800 Differential Revision: D15447892 fbshipit-source-id: 78e76523c5412c020a2bc22d6998ff7b36356720
This commit is contained in:
committed by
Facebook Github Bot
parent
63585c3b81
commit
8dedb04c26
@ -112,6 +112,7 @@ class TestMkldnn(TestCase):
|
||||
mkldnn_conv2d(x.to_mkldnn()).to_dense())
|
||||
|
||||
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
|
||||
self._test_tracing(mkldnn_conv2d, (x.to_mkldnn(),))
|
||||
|
||||
def test_relu(self):
|
||||
x = torch.randn((4, 5), dtype=torch.float32) * 10
|
||||
@ -177,6 +178,7 @@ class TestMkldnn(TestCase):
|
||||
mkldnn_bn(x.to_mkldnn()).to_dense())
|
||||
|
||||
self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))
|
||||
self._test_tracing(mkldnn_bn, (x.to_mkldnn(),))
|
||||
|
||||
def test_add(self):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
@ -244,6 +246,7 @@ class TestMkldnn(TestCase):
|
||||
mkldnn_linear(x.to_mkldnn()).to_dense())
|
||||
|
||||
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
|
||||
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
|
||||
|
||||
def _test_serialization(self, module, inputs):
|
||||
with TemporaryFileName() as fname:
|
||||
@ -253,6 +256,12 @@ class TestMkldnn(TestCase):
|
||||
module(*inputs).to_dense(),
|
||||
loaded(*inputs).to_dense())
|
||||
|
||||
def _test_tracing(self, module, inputs):
|
||||
traced = torch.jit.trace(module, inputs, check_trace=False)
|
||||
self.assertEqual(
|
||||
module(*inputs).to_dense(),
|
||||
traced(*inputs).to_dense())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -653,6 +653,16 @@ void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
|
||||
block()->remapTypes(type_map);
|
||||
}
|
||||
|
||||
void Value::inferTypeFrom(const at::Tensor& output) {
|
||||
if (output.is_mkldnn()) {
|
||||
// mkldnn tensor as opaque tensor doesn't have strides, so we can
|
||||
// not create a CompleteTensorType
|
||||
setType(DimensionedTensorType::create(output));
|
||||
return;
|
||||
}
|
||||
setType(CompleteTensorType::create(output));
|
||||
}
|
||||
|
||||
bool Value::mustBeNone() const {
|
||||
return node_->mustBeNone();
|
||||
}
|
||||
|
@ -162,9 +162,7 @@ struct Value {
|
||||
|
||||
public:
|
||||
Value* setType(TypePtr type);
|
||||
void inferTypeFrom(const at::Tensor& output) {
|
||||
setType(CompleteTensorType::create(output));
|
||||
}
|
||||
TORCH_API void inferTypeFrom(const at::Tensor& output);
|
||||
const TypePtr& type() const {
|
||||
AT_ASSERT(type_ != nullptr);
|
||||
return type_;
|
||||
|
@ -94,6 +94,11 @@ inline TypedIValue toTypedIValue(py::handle input) {
|
||||
if (ten.is_sparse()) {
|
||||
AT_ERROR("sparse tensors not supported");
|
||||
}
|
||||
if (ten.is_mkldnn()) {
|
||||
// mkldnn tensor as opaque tensor doesn't have strides, so we can
|
||||
// not create a CompleteTensorType
|
||||
return TypedIValue(ten, DimensionedTensorType::create(ten));
|
||||
}
|
||||
return TypedIValue(ten, CompleteTensorType::create(ten));
|
||||
} else if (six::isTuple(input)) {
|
||||
py::tuple input_tuple = py::cast<py::tuple>(input);
|
||||
|
Reference in New Issue
Block a user