Files
pytorch/torch/ao/quantization/experimental/quantizer.py
asl3 368018530e [quant] Implement forward and backward autograd functions for fake quantize (#81438)
### Summary:
This PR implements custom autograd functions for forward and backward to be used in APoT fake quantization. The implementation follows this doc about custom autograd functions: https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html

### Test Plan:
Run tests with: `python test/quantization/core/experimental/test_fake_quantize.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81438
Approved by: https://github.com/jerryzh168
2022-07-19 02:15:30 +00:00

100 lines
3.8 KiB
Python

import torch
from torch import Tensor
import numpy as np
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# class to store APoT quantizer and
# implement quantize and dequantize
class APoTQuantizer():
alpha: torch.Tensor
gamma: torch.Tensor
quantization_levels: torch.Tensor
level_indices: torch.Tensor
def __init__(
self,
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor) -> None:
self.alpha = alpha
self.gamma = gamma
self.quantization_levels = quantization_levels
self.level_indices = level_indices
r""" Quantizes fp Tensor to integer APoT representation.
Conversion is based on the qparams from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: APoT Tensor representation of tensor2quantize
"""
def quantize(self, tensor2quantize: Tensor):
result = torch.tensor([])
# map float_to_apot over tensor2quantize elements
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
self.quantization_levels,
self.level_indices,
self.alpha))
# convert to APoT int representation for dtype
tensor2quantize = tensor2quantize.int()
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize)
return result
r""" Dequantizes integer Tensor to floating point (fp) representation
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
apot_tensor: quantized APoT Tensor to dequantize
Returns:
result: fp representation of input Tensor
"""
def dequantize(self, apot_tensor) -> Tensor:
apot_tensor_data = apot_tensor.data
# map apot_to_float over tensor2quantize elements
result_temp = np.empty(apot_tensor_data.size())
for ele in apot_tensor_data:
new_ele = apot_to_float(ele, self.quantization_levels, self.level_indices)
np.append(result_temp, new_ele)
result = torch.from_numpy(result_temp).int()
return result
def q_apot_alpha(self) -> float:
raise NotImplementedError
r""" Global method to create quantizer and call quantizer quantize_APoT
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: ApoT Tensor representation of tensor2quantize
"""
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quantize(tensor2quantize)
return result
r""" Global method to create quantizer and call quantizer dequantize_APoT
Args:
apot_tensor: APoT Tensor to dequantize
Returns:
result: fp Tensor dequantized from apot_tensor
"""
def dequantize_APoT(apot_tensor) -> Tensor:
quantizer = apot_tensor.quantizer
result = quantizer.dequantize(apot_tensor)
return result