Add nonuniform observer class and tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78680

Approved by: https://github.com/dzdang
This commit is contained in:
asl3
2022-06-02 04:58:29 -07:00
committed by PyTorch MergeBot
parent eb88ea01b5
commit 308d813d45
2 changed files with 70 additions and 0 deletions

View File

@ -0,0 +1,23 @@
# Owner(s): ["oncall: quantization"]
from torch.ao.quantization.experimental.observer import APoTObserver
import unittest
import torch
class TestNonUniformObserver(unittest.TestCase):
def test_calculate_qparams(self):
t = torch.Tensor()
obs = APoTObserver(t, t, t, 0, 0)
with self.assertRaises(NotImplementedError):
obs.calculate_qparams()
def test_override_calculate_qparams(self):
t = torch.Tensor()
obs = APoTObserver(t, t, t, 0, 0)
with self.assertRaises(NotImplementedError):
obs._calculate_qparams()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,47 @@
"""
This module implements nonuniform observers used to collect statistics about
the values observed during calibration (PTQ) or training (QAT).
"""
import torch
from torch.ao.quantization.observer import ObserverBase
from typing import Tuple
class NonUniformQuantizationObserverBase(ObserverBase):
quant_min = None
quant_max = None
def __init__(
self,
min_val: torch.Tensor,
max_val: torch.Tensor,
level_indices: torch.Tensor,
b: int,
k: int) -> None:
super().__init__
def _calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
pass
class APoTObserver(NonUniformQuantizationObserverBase):
alpha = 0
gamma = 0
level_indices = torch.Tensor()
def __init__(
self,
min_val: torch.Tensor,
max_val: torch.Tensor,
level_indices: torch.Tensor,
b: int,
k: int) -> None:
super(APoTObserver, self).__init__(min_val, max_val, level_indices, b, k)
def calculate_qparams(self):
return self._calculate_qparams()
def _calculate_qparams(self):
raise NotImplementedError
def forward(self, x_orig):
pass