mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037 Approved by: https://github.com/mlazos
1366 lines
50 KiB
Python
1366 lines
50 KiB
Python
# mypy: allow-untyped-defs
|
|
import numbers
|
|
import warnings
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor # noqa: F401
|
|
from torch._jit_internal import Dict, List, Optional, Tuple, Union # noqa: F401
|
|
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
|
|
|
|
__all__ = [
|
|
"pack_weight_bias",
|
|
"PackedParameter",
|
|
"RNNBase",
|
|
"LSTM",
|
|
"GRU",
|
|
"RNNCellBase",
|
|
"RNNCell",
|
|
"LSTMCell",
|
|
"GRUCell",
|
|
"apply_permutation",
|
|
]
|
|
|
|
|
|
def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
|
|
return tensor.index_select(dim, permutation)
|
|
|
|
|
|
@deprecated(
|
|
"`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",
|
|
category=FutureWarning,
|
|
)
|
|
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
|
|
return _apply_permutation(tensor, permutation, dim)
|
|
|
|
|
|
def pack_weight_bias(qweight, bias, dtype):
|
|
if dtype == torch.qint8:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# w_ih, w_hh
|
|
packed_weight = torch.ops.quantized.linear_prepack(qweight, bias)
|
|
|
|
return packed_weight
|
|
else:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# packed_ih, packed_hh, b_ih, b_hh
|
|
packed_weight = torch.ops.quantized.linear_prepack_fp16(qweight, bias)
|
|
|
|
return packed_weight
|
|
|
|
|
|
class PackedParameter(torch.nn.Module):
|
|
def __init__(self, param):
|
|
super().__init__()
|
|
self.param = param
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + "param"] = self.param
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
):
|
|
self.param = state_dict[prefix + "param"]
|
|
super()._load_from_state_dict(
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
False,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
)
|
|
|
|
|
|
class RNNBase(torch.nn.Module):
|
|
_FLOAT_MODULE = nn.RNNBase
|
|
|
|
_version = 2
|
|
|
|
def __init__(
|
|
self,
|
|
mode,
|
|
input_size,
|
|
hidden_size,
|
|
num_layers=1,
|
|
bias=True,
|
|
batch_first=False,
|
|
dropout=0.0,
|
|
bidirectional=False,
|
|
dtype=torch.qint8,
|
|
):
|
|
super().__init__()
|
|
|
|
self.mode = mode
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = float(dropout)
|
|
self.bidirectional = bidirectional
|
|
self.dtype = dtype
|
|
self.version = 2
|
|
self.training = False
|
|
num_directions = 2 if bidirectional else 1
|
|
|
|
# "type: ignore" is required since ints and Numbers are not fully comparable
|
|
# https://github.com/python/mypy/issues/8566
|
|
if (
|
|
not isinstance(dropout, numbers.Number)
|
|
or not 0 <= dropout <= 1 # type: ignore[operator]
|
|
or isinstance(dropout, bool)
|
|
):
|
|
raise ValueError(
|
|
"dropout should be a number in range [0, 1] "
|
|
"representing the probability of an element being "
|
|
"zeroed"
|
|
)
|
|
if dropout > 0 and num_layers == 1: # type: ignore[operator]
|
|
warnings.warn(
|
|
"dropout option adds dropout after all but last "
|
|
"recurrent layer, so non-zero dropout expects "
|
|
f"num_layers greater than 1, but got dropout={dropout} and "
|
|
f"num_layers={num_layers}"
|
|
)
|
|
|
|
if mode == "LSTM":
|
|
gate_size = 4 * hidden_size
|
|
elif mode == "GRU":
|
|
gate_size = 3 * hidden_size
|
|
else:
|
|
raise ValueError("Unrecognized RNN mode: " + mode)
|
|
|
|
_all_weight_values = []
|
|
for layer in range(num_layers):
|
|
for _ in range(num_directions):
|
|
layer_input_size = (
|
|
input_size if layer == 0 else hidden_size * num_directions
|
|
)
|
|
|
|
w_ih = torch.randn(gate_size, layer_input_size).to(torch.float)
|
|
w_hh = torch.randn(gate_size, hidden_size).to(torch.float)
|
|
b_ih = torch.randn(gate_size).to(torch.float)
|
|
b_hh = torch.randn(gate_size).to(torch.float)
|
|
if dtype == torch.qint8:
|
|
w_ih = torch.quantize_per_tensor(
|
|
w_ih, scale=0.1, zero_point=0, dtype=torch.qint8
|
|
)
|
|
w_hh = torch.quantize_per_tensor(
|
|
w_hh, scale=0.1, zero_point=0, dtype=torch.qint8
|
|
)
|
|
packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
|
|
packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
|
|
if self.version is None or self.version < 2:
|
|
cell_params = (
|
|
torch.ops.quantized.make_quantized_cell_params_dynamic(
|
|
packed_ih, packed_hh, b_ih, b_hh
|
|
)
|
|
)
|
|
else:
|
|
cell_params = (
|
|
torch.ops.quantized.make_quantized_cell_params_dynamic(
|
|
packed_ih, packed_hh, b_ih, b_hh, True
|
|
)
|
|
)
|
|
else:
|
|
packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
|
|
packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
|
|
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
|
|
packed_ih, packed_hh
|
|
)
|
|
|
|
_all_weight_values.append(PackedParameter(cell_params))
|
|
self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedRNN"
|
|
|
|
def extra_repr(self):
|
|
s = "{input_size}, {hidden_size}"
|
|
if self.num_layers != 1:
|
|
s += ", num_layers={num_layers}"
|
|
if self.bias is not True:
|
|
s += ", bias={bias}"
|
|
if self.batch_first is not False:
|
|
s += ", batch_first={batch_first}"
|
|
if self.dropout != 0:
|
|
s += ", dropout={dropout}"
|
|
if self.bidirectional is not False:
|
|
s += ", bidirectional={bidirectional}"
|
|
return s.format(**self.__dict__)
|
|
|
|
def __repr__(self):
|
|
# We don't want to show `ModuleList` children, hence custom
|
|
# `__repr__`. This is the same as nn.Module.__repr__, except the check
|
|
# for the `PackedParameter` and `nn.ModuleList`.
|
|
# You should still override `extra_repr` to add more info.
|
|
extra_lines = []
|
|
extra_repr = self.extra_repr()
|
|
# empty string will be split into list ['']
|
|
if extra_repr:
|
|
extra_lines = extra_repr.split("\n")
|
|
child_lines = []
|
|
for key, module in self._modules.items():
|
|
if isinstance(module, (PackedParameter, nn.ModuleList)):
|
|
continue
|
|
mod_str = repr(module)
|
|
mod_str = nn.modules.module._addindent(mod_str, 2)
|
|
child_lines.append("(" + key + "): " + mod_str)
|
|
lines = extra_lines + child_lines
|
|
|
|
main_str = self._get_name() + "("
|
|
if lines:
|
|
# simple one-liner info, which most builtin Modules will use
|
|
if len(extra_lines) == 1 and not child_lines:
|
|
main_str += extra_lines[0]
|
|
else:
|
|
main_str += "\n " + "\n ".join(lines) + "\n"
|
|
|
|
main_str += ")"
|
|
return main_str
|
|
|
|
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
|
expected_input_dim = 2 if batch_sizes is not None else 3
|
|
if input.dim() != expected_input_dim:
|
|
raise RuntimeError(
|
|
f"input must have {expected_input_dim} dimensions, got {input.dim()}"
|
|
)
|
|
if self.input_size != input.size(-1):
|
|
raise RuntimeError(
|
|
f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
|
|
)
|
|
|
|
def get_expected_hidden_size(
|
|
self, input: Tensor, batch_sizes: Optional[Tensor]
|
|
) -> tuple[int, int, int]:
|
|
if batch_sizes is not None:
|
|
mini_batch = int(batch_sizes[0])
|
|
else:
|
|
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
expected_hidden_size = (
|
|
self.num_layers * num_directions,
|
|
mini_batch,
|
|
self.hidden_size,
|
|
)
|
|
return expected_hidden_size
|
|
|
|
def check_hidden_size(
|
|
self,
|
|
hx: Tensor,
|
|
expected_hidden_size: tuple[int, int, int],
|
|
msg: str = "Expected hidden size {}, got {}",
|
|
) -> None:
|
|
if hx.size() != expected_hidden_size:
|
|
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
|
|
|
|
def check_forward_args(
|
|
self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
|
|
) -> None:
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
self.check_hidden_size(
|
|
hidden, expected_hidden_size, msg="Expected hidden size {}, got {}"
|
|
)
|
|
|
|
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
|
|
if permutation is None:
|
|
return hx
|
|
return _apply_permutation(hx, permutation)
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
):
|
|
version = local_metadata.get("version", None)
|
|
self.version = version
|
|
super()._load_from_state_dict(
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
False,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
)
|
|
|
|
def set_weight_bias(self, weight_bias_dict):
|
|
def weight_bias_name(ihhh, layer, suffix):
|
|
weight_name = f"weight_{ihhh}_l{layer}{suffix}"
|
|
bias_name = f"bias_{ihhh}_l{layer}{suffix}"
|
|
return weight_name, bias_name
|
|
|
|
num_directions = 2 if self.bidirectional else 1
|
|
# TODO: dedup with __init__ of RNNBase
|
|
_all_weight_values = []
|
|
for layer in range(self.num_layers):
|
|
for direction in range(num_directions):
|
|
suffix = "_reverse" if direction == 1 else ""
|
|
w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix)
|
|
w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix)
|
|
w_ih = weight_bias_dict[w_ih_name]
|
|
b_ih = weight_bias_dict[b_ih_name]
|
|
w_hh = weight_bias_dict[w_hh_name]
|
|
b_hh = weight_bias_dict[b_hh_name]
|
|
if w_ih.dtype == torch.qint8:
|
|
packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
|
|
packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
|
|
if self.version is None or self.version < 2:
|
|
cell_params = (
|
|
torch.ops.quantized.make_quantized_cell_params_dynamic(
|
|
packed_ih, packed_hh, b_ih, b_hh
|
|
)
|
|
)
|
|
else:
|
|
cell_params = (
|
|
torch.ops.quantized.make_quantized_cell_params_dynamic(
|
|
packed_ih, packed_hh, b_ih, b_hh, True
|
|
)
|
|
)
|
|
else:
|
|
packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
|
|
packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
|
|
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
|
|
packed_ih, packed_hh
|
|
)
|
|
|
|
_all_weight_values.append(PackedParameter(cell_params))
|
|
self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
assert type(mod) in {
|
|
torch.nn.LSTM,
|
|
torch.nn.GRU,
|
|
}, "nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU"
|
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
|
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer_method = mod.qconfig.weight
|
|
else:
|
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
# import until we need it.
|
|
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
|
|
weight_observer_method = default_dynamic_qconfig.weight
|
|
|
|
dtype = weight_observer_method().dtype
|
|
supported_scalar_types = [torch.qint8, torch.float16]
|
|
if dtype not in supported_scalar_types:
|
|
raise RuntimeError(
|
|
f"Unsupported dtype for dynamic RNN quantization: {dtype}"
|
|
)
|
|
# RNNBase can be either LSTM or GRU
|
|
qRNNBase: Union[LSTM, GRU]
|
|
if mod.mode == "LSTM":
|
|
qRNNBase = LSTM(
|
|
mod.input_size,
|
|
mod.hidden_size,
|
|
mod.num_layers,
|
|
mod.bias,
|
|
mod.batch_first,
|
|
mod.dropout,
|
|
mod.bidirectional,
|
|
dtype,
|
|
)
|
|
elif mod.mode == "GRU":
|
|
qRNNBase = GRU(
|
|
mod.input_size,
|
|
mod.hidden_size,
|
|
mod.num_layers,
|
|
mod.bias,
|
|
mod.batch_first,
|
|
mod.dropout,
|
|
mod.bidirectional,
|
|
dtype,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Only LSTM/GRU is supported for QuantizedRNN for now"
|
|
)
|
|
|
|
num_directions = 2 if mod.bidirectional else 1
|
|
|
|
assert mod.bias
|
|
|
|
_all_weight_values = []
|
|
for layer in range(qRNNBase.num_layers):
|
|
for direction in range(num_directions):
|
|
suffix = "_reverse" if direction == 1 else ""
|
|
|
|
def retrieve_weight_bias(ihhh):
|
|
weight_name = f"weight_{ihhh}_l{layer}{suffix}"
|
|
bias_name = f"bias_{ihhh}_l{layer}{suffix}"
|
|
weight = getattr(mod, weight_name)
|
|
bias = getattr(mod, bias_name)
|
|
return weight, bias
|
|
|
|
weight_ih, bias_ih = retrieve_weight_bias("ih")
|
|
weight_hh, bias_hh = retrieve_weight_bias("hh")
|
|
|
|
if dtype == torch.qint8:
|
|
|
|
def quantize_and_pack(w, b):
|
|
weight_observer = weight_observer_method()
|
|
weight_observer(w)
|
|
qweight = _quantize_weight(w.float(), weight_observer)
|
|
packed_weight = torch.ops.quantized.linear_prepack(qweight, b)
|
|
return packed_weight
|
|
|
|
packed_ih = quantize_and_pack(weight_ih, bias_ih)
|
|
packed_hh = quantize_and_pack(weight_hh, bias_hh)
|
|
if qRNNBase.version is None or qRNNBase.version < 2:
|
|
cell_params = (
|
|
torch.ops.quantized.make_quantized_cell_params_dynamic(
|
|
packed_ih, packed_hh, bias_ih, bias_hh
|
|
)
|
|
)
|
|
else:
|
|
cell_params = (
|
|
torch.ops.quantized.make_quantized_cell_params_dynamic(
|
|
packed_ih, packed_hh, bias_ih, bias_hh, True
|
|
)
|
|
)
|
|
|
|
elif dtype == torch.float16:
|
|
packed_ih = torch.ops.quantized.linear_prepack_fp16(
|
|
weight_ih.float(), bias_ih
|
|
)
|
|
packed_hh = torch.ops.quantized.linear_prepack_fp16(
|
|
weight_hh.float(), bias_hh
|
|
)
|
|
|
|
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
|
|
packed_ih, packed_hh
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"Unsupported dtype specified for dynamic quantized LSTM!"
|
|
)
|
|
|
|
_all_weight_values.append(PackedParameter(cell_params))
|
|
qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values)
|
|
|
|
return qRNNBase
|
|
|
|
def _weight_bias(self):
|
|
# Returns a dict of weights and biases
|
|
weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}}
|
|
count = 0
|
|
num_directions = 2 if self.bidirectional else 1
|
|
for layer in range(self.num_layers):
|
|
for direction in range(num_directions):
|
|
suffix = "_reverse" if direction == 1 else ""
|
|
key_name1 = f"weight_ih_l{layer}{suffix}"
|
|
key_name2 = f"weight_hh_l{layer}{suffix}"
|
|
# packed weights are part of torchbind class, CellParamsSerializationType
|
|
# Within the packed weight class, the weight and bias are accessible as Tensors
|
|
packed_weight_bias = self._all_weight_values[ # type: ignore[index]
|
|
count
|
|
].param.__getstate__()[0][4]
|
|
weight_bias_dict["weight"][key_name1] = packed_weight_bias[
|
|
0
|
|
].__getstate__()[0][0]
|
|
weight_bias_dict["weight"][key_name2] = packed_weight_bias[
|
|
1
|
|
].__getstate__()[0][0]
|
|
key_name1 = f"bias_ih_l{layer}{suffix}"
|
|
key_name2 = f"bias_hh_l{layer}{suffix}"
|
|
weight_bias_dict["bias"][key_name1] = packed_weight_bias[
|
|
0
|
|
].__getstate__()[0][1]
|
|
weight_bias_dict["bias"][key_name2] = packed_weight_bias[
|
|
1
|
|
].__getstate__()[0][1]
|
|
count = count + 1
|
|
return weight_bias_dict
|
|
|
|
def get_weight(self):
|
|
return self._weight_bias()["weight"]
|
|
|
|
def get_bias(self):
|
|
return self._weight_bias()["bias"]
|
|
|
|
|
|
class LSTM(RNNBase):
|
|
r"""
|
|
A dynamic quantized LSTM module with floating point tensor as inputs and outputs.
|
|
We adopt the same interface as `torch.nn.LSTM`, please see
|
|
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation.
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> rnn = nn.LSTM(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> c0 = torch.randn(2, 3, 20)
|
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
|
"""
|
|
|
|
# pyrefly: ignore # bad-override
|
|
_FLOAT_MODULE = nn.LSTM
|
|
|
|
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__("LSTM", *args, **kwargs)
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedLSTM"
|
|
|
|
def forward_impl(
|
|
self,
|
|
input: Tensor,
|
|
hx: Optional[tuple[Tensor, Tensor]],
|
|
batch_sizes: Optional[Tensor],
|
|
max_batch_size: int,
|
|
sorted_indices: Optional[Tensor],
|
|
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
zeros = torch.zeros(
|
|
self.num_layers * num_directions,
|
|
max_batch_size,
|
|
self.hidden_size,
|
|
dtype=input.dtype,
|
|
device=input.device,
|
|
)
|
|
hx = (zeros, zeros)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
|
|
_all_params = [m.param for m in self._all_weight_values]
|
|
if batch_sizes is None:
|
|
result = torch.quantized_lstm(
|
|
input,
|
|
hx,
|
|
_all_params,
|
|
self.bias,
|
|
self.num_layers,
|
|
float(self.dropout),
|
|
self.training,
|
|
self.bidirectional,
|
|
self.batch_first,
|
|
dtype=self.dtype,
|
|
use_dynamic=True,
|
|
)
|
|
else:
|
|
result = torch.quantized_lstm(
|
|
input,
|
|
batch_sizes,
|
|
hx,
|
|
_all_params,
|
|
self.bias,
|
|
self.num_layers,
|
|
float(self.dropout),
|
|
self.training,
|
|
self.bidirectional,
|
|
dtype=self.dtype,
|
|
use_dynamic=True,
|
|
)
|
|
output = result[0]
|
|
hidden = result[1:]
|
|
|
|
return output, hidden
|
|
|
|
@torch.jit.export
|
|
def forward_tensor(
|
|
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
|
|
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
output, hidden = self.forward_impl(
|
|
input, hx, batch_sizes, max_batch_size, sorted_indices
|
|
)
|
|
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
@torch.jit.export
|
|
def forward_packed(
|
|
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
|
|
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]:
|
|
input_, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = int(batch_sizes[0])
|
|
|
|
output_, hidden = self.forward_impl(
|
|
input_, hx, batch_sizes, max_batch_size, sorted_indices
|
|
)
|
|
|
|
output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
# "type: ignore" is required due to issue #43072
|
|
def permute_hidden( # type: ignore[override]
|
|
self,
|
|
hx: tuple[Tensor, Tensor],
|
|
permutation: Optional[Tensor],
|
|
) -> tuple[Tensor, Tensor]:
|
|
if permutation is None:
|
|
return hx
|
|
return _apply_permutation(hx[0], permutation), _apply_permutation(
|
|
hx[1], permutation
|
|
)
|
|
|
|
# "type: ignore" is required due to issue #43072
|
|
def check_forward_args( # type: ignore[override]
|
|
self,
|
|
input: Tensor,
|
|
hidden: tuple[Tensor, Tensor],
|
|
batch_sizes: Optional[Tensor],
|
|
) -> None:
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(
|
|
hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}"
|
|
)
|
|
self.check_hidden_size(
|
|
hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}"
|
|
)
|
|
|
|
@torch.jit.ignore
|
|
def forward(self, input, hx=None):
|
|
if isinstance(input, PackedSequence):
|
|
return self.forward_packed(input, hx)
|
|
else:
|
|
return self.forward_tensor(input, hx)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
return super().from_float(
|
|
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|
|
|
|
@classmethod
|
|
def from_reference(cls, ref_mod):
|
|
assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
|
|
"exists in LSTM, may need to relax the assumption to support the use case"
|
|
qmod = cls(
|
|
ref_mod.input_size,
|
|
ref_mod.hidden_size,
|
|
ref_mod.num_layers,
|
|
ref_mod.bias,
|
|
ref_mod.batch_first,
|
|
ref_mod.dropout,
|
|
ref_mod.bidirectional,
|
|
# assuming there is layer 0, which should be OK
|
|
ref_mod.weight_ih_l0_dtype,
|
|
)
|
|
qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
|
|
return qmod
|
|
|
|
|
|
class GRU(RNNBase):
|
|
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
\begin{array}{ll}
|
|
r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
|
|
z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
|
|
n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
|
|
h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
|
|
\end{array}
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
|
|
at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
|
|
at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
|
|
:math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
|
|
:math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
|
|
|
|
In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
|
|
(:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
|
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
|
|
variable which is :math:`0` with probability :attr:`dropout`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two GRUs together to form a `stacked GRU`,
|
|
with the second GRU taking in outputs of the first GRU and
|
|
computing the final results. Default: 1
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
GRU layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
|
|
|
|
Inputs: input, h_0
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence. The input can also be a packed variable length
|
|
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
|
|
for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
Defaults to zero if not provided. If the RNN is bidirectional,
|
|
num_directions should be 2, else it should be 1.
|
|
|
|
Outputs: output, h_n
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features h_t from the last layer of the GRU,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
|
|
given as the input, the output will also be a packed sequence.
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
|
|
|
|
Shape:
|
|
- Input1: :math:`(L, N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
|
|
- Input2: :math:`(S, N, H_{out})` tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
:math:`H_{out}=\text{hidden\_size}`
|
|
Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
|
|
If the RNN is bidirectional, num_directions should be 2, else it should be 1.
|
|
- Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
|
|
- Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
|
(W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
|
|
Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
|
|
(W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
|
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
|
|
(b_ir|b_iz|b_in), of shape `(3*hidden_size)`
|
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
|
|
(b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. note::
|
|
The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
|
|
In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
|
|
previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
|
|
`W` and addition of bias:
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
|
|
\end{aligned}
|
|
|
|
This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
|
|
\end{aligned}
|
|
|
|
This implementation differs on purpose for efficiency.
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> rnn = nn.GRU(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
"""
|
|
|
|
# pyrefly: ignore # bad-override
|
|
_FLOAT_MODULE = nn.GRU
|
|
|
|
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__("GRU", *args, **kwargs)
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedGRU"
|
|
|
|
def check_forward_args(
|
|
self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
|
|
) -> None:
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(
|
|
hidden, expected_hidden_size, "Expected hidden size {}, got {}"
|
|
)
|
|
|
|
def forward_impl(
|
|
self,
|
|
input: Tensor,
|
|
hx: Optional[Tensor],
|
|
batch_sizes: Optional[Tensor],
|
|
max_batch_size: int,
|
|
sorted_indices: Optional[Tensor],
|
|
) -> tuple[Tensor, Tensor]:
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
zeros = torch.zeros(
|
|
self.num_layers * num_directions,
|
|
max_batch_size,
|
|
self.hidden_size,
|
|
dtype=input.dtype,
|
|
device=input.device,
|
|
)
|
|
hx = zeros
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
|
|
_all_params = [m.param for m in self._all_weight_values]
|
|
if batch_sizes is None:
|
|
result = torch.quantized_gru(
|
|
input,
|
|
hx,
|
|
_all_params,
|
|
self.bias,
|
|
self.num_layers,
|
|
self.dropout,
|
|
self.training,
|
|
self.bidirectional,
|
|
self.batch_first,
|
|
)
|
|
else:
|
|
result = torch.quantized_gru(
|
|
input,
|
|
batch_sizes,
|
|
hx,
|
|
_all_params,
|
|
self.bias,
|
|
self.num_layers,
|
|
self.dropout,
|
|
self.training,
|
|
self.bidirectional,
|
|
)
|
|
output = result[0]
|
|
hidden = result[1]
|
|
|
|
return output, hidden
|
|
|
|
@torch.jit.export
|
|
def forward_tensor(
|
|
self, input: Tensor, hx: Optional[Tensor] = None
|
|
) -> tuple[Tensor, Tensor]:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
output, hidden = self.forward_impl(
|
|
input, hx, batch_sizes, max_batch_size, sorted_indices
|
|
)
|
|
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
@torch.jit.export
|
|
def forward_packed(
|
|
self, input: PackedSequence, hx: Optional[Tensor] = None
|
|
) -> tuple[PackedSequence, Tensor]:
|
|
input_, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = int(batch_sizes[0])
|
|
output_, hidden = self.forward_impl(
|
|
input_, hx, batch_sizes, max_batch_size, sorted_indices
|
|
)
|
|
|
|
output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
|
|
if permutation is None:
|
|
return hx
|
|
return _apply_permutation(hx, permutation)
|
|
|
|
@torch.jit.ignore
|
|
def forward(self, input, hx=None):
|
|
if isinstance(input, PackedSequence):
|
|
return self.forward_packed(input, hx)
|
|
else:
|
|
return self.forward_tensor(input, hx)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
return super().from_float(
|
|
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|
|
|
|
@classmethod
|
|
def from_reference(cls, ref_mod):
|
|
assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
|
|
"exists in LSTM, may need to relax the assumption to support the use case"
|
|
qmod = cls(
|
|
ref_mod.input_size,
|
|
ref_mod.hidden_size,
|
|
ref_mod.num_layers,
|
|
ref_mod.bias,
|
|
ref_mod.batch_first,
|
|
ref_mod.dropout,
|
|
ref_mod.bidirectional,
|
|
# assuming there is layer 0, which should be OK
|
|
ref_mod.weight_ih_l0_dtype,
|
|
)
|
|
qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
|
|
return qmod
|
|
|
|
|
|
class RNNCellBase(torch.nn.Module):
|
|
# _FLOAT_MODULE = nn.CellRNNBase
|
|
__constants__ = ["input_size", "hidden_size", "bias"]
|
|
|
|
def __init__(
|
|
self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8
|
|
):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.bias = bias
|
|
self.weight_dtype = dtype
|
|
if bias:
|
|
self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
|
|
self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
|
|
else:
|
|
self.register_parameter("bias_ih", None)
|
|
self.register_parameter("bias_hh", None)
|
|
|
|
weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float)
|
|
weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float)
|
|
if dtype == torch.qint8:
|
|
weight_ih = torch.quantize_per_tensor(
|
|
weight_ih, scale=1, zero_point=0, dtype=torch.qint8
|
|
)
|
|
weight_hh = torch.quantize_per_tensor(
|
|
weight_hh, scale=1, zero_point=0, dtype=torch.qint8
|
|
)
|
|
|
|
if dtype == torch.qint8:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# w_ih, w_hh
|
|
packed_weight_ih = torch.ops.quantized.linear_prepack(
|
|
weight_ih, self.bias_ih
|
|
)
|
|
packed_weight_hh = torch.ops.quantized.linear_prepack(
|
|
weight_hh, self.bias_hh
|
|
)
|
|
else:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# packed_ih, packed_hh, b_ih, b_hh
|
|
packed_weight_ih = torch.ops.quantized.linear_prepack_fp16(
|
|
weight_ih, self.bias_ih
|
|
)
|
|
packed_weight_hh = torch.ops.quantized.linear_prepack_fp16(
|
|
weight_hh, self.bias_hh
|
|
)
|
|
|
|
self._packed_weight_ih = packed_weight_ih
|
|
self._packed_weight_hh = packed_weight_hh
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedRNNBase"
|
|
|
|
def extra_repr(self):
|
|
s = "{input_size}, {hidden_size}"
|
|
if "bias" in self.__dict__ and self.bias is not True:
|
|
s += ", bias={bias}"
|
|
if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
|
|
s += ", nonlinearity={nonlinearity}"
|
|
return s.format(**self.__dict__)
|
|
|
|
def check_forward_input(self, input):
|
|
if input.size(1) != self.input_size:
|
|
raise RuntimeError(
|
|
f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}"
|
|
)
|
|
|
|
def check_forward_hidden(
|
|
self, input: Tensor, hx: Tensor, hidden_label: str = ""
|
|
) -> None:
|
|
if input.size(0) != hx.size(0):
|
|
raise RuntimeError(
|
|
f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}"
|
|
)
|
|
|
|
if hx.size(1) != self.hidden_size:
|
|
raise RuntimeError(
|
|
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}"
|
|
)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
assert type(mod) in {
|
|
torch.nn.LSTMCell,
|
|
torch.nn.GRUCell,
|
|
torch.nn.RNNCell,
|
|
}, (
|
|
"nn.quantized.dynamic.RNNCellBase.from_float \
|
|
only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell"
|
|
)
|
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
|
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer_method = mod.qconfig.weight
|
|
else:
|
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
# import until we need it.
|
|
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
|
|
weight_observer_method = default_dynamic_qconfig.weight
|
|
|
|
dtype = weight_observer_method().dtype
|
|
supported_scalar_types = [torch.qint8, torch.float16]
|
|
if dtype not in supported_scalar_types:
|
|
raise RuntimeError(
|
|
f"Unsupported dtype for dynamic RNN quantization: {dtype}"
|
|
)
|
|
|
|
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
|
|
|
|
if type(mod) is torch.nn.LSTMCell:
|
|
qRNNCellBase = LSTMCell(
|
|
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
|
)
|
|
elif type(mod) is torch.nn.GRUCell:
|
|
qRNNCellBase = GRUCell(
|
|
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
|
)
|
|
elif type(mod) is torch.nn.RNNCell:
|
|
qRNNCellBase = RNNCell(
|
|
mod.input_size,
|
|
mod.hidden_size,
|
|
bias=mod.bias,
|
|
nonlinearity=mod.nonlinearity,
|
|
dtype=dtype,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Only LSTMCell, GRUCell and RNNCell \
|
|
are supported for QuantizedRNN for now"
|
|
)
|
|
|
|
assert mod.bias
|
|
|
|
def _observe_and_quantize_weight(weight):
|
|
if dtype == torch.qint8:
|
|
weight_observer = weight_observer_method()
|
|
weight_observer(weight)
|
|
qweight = _quantize_weight(weight.float(), weight_observer)
|
|
return qweight
|
|
else:
|
|
return weight.float()
|
|
|
|
qRNNCellBase._packed_weight_ih = pack_weight_bias(
|
|
_observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype
|
|
)
|
|
qRNNCellBase._packed_weight_hh = pack_weight_bias(
|
|
_observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype
|
|
)
|
|
return qRNNCellBase
|
|
|
|
@classmethod
|
|
def from_reference(cls, ref_mod):
|
|
assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih "
|
|
"exists in reference module, may need to relax the assumption to support the use case"
|
|
if hasattr(ref_mod, "nonlinearity"):
|
|
qmod = cls(
|
|
ref_mod.input_size,
|
|
ref_mod.hidden_size,
|
|
ref_mod.bias,
|
|
ref_mod.nonlinearity,
|
|
dtype=ref_mod.weight_ih_dtype,
|
|
)
|
|
else:
|
|
qmod = cls(
|
|
ref_mod.input_size,
|
|
ref_mod.hidden_size,
|
|
ref_mod.bias,
|
|
dtype=ref_mod.weight_ih_dtype,
|
|
)
|
|
weight_bias_dict = {
|
|
"weight": {
|
|
"weight_ih": ref_mod.get_quantized_weight_ih(),
|
|
"weight_hh": ref_mod.get_quantized_weight_hh(),
|
|
},
|
|
"bias": {
|
|
"bias_ih": ref_mod.bias_ih,
|
|
"bias_hh": ref_mod.bias_hh,
|
|
},
|
|
}
|
|
qmod.set_weight_bias(weight_bias_dict)
|
|
return qmod
|
|
|
|
def _weight_bias(self):
|
|
# Returns a dict of weights and biases
|
|
weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}}
|
|
w1, b1 = self._packed_weight_ih.__getstate__()[0]
|
|
w2, b2 = self._packed_weight_hh.__getstate__()[0]
|
|
# TODO: these can be simplified to one level? e.g. using weight_ih as key
|
|
# directly
|
|
weight_bias_dict["weight"]["weight_ih"] = w1
|
|
weight_bias_dict["weight"]["weight_hh"] = w2
|
|
weight_bias_dict["bias"]["bias_ih"] = b1
|
|
weight_bias_dict["bias"]["bias_hh"] = b2
|
|
return weight_bias_dict
|
|
|
|
def get_weight(self):
|
|
return self._weight_bias()["weight"]
|
|
|
|
def get_bias(self):
|
|
return self._weight_bias()["bias"]
|
|
|
|
def set_weight_bias(self, weight_bias_dict):
|
|
# TODO: these can be simplified to one level? e.g. using weight_ih as key
|
|
# directly
|
|
self._packed_weight_ih = pack_weight_bias(
|
|
weight_bias_dict["weight"]["weight_ih"],
|
|
weight_bias_dict["bias"]["bias_ih"],
|
|
self.weight_dtype,
|
|
)
|
|
self._packed_weight_hh = pack_weight_bias(
|
|
weight_bias_dict["weight"]["weight_hh"],
|
|
weight_bias_dict["bias"]["bias_hh"],
|
|
self.weight_dtype,
|
|
)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + "_packed_weight_ih"] = self._packed_weight_ih
|
|
destination[prefix + "_packed_weight_hh"] = self._packed_weight_hh
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
):
|
|
self._packed_weight_ih = state_dict.pop(prefix + "_packed_weight_ih")
|
|
self._packed_weight_hh = state_dict.pop(prefix + "_packed_weight_hh")
|
|
super()._load_from_state_dict(
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
False,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
)
|
|
|
|
|
|
class RNNCell(RNNCellBase):
|
|
r"""An Elman RNN cell with tanh or ReLU non-linearity.
|
|
A dynamic quantized RNNCell module with floating point tensor as inputs and outputs.
|
|
Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`,
|
|
please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation.
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> rnn = nn.RNNCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
... hx = rnn(input[i], hx)
|
|
... output.append(hx)
|
|
"""
|
|
|
|
__constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"]
|
|
|
|
def __init__(
|
|
self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8
|
|
):
|
|
super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype)
|
|
self.nonlinearity = nonlinearity
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedRNNCell"
|
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(
|
|
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
|
|
)
|
|
self.check_forward_hidden(input, hx, "")
|
|
if self.nonlinearity == "tanh":
|
|
ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic(
|
|
input,
|
|
hx,
|
|
self._packed_weight_ih,
|
|
self._packed_weight_hh,
|
|
self.bias_ih,
|
|
self.bias_hh,
|
|
)
|
|
elif self.nonlinearity == "relu":
|
|
ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic(
|
|
input,
|
|
hx,
|
|
self._packed_weight_ih,
|
|
self._packed_weight_hh,
|
|
self.bias_ih,
|
|
self.bias_hh,
|
|
)
|
|
else:
|
|
ret = input # TODO: remove when jit supports exception flow
|
|
raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
|
|
return ret
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
return super().from_float(
|
|
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|
|
|
|
|
|
class LSTMCell(RNNCellBase):
|
|
r"""A long short-term memory (LSTM) cell.
|
|
|
|
A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs.
|
|
Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`,
|
|
please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation.
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> rnn = nn.LSTMCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> cx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
... hx, cx = rnn(input[i], (hx, cx))
|
|
... output.append(hx)
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, num_chunks=4, **kwargs) # type: ignore[misc]
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedLSTMCell"
|
|
|
|
def forward(
|
|
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
|
|
) -> tuple[Tensor, Tensor]:
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
zeros = torch.zeros(
|
|
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
|
|
)
|
|
hx = (zeros, zeros)
|
|
self.check_forward_hidden(input, hx[0], "[0]")
|
|
self.check_forward_hidden(input, hx[1], "[1]")
|
|
return torch.ops.quantized.quantized_lstm_cell_dynamic(
|
|
input,
|
|
hx,
|
|
self._packed_weight_ih,
|
|
self._packed_weight_hh,
|
|
self.bias_ih,
|
|
self.bias_hh,
|
|
)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
return super().from_float(
|
|
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|
|
|
|
|
|
class GRUCell(RNNCellBase):
|
|
r"""A gated recurrent unit (GRU) cell
|
|
|
|
A dynamic quantized GRUCell module with floating point tensor as inputs and outputs.
|
|
Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`,
|
|
please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation.
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> rnn = nn.GRUCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
... hx = rnn(input[i], hx)
|
|
... output.append(hx)
|
|
"""
|
|
|
|
def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
|
|
super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype)
|
|
|
|
def _get_name(self):
|
|
return "DynamicQuantizedGRUCell"
|
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(
|
|
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
|
|
)
|
|
self.check_forward_hidden(input, hx, "")
|
|
return torch.ops.quantized.quantized_gru_cell_dynamic(
|
|
input,
|
|
hx,
|
|
self._packed_weight_ih,
|
|
self._packed_weight_hh,
|
|
self.bias_ih,
|
|
self.bias_hh,
|
|
)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
return super().from_float(
|
|
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|