mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840 Approved by: https://github.com/oulgen
413 lines
17 KiB
Python
413 lines
17 KiB
Python
# mypy: allow-untyped-defs
|
|
import numbers
|
|
from typing import Optional, Tuple
|
|
import warnings
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
"""
|
|
We will recreate all the RNN modules as we require the modules to be decomposed
|
|
into its building blocks to be able to observe.
|
|
"""
|
|
|
|
__all__ = [
|
|
"LSTMCell",
|
|
"LSTM"
|
|
]
|
|
|
|
class LSTMCell(torch.nn.Module):
|
|
r"""A quantizable long short-term memory (LSTM) cell.
|
|
|
|
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
|
|
|
|
Examples::
|
|
|
|
>>> import torch.ao.nn.quantizable as nnqa
|
|
>>> rnn = nnqa.LSTMCell(10, 20)
|
|
>>> input = torch.randn(6, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> cx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
... hx, cx = rnn(input[i], (hx, cx))
|
|
... output.append(hx)
|
|
"""
|
|
_FLOAT_MODULE = torch.nn.LSTMCell
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.input_size = input_dim
|
|
self.hidden_size = hidden_dim
|
|
self.bias = bias
|
|
|
|
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
|
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
|
self.gates = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
self.input_gate = torch.nn.Sigmoid()
|
|
self.forget_gate = torch.nn.Sigmoid()
|
|
self.cell_gate = torch.nn.Tanh()
|
|
self.output_gate = torch.nn.Sigmoid()
|
|
|
|
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
|
|
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
|
|
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
|
|
self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
|
|
self.hidden_state_dtype: torch.dtype = torch.quint8
|
|
self.cell_state_dtype: torch.dtype = torch.quint8
|
|
|
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
|
if hidden is None or hidden[0] is None or hidden[1] is None:
|
|
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
|
|
hx, cx = hidden
|
|
|
|
igates = self.igates(x)
|
|
hgates = self.hgates(hx)
|
|
gates = self.gates.add(igates, hgates)
|
|
|
|
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
|
|
|
input_gate = self.input_gate(input_gate)
|
|
forget_gate = self.forget_gate(forget_gate)
|
|
cell_gate = self.cell_gate(cell_gate)
|
|
out_gate = self.output_gate(out_gate)
|
|
|
|
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
|
|
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
|
|
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
|
|
cy = fgate_cx_igate_cgate
|
|
|
|
# TODO: make this tanh a member of the module so its qparams can be configured
|
|
tanh_cy = torch.tanh(cy)
|
|
hy = self.ogate_cy.mul(out_gate, tanh_cy)
|
|
return hy, cy
|
|
|
|
def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
|
|
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
|
|
if is_quantized:
|
|
(h_scale, h_zp) = self.initial_hidden_state_qparams
|
|
(c_scale, c_zp) = self.initial_cell_state_qparams
|
|
h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype)
|
|
c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype)
|
|
return h, c
|
|
|
|
def _get_name(self):
|
|
return 'QuantizableLSTMCell'
|
|
|
|
@classmethod
|
|
def from_params(cls, wi, wh, bi=None, bh=None):
|
|
"""Uses the weights and biases to create a new LSTM cell.
|
|
|
|
Args:
|
|
wi, wh: Weights for the input and hidden layers
|
|
bi, bh: Biases for the input and hidden layers
|
|
"""
|
|
assert (bi is None) == (bh is None) # Either both None or both have values
|
|
input_size = wi.shape[1]
|
|
hidden_size = wh.shape[1]
|
|
cell = cls(input_dim=input_size, hidden_dim=hidden_size,
|
|
bias=(bi is not None))
|
|
cell.igates.weight = torch.nn.Parameter(wi)
|
|
if bi is not None:
|
|
cell.igates.bias = torch.nn.Parameter(bi)
|
|
cell.hgates.weight = torch.nn.Parameter(wh)
|
|
if bh is not None:
|
|
cell.hgates.bias = torch.nn.Parameter(bh)
|
|
return cell
|
|
|
|
@classmethod
|
|
def from_float(cls, other, use_precomputed_fake_quant=False):
|
|
assert type(other) == cls._FLOAT_MODULE
|
|
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
|
observed = cls.from_params(other.weight_ih, other.weight_hh,
|
|
other.bias_ih, other.bias_hh)
|
|
observed.qconfig = other.qconfig
|
|
observed.igates.qconfig = other.qconfig
|
|
observed.hgates.qconfig = other.qconfig
|
|
return observed
|
|
|
|
|
|
class _LSTMSingleLayer(torch.nn.Module):
|
|
r"""A single one-directional LSTM layer.
|
|
|
|
The difference between a layer and a cell is that the layer can process a
|
|
sequence, while the cell only expects an instantaneous value.
|
|
"""
|
|
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
|
|
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
|
result = []
|
|
seq_len = x.shape[0]
|
|
for i in range(seq_len):
|
|
hidden = self.cell(x[i], hidden)
|
|
result.append(hidden[0]) # type: ignore[index]
|
|
result_tensor = torch.stack(result, 0)
|
|
return result_tensor, hidden
|
|
|
|
@classmethod
|
|
def from_params(cls, *args, **kwargs):
|
|
cell = LSTMCell.from_params(*args, **kwargs)
|
|
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
|
|
layer.cell = cell
|
|
return layer
|
|
|
|
|
|
class _LSTMLayer(torch.nn.Module):
|
|
r"""A single bi-directional LSTM layer."""
|
|
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
|
batch_first: bool = False, bidirectional: bool = False,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.batch_first = batch_first
|
|
self.bidirectional = bidirectional
|
|
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
|
if self.bidirectional:
|
|
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
|
|
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
|
if self.batch_first:
|
|
x = x.transpose(0, 1)
|
|
if hidden is None:
|
|
hx_fw, cx_fw = (None, None)
|
|
else:
|
|
hx_fw, cx_fw = hidden
|
|
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
|
|
if self.bidirectional:
|
|
if hx_fw is None:
|
|
hx_bw = None
|
|
else:
|
|
hx_bw = hx_fw[1]
|
|
hx_fw = hx_fw[0]
|
|
if cx_fw is None:
|
|
cx_bw = None
|
|
else:
|
|
cx_bw = cx_fw[1]
|
|
cx_fw = cx_fw[0]
|
|
if hx_bw is not None and cx_bw is not None:
|
|
hidden_bw = hx_bw, cx_bw
|
|
if hx_fw is None and cx_fw is None:
|
|
hidden_fw = None
|
|
else:
|
|
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
|
|
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
|
|
|
|
if hasattr(self, 'layer_bw') and self.bidirectional:
|
|
x_reversed = x.flip(0)
|
|
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
|
|
result_bw = result_bw.flip(0)
|
|
|
|
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
|
|
if hidden_fw is None and hidden_bw is None:
|
|
h = None
|
|
c = None
|
|
elif hidden_fw is None:
|
|
(h, c) = torch.jit._unwrap_optional(hidden_bw)
|
|
elif hidden_bw is None:
|
|
(h, c) = torch.jit._unwrap_optional(hidden_fw)
|
|
else:
|
|
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
|
|
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
|
|
else:
|
|
result = result_fw
|
|
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
|
|
|
|
if self.batch_first:
|
|
result.transpose_(0, 1)
|
|
|
|
return result, (h, c)
|
|
|
|
@classmethod
|
|
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
|
|
r"""
|
|
There is no FP equivalent of this class. This function is here just to
|
|
mimic the behavior of the `prepare` within the `torch.ao.quantization`
|
|
flow.
|
|
"""
|
|
assert hasattr(other, 'qconfig') or (qconfig is not None)
|
|
|
|
input_size = kwargs.get('input_size', other.input_size)
|
|
hidden_size = kwargs.get('hidden_size', other.hidden_size)
|
|
bias = kwargs.get('bias', other.bias)
|
|
batch_first = kwargs.get('batch_first', other.batch_first)
|
|
bidirectional = kwargs.get('bidirectional', other.bidirectional)
|
|
|
|
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
|
|
layer.qconfig = getattr(other, 'qconfig', qconfig)
|
|
wi = getattr(other, f'weight_ih_l{layer_idx}')
|
|
wh = getattr(other, f'weight_hh_l{layer_idx}')
|
|
bi = getattr(other, f'bias_ih_l{layer_idx}', None)
|
|
bh = getattr(other, f'bias_hh_l{layer_idx}', None)
|
|
|
|
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
|
|
|
if other.bidirectional:
|
|
wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
|
|
wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
|
|
bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
|
|
bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
|
|
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
|
return layer
|
|
|
|
|
|
class LSTM(torch.nn.Module):
|
|
r"""A quantizable long short-term memory (LSTM).
|
|
|
|
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
|
|
|
|
Attributes:
|
|
layers : instances of the `_LSTMLayer`
|
|
|
|
.. note::
|
|
To access the weights and biases, you need to access them per layer.
|
|
See examples below.
|
|
|
|
Examples::
|
|
|
|
>>> import torch.ao.nn.quantizable as nnqa
|
|
>>> rnn = nnqa.LSTM(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> c0 = torch.randn(2, 3, 20)
|
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
|
>>> # To get the weights:
|
|
>>> # xdoctest: +SKIP
|
|
>>> print(rnn.layers[0].weight_ih)
|
|
tensor([[...]])
|
|
>>> print(rnn.layers[0].weight_hh)
|
|
AssertionError: There is no reverse path in the non-bidirectional layer
|
|
"""
|
|
_FLOAT_MODULE = torch.nn.LSTM
|
|
|
|
def __init__(self, input_size: int, hidden_size: int,
|
|
num_layers: int = 1, bias: bool = True,
|
|
batch_first: bool = False, dropout: float = 0.,
|
|
bidirectional: bool = False,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = float(dropout)
|
|
self.bidirectional = bidirectional
|
|
self.training = False # Default to eval mode. If we want to train, we will explicitly set to training.
|
|
num_directions = 2 if bidirectional else 1
|
|
|
|
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
|
isinstance(dropout, bool):
|
|
raise ValueError("dropout should be a number in range [0, 1] "
|
|
"representing the probability of an element being "
|
|
"zeroed")
|
|
if dropout > 0:
|
|
warnings.warn("dropout option for quantizable LSTM is ignored. "
|
|
"If you are training, please, use nn.LSTM version "
|
|
"followed by `prepare` step.")
|
|
if num_layers == 1:
|
|
warnings.warn("dropout option adds dropout after all but last "
|
|
"recurrent layer, so non-zero dropout expects "
|
|
f"num_layers greater than 1, but got dropout={dropout} "
|
|
f"and num_layers={num_layers}")
|
|
|
|
layers = [_LSTMLayer(self.input_size, self.hidden_size,
|
|
self.bias, batch_first=False,
|
|
bidirectional=self.bidirectional, **factory_kwargs)]
|
|
for layer in range(1, num_layers):
|
|
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
|
|
self.bias, batch_first=False,
|
|
bidirectional=self.bidirectional,
|
|
**factory_kwargs))
|
|
self.layers = torch.nn.ModuleList(layers)
|
|
|
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
|
if self.batch_first:
|
|
x = x.transpose(0, 1)
|
|
|
|
max_batch_size = x.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
if hidden is None:
|
|
zeros = torch.zeros(num_directions, max_batch_size,
|
|
self.hidden_size, dtype=torch.float,
|
|
device=x.device)
|
|
zeros.squeeze_(0)
|
|
if x.is_quantized:
|
|
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
|
|
zero_point=0, dtype=x.dtype)
|
|
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
|
|
else:
|
|
hidden_non_opt = torch.jit._unwrap_optional(hidden)
|
|
if isinstance(hidden_non_opt[0], Tensor):
|
|
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
|
|
max_batch_size,
|
|
self.hidden_size)
|
|
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
|
|
max_batch_size,
|
|
self.hidden_size)
|
|
hxcx = [(hx[idx].squeeze(0), cx[idx].squeeze(0)) for idx in range(self.num_layers)]
|
|
else:
|
|
hxcx = hidden_non_opt
|
|
|
|
hx_list = []
|
|
cx_list = []
|
|
for idx, layer in enumerate(self.layers):
|
|
x, (h, c) = layer(x, hxcx[idx])
|
|
hx_list.append(torch.jit._unwrap_optional(h))
|
|
cx_list.append(torch.jit._unwrap_optional(c))
|
|
hx_tensor = torch.stack(hx_list)
|
|
cx_tensor = torch.stack(cx_list)
|
|
|
|
# We are creating another dimension for bidirectional case
|
|
# need to collapse it
|
|
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
|
|
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
|
|
|
|
if self.batch_first:
|
|
x = x.transpose(0, 1)
|
|
|
|
return x, (hx_tensor, cx_tensor)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizableLSTM'
|
|
|
|
@classmethod
|
|
def from_float(cls, other, qconfig=None):
|
|
assert isinstance(other, cls._FLOAT_MODULE)
|
|
assert (hasattr(other, 'qconfig') or qconfig)
|
|
observed = cls(other.input_size, other.hidden_size, other.num_layers,
|
|
other.bias, other.batch_first, other.dropout,
|
|
other.bidirectional)
|
|
observed.qconfig = getattr(other, 'qconfig', qconfig)
|
|
for idx in range(other.num_layers):
|
|
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
|
|
batch_first=False)
|
|
|
|
# Prepare the model
|
|
if other.training:
|
|
observed.train()
|
|
observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
|
|
else:
|
|
observed.eval()
|
|
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
|
return observed
|
|
|
|
@classmethod
|
|
def from_observed(cls, other):
|
|
# The whole flow is float -> observed -> quantized
|
|
# This class does float -> observed only
|
|
raise NotImplementedError("It looks like you are trying to convert a "
|
|
"non-quantizable LSTM module. Please, see "
|
|
"the examples on quantizable LSTMs.")
|