mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nonuniform observer class and tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78680 Approved by: https://github.com/dzdang
This commit is contained in:
@ -0,0 +1,23 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
class TestNonUniformObserver(unittest.TestCase):
|
||||
def test_calculate_qparams(self):
|
||||
t = torch.Tensor()
|
||||
obs = APoTObserver(t, t, t, 0, 0)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
obs.calculate_qparams()
|
||||
|
||||
def test_override_calculate_qparams(self):
|
||||
t = torch.Tensor()
|
||||
obs = APoTObserver(t, t, t, 0, 0)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
obs._calculate_qparams()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
47
torch/ao/quantization/experimental/observer.py
Normal file
47
torch/ao/quantization/experimental/observer.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""
|
||||
This module implements nonuniform observers used to collect statistics about
|
||||
the values observed during calibration (PTQ) or training (QAT).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.observer import ObserverBase
|
||||
from typing import Tuple
|
||||
|
||||
class NonUniformQuantizationObserverBase(ObserverBase):
|
||||
quant_min = None
|
||||
quant_max = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_val: torch.Tensor,
|
||||
max_val: torch.Tensor,
|
||||
level_indices: torch.Tensor,
|
||||
b: int,
|
||||
k: int) -> None:
|
||||
super().__init__
|
||||
|
||||
def _calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
class APoTObserver(NonUniformQuantizationObserverBase):
|
||||
alpha = 0
|
||||
gamma = 0
|
||||
level_indices = torch.Tensor()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_val: torch.Tensor,
|
||||
max_val: torch.Tensor,
|
||||
level_indices: torch.Tensor,
|
||||
b: int,
|
||||
k: int) -> None:
|
||||
super(APoTObserver, self).__init__(min_val, max_val, level_indices, b, k)
|
||||
|
||||
def calculate_qparams(self):
|
||||
return self._calculate_qparams()
|
||||
|
||||
def _calculate_qparams(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x_orig):
|
||||
pass
|
Reference in New Issue
Block a user