mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Add support for save and load mkldnn modules
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20799 Reviewed By: wanchaol Differential Revision: D15447891 fbshipit-source-id: e34de946c79282fb934a5c52ff1def41c7993c75
This commit is contained in:
committed by
Facebook Github Bot
parent
5f83c5d834
commit
63585c3b81
@ -20,10 +20,12 @@ import contextlib
|
||||
import socket
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from itertools import product
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
import tempfile
|
||||
|
||||
import __main__
|
||||
import errno
|
||||
@ -66,6 +68,24 @@ IS_PPC = platform.machine() == "ppc64le"
|
||||
# Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
|
||||
IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI', 0))
|
||||
|
||||
if IS_WINDOWS:
|
||||
@contextmanager
|
||||
def TemporaryFileName():
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
else:
|
||||
@contextmanager # noqa: T484
|
||||
def TemporaryFileName():
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
yield f.name
|
||||
|
||||
|
||||
def _check_module_exists(name):
|
||||
r"""Returns if a top-level module with :attr:`name` exists *without**
|
||||
|
||||
@ -17,7 +17,7 @@ from torch.onnx import OperatorExportTypes
|
||||
from torch._six import inf, PY2, builtins, StringIO
|
||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
|
||||
freeze_rng_state, set_rng_seed, slowTest
|
||||
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName
|
||||
from common_nn import module_tests, new_module_tests, criterion_tests
|
||||
from textwrap import dedent
|
||||
from functools import wraps, reduce
|
||||
@ -84,25 +84,6 @@ PY35 = sys.version_info >= (3, 5)
|
||||
WINDOWS = sys.platform == 'win32'
|
||||
|
||||
|
||||
if WINDOWS:
|
||||
@contextmanager
|
||||
def TemporaryFileName():
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
else:
|
||||
@contextmanager # noqa: T484
|
||||
def TemporaryFileName():
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
yield f.name
|
||||
|
||||
|
||||
def LSTMCellF(input, hx, cx, *params):
|
||||
return LSTMCell(input, (hx, cx), *params)
|
||||
|
||||
|
||||
@ -3,8 +3,10 @@ import copy
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
from torch.utils import mkldnn as mkldnn_utils
|
||||
from common_utils import TestCase, run_tests
|
||||
from common_utils import TestCase, run_tests, TemporaryFileName
|
||||
|
||||
from torch.autograd.gradcheck import gradgradcheck, gradcheck
|
||||
|
||||
|
||||
@ -109,6 +111,8 @@ class TestMkldnn(TestCase):
|
||||
conv2d(x),
|
||||
mkldnn_conv2d(x.to_mkldnn()).to_dense())
|
||||
|
||||
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
|
||||
|
||||
def test_relu(self):
|
||||
x = torch.randn((4, 5), dtype=torch.float32) * 10
|
||||
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
|
||||
@ -172,6 +176,8 @@ class TestMkldnn(TestCase):
|
||||
bn(x),
|
||||
mkldnn_bn(x.to_mkldnn()).to_dense())
|
||||
|
||||
self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))
|
||||
|
||||
def test_add(self):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
C = torch.randint(3, 100, (1,)).item()
|
||||
@ -231,12 +237,22 @@ class TestMkldnn(TestCase):
|
||||
x = torch.randn(3, in_features, dtype=torch.float32) * 10
|
||||
|
||||
for bias in [True, False]:
|
||||
linear = torch.nn.Linear(in_features, out_features).float()
|
||||
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
|
||||
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
|
||||
self.assertEqual(
|
||||
linear(x),
|
||||
mkldnn_linear(x.to_mkldnn()).to_dense())
|
||||
|
||||
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
|
||||
|
||||
def _test_serialization(self, module, inputs):
|
||||
with TemporaryFileName() as fname:
|
||||
torch.jit.save(module, fname)
|
||||
loaded = torch.jit.load(fname)
|
||||
self.assertEqual(
|
||||
module(*inputs).to_dense(),
|
||||
loaded(*inputs).to_dense())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -1,43 +1,149 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MkldnnLinear(torch.jit.ScriptModule):
|
||||
def __init__(self, dense_module):
|
||||
super(MkldnnLinear, self).__init__()
|
||||
self.register_buffer('weight', dense_module.weight.to_mkldnn())
|
||||
if dense_module.bias is not None:
|
||||
self.register_buffer('bias', dense_module.bias.to_mkldnn())
|
||||
else:
|
||||
# TODO: Remove this once ScriptModule supports registering None buffer
|
||||
self.register_buffer(
|
||||
'bias',
|
||||
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
|
||||
|
||||
@torch.jit.script_method
|
||||
def __getstate__(self):
|
||||
return (self.weight.to_dense(), self.bias.to_dense())
|
||||
|
||||
@torch.jit.script_method
|
||||
def __setstate__(self, state):
|
||||
# type: (Tuple[Tensor, Tensor]) -> None
|
||||
self.weight = state[0].to_mkldnn()
|
||||
self.bias = state[1].to_mkldnn()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return torch._C._nn.mkldnn_linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class MkldnnConv2d(torch.jit.ScriptModule):
|
||||
__constants__ = ['stride', 'padding', 'dilation', 'groups']
|
||||
|
||||
def __init__(self, dense_module):
|
||||
super(MkldnnConv2d, self).__init__()
|
||||
|
||||
self.stride = dense_module.stride
|
||||
self.padding = dense_module.padding
|
||||
self.dilation = dense_module.dilation
|
||||
self.groups = dense_module.groups
|
||||
|
||||
self.register_buffer('weight', dense_module.weight.to_mkldnn())
|
||||
if dense_module.bias is not None:
|
||||
self.register_buffer('bias', dense_module.bias.to_mkldnn())
|
||||
else:
|
||||
# TODO: Remove this once ScriptModule supports registering None buffer
|
||||
self.register_buffer(
|
||||
'bias',
|
||||
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
|
||||
|
||||
@torch.jit.script_method
|
||||
def __getstate__(self):
|
||||
return (self.weight.to_dense(), self.bias.to_dense())
|
||||
|
||||
@torch.jit.script_method
|
||||
def __setstate__(self, state):
|
||||
# type: (Tuple[Tensor, Tensor]) -> None
|
||||
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
|
||||
state[0].to_mkldnn(),
|
||||
self.padding,
|
||||
self.stride,
|
||||
self.dilation,
|
||||
self.groups)
|
||||
self.bias = state[1].to_mkldnn()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return torch.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups)
|
||||
|
||||
|
||||
class MkldnnBatchNorm2d(torch.jit.ScriptModule):
|
||||
__constants__ = ['exponential_average_factor', 'eps']
|
||||
|
||||
def __init__(self, dense_module):
|
||||
super(MkldnnBatchNorm2d, self).__init__()
|
||||
|
||||
assert(not dense_module.training)
|
||||
assert(dense_module.track_running_stats)
|
||||
assert(dense_module.affine)
|
||||
|
||||
if dense_module.momentum is None:
|
||||
self.exponential_average_factor = 0.0
|
||||
else:
|
||||
self.exponential_average_factor = dense_module.momentum
|
||||
self.eps = dense_module.eps
|
||||
|
||||
self.register_buffer('weight', dense_module.weight.to_mkldnn())
|
||||
self.register_buffer('bias', dense_module.bias.to_mkldnn())
|
||||
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
|
||||
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
|
||||
|
||||
@torch.jit.script_method
|
||||
def __getstate__(self):
|
||||
weight = self.weight.to_dense()
|
||||
bias = self.bias.to_dense()
|
||||
running_mean = self.running_mean.to_dense()
|
||||
running_var = self.running_var.to_dense()
|
||||
return (weight, bias, running_mean, running_var)
|
||||
|
||||
@torch.jit.script_method
|
||||
def __setstate__(self, state):
|
||||
# type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None
|
||||
self.weight = state[0].to_mkldnn()
|
||||
self.bias = state[1].to_mkldnn()
|
||||
self.running_mean = state[2].to_mkldnn()
|
||||
self.running_var = state[3].to_mkldnn()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return torch.batch_norm(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.running_mean,
|
||||
self.running_var,
|
||||
False, # training
|
||||
self.exponential_average_factor,
|
||||
self.eps,
|
||||
False, # cuda_enabled
|
||||
)
|
||||
|
||||
|
||||
def to_mkldnn(module):
|
||||
def t_fn(t):
|
||||
if t.is_floating_point():
|
||||
return t.to_mkldnn()
|
||||
|
||||
def m_fn(m):
|
||||
# TODO: This is a temporary hack to work around the fact that
|
||||
# nn.Linear is decomposed into addmm/matmul. Later we will
|
||||
# change nn.Linear to directly call aten linear and we can
|
||||
# remove this patch
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
m.forward = functools.partial(
|
||||
torch._C._nn.linear,
|
||||
weight=m.weight,
|
||||
bias=m.bias)
|
||||
return MkldnnLinear(m)
|
||||
elif isinstance(m, torch.nn.Conv2d):
|
||||
return MkldnnConv2d(m)
|
||||
elif isinstance(m, torch.nn.BatchNorm2d):
|
||||
return MkldnnBatchNorm2d(m)
|
||||
else:
|
||||
return m
|
||||
|
||||
for param in m._parameters.values():
|
||||
if param is not None:
|
||||
# Tensors stored in modules are graph leaves, and we don't
|
||||
# want to create copy nodes, so we have to unpack the data.
|
||||
param.data = t_fn(param.data)
|
||||
if param._grad is not None:
|
||||
param._grad.data = t_fn(param._grad.data)
|
||||
def m_fn_rec(m):
|
||||
new_m = m_fn(m)
|
||||
for name, sub_m in m.named_children():
|
||||
setattr(new_m, name, m_fn_rec(sub_m))
|
||||
return new_m
|
||||
|
||||
for key, buf in m._buffers.items():
|
||||
if buf is not None:
|
||||
m._buffers[key] = t_fn(buf)
|
||||
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
|
||||
m.weight.data,
|
||||
m.padding,
|
||||
m.stride,
|
||||
m.dilation,
|
||||
m.groups)
|
||||
|
||||
return module.apply(m_fn)
|
||||
return m_fn_rec(module)
|
||||
|
||||
Reference in New Issue
Block a user