Files
pytorch/torch/ao/quantization/experimental/quantizer.py
Nikita Shulga 634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00

137 lines
5.5 KiB
Python

import torch
from torch import Tensor
import numpy as np
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
# class to store APoT quantizer and
# implement quantize and dequantize
class APoTQuantizer():
alpha: torch.Tensor
gamma: torch.Tensor
quantization_levels: torch.Tensor
level_indices: torch.Tensor
def __init__(
self,
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor) -> None:
self.alpha = alpha
self.gamma = gamma
self.quantization_levels = quantization_levels
self.level_indices = level_indices
r""" Quantizes fp Tensor to integer APoT representation.
Conversion is based on the qparams from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: APoT Tensor representation of tensor2quantize
"""
def quantize(self, tensor2quantize: Tensor):
result = torch.tensor([])
# map float_to_apot over tensor2quantize elements
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
self.quantization_levels,
self.level_indices,
self.alpha))
# convert to APoT int representation for dtype
tensor2quantize = tensor2quantize.int()
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment]
return result
r""" Dequantizes integer Tensor to floating point (fp) representation
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: fp reduced precision representation of input Tensor
"""
def dequantize(self, apot_tensor) -> Tensor:
orig_size = apot_tensor.data.size()
apot_tensor_data = apot_tensor.data.flatten()
print(apot_tensor_data)
# map apot_to_float over tensor2quantize elements
result_temp = np.empty(shape=apot_tensor_data.size())
for i in range(len(apot_tensor_data)):
new_ele = apot_to_float(apot_tensor_data[i], self.quantization_levels, self.level_indices)
result_temp[i] = new_ele
result = torch.from_numpy(result_temp).reshape(orig_size)
return result
r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
apot_tensor: quantized APoT Tensor to dequantize
Returns:
result: fp representation of input Tensor
"""
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
levels_lst = list(self.quantization_levels)
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg]
return result
def q_apot_alpha(self) -> float:
raise NotImplementedError
r""" Global method to create quantizer and call quantizer quantize_APoT
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: ApoT Tensor representation of tensor2quantize
"""
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quantize(tensor2quantize)
return result
r""" Global method to create quantizer and call quantizer dequantize_APoT
Args:
apot_tensor: APoT Tensor to dequantize
Returns:
result: fp Tensor dequantized from apot_tensor
"""
def dequantize_APoT(apot_tensor) -> Tensor:
quantizer = apot_tensor.quantizer
result = quantizer.dequantize(apot_tensor)
return result
r""" Global method to create quantizer and call quantizer quant_dequant
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: fp reduced precision Tensor from tensor2quantize
"""
def quant_dequant_APoT(tensor2quantize: Tensor,
alpha: Tensor,
gamma: Tensor,
quantization_levels: Tensor,
level_indices: Tensor) -> Tensor:
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quant_dequant(tensor2quantize)
return result