Add support for save and load mkldnn modules

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20799

Reviewed By: wanchaol

Differential Revision: D15447891

fbshipit-source-id: e34de946c79282fb934a5c52ff1def41c7993c75
This commit is contained in:
Junjie Bai
2019-05-23 12:46:08 -07:00
committed by Facebook Github Bot
parent 5f83c5d834
commit 63585c3b81
4 changed files with 179 additions and 56 deletions

View File

@ -20,10 +20,12 @@ import contextlib
import socket
import time
from collections import OrderedDict
from contextlib import contextmanager
from functools import wraps
from itertools import product
from copy import deepcopy
from numbers import Number
import tempfile
import __main__
import errno
@ -66,6 +68,24 @@ IS_PPC = platform.machine() == "ppc64le"
# Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI', 0))
if IS_WINDOWS:
@contextmanager
def TemporaryFileName():
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
yield f.name
finally:
os.unlink(f.name)
else:
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name
def _check_module_exists(name):
r"""Returns if a top-level module with :attr:`name` exists *without**

View File

@ -17,7 +17,7 @@ from torch.onnx import OperatorExportTypes
from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName
from common_nn import module_tests, new_module_tests, criterion_tests
from textwrap import dedent
from functools import wraps, reduce
@ -84,25 +84,6 @@ PY35 = sys.version_info >= (3, 5)
WINDOWS = sys.platform == 'win32'
if WINDOWS:
@contextmanager
def TemporaryFileName():
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
yield f.name
finally:
os.unlink(f.name)
else:
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name
def LSTMCellF(input, hx, cx, *params):
return LSTMCell(input, (hx, cx), *params)

View File

@ -3,8 +3,10 @@ import copy
import unittest
import torch
import torch.jit
from torch.utils import mkldnn as mkldnn_utils
from common_utils import TestCase, run_tests
from common_utils import TestCase, run_tests, TemporaryFileName
from torch.autograd.gradcheck import gradgradcheck, gradcheck
@ -109,6 +111,8 @@ class TestMkldnn(TestCase):
conv2d(x),
mkldnn_conv2d(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
def test_relu(self):
x = torch.randn((4, 5), dtype=torch.float32) * 10
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
@ -172,6 +176,8 @@ class TestMkldnn(TestCase):
bn(x),
mkldnn_bn(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))
def test_add(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 100, (1,)).item()
@ -231,12 +237,22 @@ class TestMkldnn(TestCase):
x = torch.randn(3, in_features, dtype=torch.float32) * 10
for bias in [True, False]:
linear = torch.nn.Linear(in_features, out_features).float()
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
self.assertEqual(
linear(x),
mkldnn_linear(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
def _test_serialization(self, module, inputs):
with TemporaryFileName() as fname:
torch.jit.save(module, fname)
loaded = torch.jit.load(fname)
self.assertEqual(
module(*inputs).to_dense(),
loaded(*inputs).to_dense())
if __name__ == '__main__':
run_tests()

View File

@ -1,43 +1,149 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import functools
import torch
class MkldnnLinear(torch.jit.ScriptModule):
def __init__(self, dense_module):
super(MkldnnLinear, self).__init__()
self.register_buffer('weight', dense_module.weight.to_mkldnn())
if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense())
@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor]) -> None
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
@torch.jit.script_method
def forward(self, x):
return torch._C._nn.mkldnn_linear(x, self.weight, self.bias)
class MkldnnConv2d(torch.jit.ScriptModule):
__constants__ = ['stride', 'padding', 'dilation', 'groups']
def __init__(self, dense_module):
super(MkldnnConv2d, self).__init__()
self.stride = dense_module.stride
self.padding = dense_module.padding
self.dilation = dense_module.dilation
self.groups = dense_module.groups
self.register_buffer('weight', dense_module.weight.to_mkldnn())
if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense())
@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor]) -> None
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
state[0].to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups)
self.bias = state[1].to_mkldnn()
@torch.jit.script_method
def forward(self, x):
return torch.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups)
class MkldnnBatchNorm2d(torch.jit.ScriptModule):
__constants__ = ['exponential_average_factor', 'eps']
def __init__(self, dense_module):
super(MkldnnBatchNorm2d, self).__init__()
assert(not dense_module.training)
assert(dense_module.track_running_stats)
assert(dense_module.affine)
if dense_module.momentum is None:
self.exponential_average_factor = 0.0
else:
self.exponential_average_factor = dense_module.momentum
self.eps = dense_module.eps
self.register_buffer('weight', dense_module.weight.to_mkldnn())
self.register_buffer('bias', dense_module.bias.to_mkldnn())
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
weight = self.weight.to_dense()
bias = self.bias.to_dense()
running_mean = self.running_mean.to_dense()
running_var = self.running_var.to_dense()
return (weight, bias, running_mean, running_var)
@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.running_mean = state[2].to_mkldnn()
self.running_var = state[3].to_mkldnn()
@torch.jit.script_method
def forward(self, x):
return torch.batch_norm(
x,
self.weight,
self.bias,
self.running_mean,
self.running_var,
False, # training
self.exponential_average_factor,
self.eps,
False, # cuda_enabled
)
def to_mkldnn(module):
def t_fn(t):
if t.is_floating_point():
return t.to_mkldnn()
def m_fn(m):
# TODO: This is a temporary hack to work around the fact that
# nn.Linear is decomposed into addmm/matmul. Later we will
# change nn.Linear to directly call aten linear and we can
# remove this patch
if isinstance(m, torch.nn.Linear):
m.forward = functools.partial(
torch._C._nn.linear,
weight=m.weight,
bias=m.bias)
return MkldnnLinear(m)
elif isinstance(m, torch.nn.Conv2d):
return MkldnnConv2d(m)
elif isinstance(m, torch.nn.BatchNorm2d):
return MkldnnBatchNorm2d(m)
else:
return m
for param in m._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = t_fn(param.data)
if param._grad is not None:
param._grad.data = t_fn(param._grad.data)
def m_fn_rec(m):
new_m = m_fn(m)
for name, sub_m in m.named_children():
setattr(new_m, name, m_fn_rec(sub_m))
return new_m
for key, buf in m._buffers.items():
if buf is not None:
m._buffers[key] = t_fn(buf)
if isinstance(m, torch.nn.Conv2d):
m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
m.weight.data,
m.padding,
m.stride,
m.dilation,
m.groups)
return module.apply(m_fn)
return m_fn_rec(module)