Files
pytorch/torch/utils/mkldnn.py
Michael Suo 341262754f module dedupe (#26666)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26666

Changes:
- Introduce a `ConcreteModuleType` concept. This acts both as the key into the type
  cache, and as the source of truth for `ModuleValue::attr` queries. It needs
  to do both jobs because that's how we ensure correctness (if the types are
  different, it's because `ModuleValue::attr` would return different things).
- Now `recursive_script` will first construct a `ConcreteModuleType` and search for a
  pre-existing type before starting compilation.
- All previous paths to creating a `ScriptModule` (including inheriting from
  `ScriptModule`) are now rewritten to go through `create_script_module`, so
  that we have only a single place where construction happens.

Behavioral changes:
- Big change to `torch.jit.ScriptModule` inheritance: all attributes are now
  recursively scripted if possible, matching recursive scripting semantics.
  This makes it hard to keep something from being scripted (for example, a
  Python submodule). Possibly we'll need an `ignore()` type thing for
  attributes. In particular, this adds `self.training` to *every* ScriptModule, since
  it's present on every `nn.Module`.
- I believe this change to be transparent to existing users of the inheritance API, since if you had an attribute that is unscriptable that you never used, there is no error. In some cases, we will create new attributes (even if they are unused), which will increase serialized model size from before.

Test Plan: Imported from OSS

Differential Revision: D17551196

Pulled By: suo

fbshipit-source-id: b476d1c9feb3ddfd63406d90989aaf9dfe890591
2019-10-12 09:51:57 -07:00

153 lines
5.0 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
class MkldnnLinear(torch.jit.ScriptModule):
def __init__(self, dense_module):
super(MkldnnLinear, self).__init__()
self.register_buffer('weight', dense_module.weight.to_mkldnn())
if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense(), self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.training = state[2]
@torch.jit.script_method
def forward(self, x):
x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
return y
class MkldnnConv2d(torch.jit.ScriptModule):
__constants__ = ['stride', 'padding', 'dilation', 'groups']
def __init__(self, dense_module):
super(MkldnnConv2d, self).__init__()
self.stride = dense_module.stride
self.padding = dense_module.padding
self.dilation = dense_module.dilation
self.groups = dense_module.groups
self.register_buffer('weight', dense_module.weight.to_mkldnn())
if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense(), self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
state[0].to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups)
self.bias = state[1].to_mkldnn()
self.training = state[2]
@torch.jit.script_method
def forward(self, x):
return torch.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups)
class MkldnnBatchNorm2d(torch.jit.ScriptModule):
__constants__ = ['exponential_average_factor', 'eps']
def __init__(self, dense_module):
super(MkldnnBatchNorm2d, self).__init__()
assert(not dense_module.training)
assert(dense_module.track_running_stats)
assert(dense_module.affine)
if dense_module.momentum is None:
self.exponential_average_factor = 0.0
else:
self.exponential_average_factor = dense_module.momentum
self.eps = dense_module.eps
self.register_buffer('weight', dense_module.weight.to_mkldnn())
self.register_buffer('bias', dense_module.bias.to_mkldnn())
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
weight = self.weight.to_dense()
bias = self.bias.to_dense()
running_mean = self.running_mean.to_dense()
running_var = self.running_var.to_dense()
return (weight, bias, running_mean, running_var, self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.running_mean = state[2].to_mkldnn()
self.running_var = state[3].to_mkldnn()
self.training = state[4]
@torch.jit.script_method
def forward(self, x):
return torch.batch_norm(
x,
self.weight,
self.bias,
self.running_mean,
self.running_var,
False, # training
self.exponential_average_factor,
self.eps,
False, # cuda_enabled
)
def to_mkldnn(module):
def m_fn(m):
if isinstance(m, torch.nn.Linear):
return MkldnnLinear(m)
elif isinstance(m, torch.nn.Conv2d):
return MkldnnConv2d(m)
elif isinstance(m, torch.nn.BatchNorm2d):
return MkldnnBatchNorm2d(m)
else:
return m
def m_fn_rec(m):
new_m = m_fn(m)
for name, sub_m in m.named_children():
setattr(new_m, name, m_fn_rec(sub_m))
return new_m
return m_fn_rec(module)