mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings. I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519 Approved by: https://github.com/ezyang
133 lines
5.9 KiB
Python
133 lines
5.9 KiB
Python
import torch
|
|
import torch.ao.nn.quantized as nnq
|
|
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
import torch.ao.nn.intrinsic as nni
|
|
|
|
__all__ = [
|
|
"Linear",
|
|
]
|
|
|
|
|
|
class Linear(nnq.Linear):
|
|
r"""
|
|
A dynamic quantized linear module with floating point tensor as inputs and outputs.
|
|
We adopt the same interface as `torch.nn.Linear`, please see
|
|
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
|
|
|
|
Similar to :class:`torch.nn.Linear`, attributes will be randomly
|
|
initialized at module creation time and will be overwritten later
|
|
|
|
Attributes:
|
|
weight (Tensor): the non-learnable quantized weights of the module which are of
|
|
shape :math:`(\text{out\_features}, \text{in\_features})`.
|
|
bias (Tensor): the non-learnable floating point bias of the module of shape
|
|
:math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
|
|
the values are initialized to zero.
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> m = nn.quantized.dynamic.Linear(20, 30)
|
|
>>> input = torch.randn(128, 20)
|
|
>>> output = m(input)
|
|
>>> print(output.size())
|
|
torch.Size([128, 30])
|
|
"""
|
|
# version used in this class is different from the parent class nnq.Linear
|
|
_version = 4
|
|
|
|
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
|
|
super().__init__(in_features, out_features, bias_, dtype=dtype)
|
|
# We don't muck around with buffers or attributes or anything here
|
|
# to keep the module simple. *everything* is simply a Python attribute.
|
|
# Serialization logic is explicitly handled in the below serialization and
|
|
# deserialization modules
|
|
self.version = 4
|
|
|
|
def forward(self, x):
|
|
# Note that we can handle self.bias == None case.
|
|
if self._packed_params.dtype == torch.qint8:
|
|
if self.version is None or self.version < 4:
|
|
Y = torch.ops.quantized.linear_dynamic(
|
|
x, self._packed_params._packed_params)
|
|
else:
|
|
Y = torch.ops.quantized.linear_dynamic(
|
|
x, self._packed_params._packed_params, reduce_range=True)
|
|
elif self._packed_params.dtype == torch.float16:
|
|
Y = torch.ops.quantized.linear_dynamic_fp16(
|
|
x, self._packed_params._packed_params)
|
|
else:
|
|
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
|
|
return Y.to(x.dtype)
|
|
|
|
def _get_name(self):
|
|
return 'DynamicQuantizedLinear'
|
|
|
|
def extra_repr(self):
|
|
extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format(
|
|
self.in_features, self.out_features, self._packed_params.dtype
|
|
)
|
|
if self._packed_params.dtype == torch.qint8:
|
|
extra_repr_str += f', qscheme={self.weight().qscheme()}'
|
|
return extra_repr_str
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
self.version = version
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
r"""Create a dynamic quantized module from a float module or qparams_dict
|
|
|
|
Args:
|
|
mod (Module): a float module, either produced by torch.ao.quantization
|
|
utilities or provided by the user
|
|
"""
|
|
float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
|
|
torch.ao.nn.intrinsic.modules.fused.LinearReLU, torch.ao.nn.qat.dynamic.Linear]
|
|
|
|
assert type(mod) in float_modules, \
|
|
'nn.quantized.dynamic.Linear.from_float only works for one of' + \
|
|
str([float_mod.__name__ for float_mod in float_modules])
|
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
|
if type(mod) == nni.LinearReLU:
|
|
mod = mod[0]
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer = mod.qconfig.weight()
|
|
else:
|
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
# import until we need it.
|
|
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
weight_observer = default_dynamic_qconfig.weight()
|
|
dtype = weight_observer.dtype
|
|
assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \
|
|
f"dynamic quantized linear are qint8 and float16 got: {dtype}"
|
|
weight_observer(mod.weight)
|
|
if dtype == torch.qint8:
|
|
qweight = _quantize_weight(mod.weight.float(), weight_observer)
|
|
elif dtype == torch.float16:
|
|
qweight = mod.weight.float()
|
|
else:
|
|
raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
|
|
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
|
qlinear.set_weight_bias(qweight, mod.bias)
|
|
return qlinear
|
|
|
|
@classmethod
|
|
def from_reference(cls, ref_qlinear):
|
|
""" Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
|
|
module
|
|
Args:
|
|
ref_qlinear (Module): a reference quantized module, either produced by
|
|
torch.ao.quantization functions or provided by the user
|
|
"""
|
|
qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features, dtype=ref_qlinear.weight_dtype)
|
|
qweight = ref_qlinear.get_quantized_weight()
|
|
bias = ref_qlinear.bias
|
|
qlinear.set_weight_bias(qweight, bias)
|
|
return qlinear
|