mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26666 Changes: - Introduce a `ConcreteModuleType` concept. This acts both as the key into the type cache, and as the source of truth for `ModuleValue::attr` queries. It needs to do both jobs because that's how we ensure correctness (if the types are different, it's because `ModuleValue::attr` would return different things). - Now `recursive_script` will first construct a `ConcreteModuleType` and search for a pre-existing type before starting compilation. - All previous paths to creating a `ScriptModule` (including inheriting from `ScriptModule`) are now rewritten to go through `create_script_module`, so that we have only a single place where construction happens. Behavioral changes: - Big change to `torch.jit.ScriptModule` inheritance: all attributes are now recursively scripted if possible, matching recursive scripting semantics. This makes it hard to keep something from being scripted (for example, a Python submodule). Possibly we'll need an `ignore()` type thing for attributes. In particular, this adds `self.training` to *every* ScriptModule, since it's present on every `nn.Module`. - I believe this change to be transparent to existing users of the inheritance API, since if you had an attribute that is unscriptable that you never used, there is no error. In some cases, we will create new attributes (even if they are unused), which will increase serialized model size from before. Test Plan: Imported from OSS Differential Revision: D17551196 Pulled By: suo fbshipit-source-id: b476d1c9feb3ddfd63406d90989aaf9dfe890591
153 lines
5.0 KiB
Python
153 lines
5.0 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import torch
|
|
|
|
|
|
class MkldnnLinear(torch.jit.ScriptModule):
|
|
def __init__(self, dense_module):
|
|
super(MkldnnLinear, self).__init__()
|
|
self.register_buffer('weight', dense_module.weight.to_mkldnn())
|
|
if dense_module.bias is not None:
|
|
self.register_buffer('bias', dense_module.bias.to_mkldnn())
|
|
else:
|
|
# TODO: Remove this once ScriptModule supports registering None buffer
|
|
self.register_buffer(
|
|
'bias',
|
|
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
|
|
|
|
@torch.jit.script_method
|
|
def __getstate__(self):
|
|
return (self.weight.to_dense(), self.bias.to_dense(), self.training)
|
|
|
|
@torch.jit.script_method
|
|
def __setstate__(self, state):
|
|
self.weight = state[0].to_mkldnn()
|
|
self.bias = state[1].to_mkldnn()
|
|
self.training = state[2]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
|
|
y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
|
|
y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
|
|
return y
|
|
|
|
|
|
class MkldnnConv2d(torch.jit.ScriptModule):
|
|
__constants__ = ['stride', 'padding', 'dilation', 'groups']
|
|
|
|
def __init__(self, dense_module):
|
|
super(MkldnnConv2d, self).__init__()
|
|
|
|
self.stride = dense_module.stride
|
|
self.padding = dense_module.padding
|
|
self.dilation = dense_module.dilation
|
|
self.groups = dense_module.groups
|
|
|
|
self.register_buffer('weight', dense_module.weight.to_mkldnn())
|
|
if dense_module.bias is not None:
|
|
self.register_buffer('bias', dense_module.bias.to_mkldnn())
|
|
else:
|
|
# TODO: Remove this once ScriptModule supports registering None buffer
|
|
self.register_buffer(
|
|
'bias',
|
|
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
|
|
|
|
@torch.jit.script_method
|
|
def __getstate__(self):
|
|
return (self.weight.to_dense(), self.bias.to_dense(), self.training)
|
|
|
|
@torch.jit.script_method
|
|
def __setstate__(self, state):
|
|
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
|
|
state[0].to_mkldnn(),
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups)
|
|
self.bias = state[1].to_mkldnn()
|
|
self.training = state[2]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups)
|
|
|
|
|
|
class MkldnnBatchNorm2d(torch.jit.ScriptModule):
|
|
__constants__ = ['exponential_average_factor', 'eps']
|
|
|
|
def __init__(self, dense_module):
|
|
super(MkldnnBatchNorm2d, self).__init__()
|
|
|
|
assert(not dense_module.training)
|
|
assert(dense_module.track_running_stats)
|
|
assert(dense_module.affine)
|
|
|
|
if dense_module.momentum is None:
|
|
self.exponential_average_factor = 0.0
|
|
else:
|
|
self.exponential_average_factor = dense_module.momentum
|
|
self.eps = dense_module.eps
|
|
|
|
self.register_buffer('weight', dense_module.weight.to_mkldnn())
|
|
self.register_buffer('bias', dense_module.bias.to_mkldnn())
|
|
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
|
|
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
|
|
|
|
@torch.jit.script_method
|
|
def __getstate__(self):
|
|
weight = self.weight.to_dense()
|
|
bias = self.bias.to_dense()
|
|
running_mean = self.running_mean.to_dense()
|
|
running_var = self.running_var.to_dense()
|
|
return (weight, bias, running_mean, running_var, self.training)
|
|
|
|
@torch.jit.script_method
|
|
def __setstate__(self, state):
|
|
self.weight = state[0].to_mkldnn()
|
|
self.bias = state[1].to_mkldnn()
|
|
self.running_mean = state[2].to_mkldnn()
|
|
self.running_var = state[3].to_mkldnn()
|
|
self.training = state[4]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.batch_norm(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.running_mean,
|
|
self.running_var,
|
|
False, # training
|
|
self.exponential_average_factor,
|
|
self.eps,
|
|
False, # cuda_enabled
|
|
)
|
|
|
|
|
|
def to_mkldnn(module):
|
|
def m_fn(m):
|
|
if isinstance(m, torch.nn.Linear):
|
|
return MkldnnLinear(m)
|
|
elif isinstance(m, torch.nn.Conv2d):
|
|
return MkldnnConv2d(m)
|
|
elif isinstance(m, torch.nn.BatchNorm2d):
|
|
return MkldnnBatchNorm2d(m)
|
|
else:
|
|
return m
|
|
|
|
def m_fn_rec(m):
|
|
new_m = m_fn(m)
|
|
for name, sub_m in m.named_children():
|
|
setattr(new_m, name, m_fn_rec(sub_m))
|
|
return new_m
|
|
|
|
return m_fn_rec(module)
|