mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal Pull Request resolved: https://github.com/pytorch/pytorch/pull/17721 Differential Revision: D14357778 Pulled By: eellison fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d
246 lines
9.9 KiB
Python
246 lines
9.9 KiB
Python
import torch
|
|
import copy
|
|
import numbers
|
|
from typing import Tuple, Optional
|
|
from torch import Tensor
|
|
from torch.jit import ScriptModule
|
|
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
from torch.nn import _VF
|
|
|
|
|
|
class QuantizedLinear(torch.jit.ScriptModule):
|
|
__constants__ = ['scale', 'zero_point']
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedLinear, self).__init__()
|
|
self.in_features = other.in_features
|
|
self.out_features = other.out_features
|
|
# Quantize weight and discard the original
|
|
self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
|
|
other.weight.clone().float())
|
|
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
|
|
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
|
|
assert other.bias is not None, 'QuantizedLinear requires a bias'
|
|
self.bias = torch.nn.Parameter(other.bias.clone().float())
|
|
|
|
self.register_buffer(
|
|
'packed_tensor_ptr',
|
|
torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0)))
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.packed_tensor_ptr.set_(
|
|
torch.fbgemm_pack_quantized_matrix(
|
|
self.weight, self.weight.size(1), self.weight.size(0)))
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.packed_tensor_ptr.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = torch.fbgemm_linear_int8_weight(
|
|
input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
|
|
self.scale, self.zero_point, self.bias)
|
|
return out.type_as(input)
|
|
|
|
def extra_repr(self):
|
|
repr = 'in_features={in_features}, out_features={out_features}, ' \
|
|
'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
|
|
return repr
|
|
|
|
|
|
# Quantized RNN cell implementations
|
|
class QuantizedRNNCellBase(torch.jit.ScriptModule):
|
|
__constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
|
|
'zero_point_ih', 'zero_point_hh']
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedRNNCellBase, self).__init__()
|
|
self.input_size = other.input_size
|
|
self.hidden_size = other.hidden_size
|
|
self.bias = other.bias
|
|
if not self.bias:
|
|
raise ValueError("Quantized RNN cells require bias terms")
|
|
|
|
weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
|
|
torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
|
|
self.register_buffer('weight_ih', weight_ih)
|
|
self.register_buffer('col_offsets_ih', col_offsets_ih)
|
|
weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
|
|
torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
|
|
self.register_buffer('weight_hh', weight_hh)
|
|
self.register_buffer('col_offsets_hh', col_offsets_hh)
|
|
|
|
packed_ih = torch.fbgemm_pack_quantized_matrix(
|
|
self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0))
|
|
self.register_buffer('packed_ih', packed_ih)
|
|
packed_hh = torch.fbgemm_pack_quantized_matrix(
|
|
self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0))
|
|
self.register_buffer('packed_hh', packed_hh)
|
|
|
|
self.bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=False)
|
|
self.bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=False)
|
|
|
|
def extra_repr(self):
|
|
s = '{input_size}, {hidden_size}'
|
|
if 'bias' in self.__dict__ and self.bias is not True:
|
|
s += ', bias={bias}'
|
|
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
|
|
s += ', nonlinearity={nonlinearity}'
|
|
return s.format(**self.__dict__)
|
|
|
|
@torch.jit.script_method
|
|
def check_forward_input(self, input):
|
|
if input.size(1) != self.input_size:
|
|
raise RuntimeError(
|
|
"input has inconsistent input_size: got {}, expected {}".format(
|
|
input.size(1), self.input_size))
|
|
|
|
@torch.jit.script_method
|
|
def check_forward_hidden(self, input, hx, hidden_label=''):
|
|
# type: (Tensor, Tensor, str) -> None
|
|
if input.size(0) != hx.size(0):
|
|
raise RuntimeError(
|
|
"Input batch size {} doesn't match hidden{} batch size {}".format(
|
|
input.size(0), hidden_label, hx.size(0)))
|
|
|
|
if hx.size(1) != self.hidden_size:
|
|
raise RuntimeError(
|
|
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
|
|
hidden_label, hx.size(1), self.hidden_size))
|
|
|
|
# TODO: for some reason weak_script_method causes a destruction of the
|
|
# module to occur, which in turn frees the packed_ih object via its DataPtr
|
|
# deleter. This is bizarre and should probably get fixed.
|
|
# @torch._jit_internal.weak_script_method
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(
|
|
self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0)))
|
|
self.packed_hh.set_(
|
|
torch.fbgemm_pack_quantized_matrix(
|
|
self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0)))
|
|
|
|
# @torch._jit_internal.weak_script_method
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.packed_ih.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
self.packed_hh.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
|
|
|
|
class QuantizedRNNCell(QuantizedRNNCellBase):
|
|
__constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
|
|
'zero_point_ih', 'zero_point_hh', 'nonlinearity']
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedRNNCell, self).__init__(other)
|
|
self.nonlinearity = other.nonlinearity
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
if self.nonlinearity == "tanh":
|
|
ret = _VF.quantized_rnn_tanh_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
elif self.nonlinearity == "relu":
|
|
ret = _VF.quantized_rnn_relu_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
else:
|
|
ret = input # TODO: remove when jit supports exception flow
|
|
raise RuntimeError(
|
|
"Unknown nonlinearity: {}".format(self.nonlinearity))
|
|
return ret
|
|
|
|
|
|
class QuantizedLSTMCell(QuantizedRNNCellBase):
|
|
def __init__(self, other):
|
|
super(QuantizedLSTMCell, self).__init__(other)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
self.check_forward_hidden(input, hx[0], '[0]')
|
|
self.check_forward_hidden(input, hx[1], '[1]')
|
|
return _VF.quantized_lstm_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
|
|
|
|
class QuantizedGRUCell(QuantizedRNNCellBase):
|
|
def __init__(self, other):
|
|
super(QuantizedGRUCell, self).__init__(other)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
return _VF.quantized_gru_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
|
|
|
|
def quantize_rnn_cell_modules(module):
|
|
reassign = {}
|
|
for name, mod in module.named_modules():
|
|
if mod is module:
|
|
continue
|
|
new_mod = quantize_rnn_cell_modules(mod)
|
|
if new_mod is not mod:
|
|
reassign[name] = new_mod
|
|
for name, mod in reassign.items():
|
|
setattr(module, name, mod)
|
|
if isinstance(module, torch.nn.LSTMCell):
|
|
return QuantizedLSTMCell(mod)
|
|
if isinstance(module, torch.nn.GRUCell):
|
|
return QuantizedGRUCell(mod)
|
|
if isinstance(module, torch.nn.RNNCell):
|
|
return QuantizedRNNCell(mod)
|
|
|
|
return module
|
|
|
|
|
|
def quantize_linear_modules(module):
|
|
reassign = {}
|
|
for name, mod in module.named_modules():
|
|
if mod is module:
|
|
continue
|
|
new_mod = quantize_linear_modules(mod)
|
|
if new_mod is not mod:
|
|
reassign[name] = new_mod
|
|
|
|
for name, mod in reassign.items():
|
|
setattr(module, name, mod)
|
|
if isinstance(mod, torch.nn.Linear):
|
|
return QuantizedLinear(mod)
|
|
return module
|