mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[quant] Modify APoT nonuniform quantization workflow (#80075)
### Summary: This PR updates the design of APoT Observer, Quantizer, and Tensor to be more consistent with their uniform counterparts in the PyTorch framework. APoT Observer now calculates alpha as the max between the absolute values of the max and min values in the input tensor. APoT Quantizer is modified so its instance methods quantize_APoT and dequantize_APoT are called by their global method counterparts. APoT Tensor is modified to account for the new method definition of the `quantize_APoT` from APoT Quantizer. ### Test Plan: Run APoT Observer class unit tests with: `python pytorch/test/quantization/core/experimental/test_nonuniform_observer.py` Run APoT Quantize class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantizer.py` Run APoT Tensor class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantized_tensor.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80075 Approved by: https://github.com/jerryzh168
This commit is contained in:
3
mypy.ini
3
mypy.ini
@ -67,6 +67,9 @@ ignore_missing_imports = True
|
||||
[mypy-torch.ao.quantization.experimental.observer]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.experimental.APoT_tensor]
|
||||
ignore_missing_imports = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
@ -2,20 +2,23 @@
|
||||
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
class TestNonUniformObserver(unittest.TestCase):
|
||||
"""
|
||||
Test case 1
|
||||
Test case 1: calculate_qparams
|
||||
Test that error is thrown when k == 0
|
||||
"""
|
||||
def test_calculate_qparams_invalid(self):
|
||||
obs = APoTObserver(max_val=0.0, b=0, k=0)
|
||||
obs = APoTObserver(b=0, k=0)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
obs_result = obs.calculate_qparams(signed=False)
|
||||
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False,
|
||||
min_val=torch.tensor([0]),
|
||||
max_val=torch.tensor([0]))
|
||||
|
||||
"""
|
||||
Test case 2
|
||||
Test case 2: calculate_qparams
|
||||
APoT paper example: https://arxiv.org/pdf/1909.13144.pdf
|
||||
Assume hardcoded parameters:
|
||||
* b = 4 (total number of bits across all terms)
|
||||
@ -24,8 +27,18 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
* note: b = k * n
|
||||
"""
|
||||
def test_calculate_qparams_2terms(self):
|
||||
obs = APoTObserver(max_val=1.0, b=4, k=2)
|
||||
obs_result = obs.calculate_qparams(signed=False)
|
||||
obs = APoTObserver(b=4, k=2)
|
||||
|
||||
min_val = torch.tensor([0])
|
||||
max_val = torch.tensor([1])
|
||||
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False,
|
||||
min_val=min_val,
|
||||
max_val=max_val)
|
||||
|
||||
alpha_test = torch.max(-min_val, max_val)
|
||||
|
||||
# check alpha value
|
||||
self.assertEqual(alpha, alpha_test)
|
||||
|
||||
# calculate expected gamma value
|
||||
gamma_test = 0
|
||||
@ -35,32 +48,41 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
gamma_test = 1 / gamma_test
|
||||
|
||||
# check gamma value
|
||||
self.assertEqual(obs_result[0], gamma_test)
|
||||
self.assertEqual(gamma, gamma_test)
|
||||
|
||||
# check quantization levels size
|
||||
quantlevels_size_test = int(len(obs_result[1]))
|
||||
quantlevels_size_test = int(len(quantization_levels))
|
||||
quantlevels_size = 2**4
|
||||
self.assertEqual(quantlevels_size_test, quantlevels_size)
|
||||
|
||||
# check level indices size
|
||||
levelindices_size_test = int(len(obs_result[2]))
|
||||
levelindices_size_test = int(len(level_indices))
|
||||
self.assertEqual(levelindices_size_test, 16)
|
||||
|
||||
# check level indices unique values
|
||||
level_indices_test_list = obs_result[2].tolist()
|
||||
level_indices_test_list = level_indices.tolist()
|
||||
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
|
||||
|
||||
"""
|
||||
Test case 3
|
||||
Test case 3: calculate_qparams
|
||||
Assume hardcoded parameters:
|
||||
* b = 6 (total number of bits across all terms)
|
||||
* k = 2 (base bitwidth, i.e. bitwidth of every term)
|
||||
* n = 3 (number of additive terms)
|
||||
"""
|
||||
def test_calculate_qparams_3terms(self):
|
||||
obs = APoTObserver(max_val=1.0, b=6, k=2)
|
||||
obs = APoTObserver(b=6, k=2)
|
||||
|
||||
obs_result = obs.calculate_qparams(signed=False)
|
||||
min_val = torch.tensor([0])
|
||||
max_val = torch.tensor([1])
|
||||
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False,
|
||||
min_val=min_val,
|
||||
max_val=max_val)
|
||||
|
||||
alpha_test = torch.max(-min_val, max_val)
|
||||
|
||||
# check alpha value
|
||||
self.assertEqual(alpha, alpha_test)
|
||||
|
||||
# calculate expected gamma value
|
||||
gamma_test = 0
|
||||
@ -70,23 +92,23 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
gamma_test = 1 / gamma_test
|
||||
|
||||
# check gamma value
|
||||
self.assertEqual(obs_result[0], gamma_test)
|
||||
self.assertEqual(gamma, gamma_test)
|
||||
|
||||
# check quantization levels size
|
||||
quantlevels_size_test = int(len(obs_result[1]))
|
||||
quantlevels_size_test = int(len(quantization_levels))
|
||||
quantlevels_size = 2**6
|
||||
self.assertEqual(quantlevels_size_test, quantlevels_size)
|
||||
|
||||
# check level indices size
|
||||
levelindices_size_test = int(len(obs_result[2]))
|
||||
levelindices_size_test = int(len(level_indices))
|
||||
self.assertEqual(levelindices_size_test, 64)
|
||||
|
||||
# check level indices unique values
|
||||
level_indices_test_list = obs_result[2].tolist()
|
||||
level_indices_test_list = level_indices.tolist()
|
||||
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
|
||||
|
||||
"""
|
||||
Test case 4
|
||||
Test case 4: calculate_qparams
|
||||
Same as test case 2 but with signed = True
|
||||
Assume hardcoded parameters:
|
||||
* b = 4 (total number of bits across all terms)
|
||||
@ -95,8 +117,17 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
* signed = True
|
||||
"""
|
||||
def test_calculate_qparams_signed(self):
|
||||
obs = APoTObserver(max_val=1.0, b=4, k=2)
|
||||
obs_result = obs.calculate_qparams(signed=True)
|
||||
obs = APoTObserver(b=4, k=2)
|
||||
|
||||
min_val = torch.tensor([0])
|
||||
max_val = torch.tensor([1])
|
||||
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True,
|
||||
min_val=min_val,
|
||||
max_val=max_val)
|
||||
alpha_test = torch.max(-min_val, max_val)
|
||||
|
||||
# check alpha value
|
||||
self.assertEqual(alpha, alpha_test)
|
||||
|
||||
# calculate expected gamma value
|
||||
gamma_test = 0
|
||||
@ -106,15 +137,15 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
gamma_test = 1 / gamma_test
|
||||
|
||||
# check gamma value
|
||||
self.assertEqual(obs_result[0], gamma_test)
|
||||
self.assertEqual(gamma, gamma_test)
|
||||
|
||||
# check quantization levels size
|
||||
quantlevels_size_test = int(len(obs_result[1]))
|
||||
quantlevels_size_test = int(len(quantization_levels))
|
||||
self.assertEqual(quantlevels_size_test, 49)
|
||||
|
||||
# check negatives of each element contained
|
||||
# in quantization levels
|
||||
quantlevels_test_list = obs_result[1].tolist()
|
||||
quantlevels_test_list = quantization_levels.tolist()
|
||||
negatives_contained = True
|
||||
for ele in quantlevels_test_list:
|
||||
if not (-ele) in quantlevels_test_list:
|
||||
@ -122,15 +153,15 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
self.assertTrue(negatives_contained)
|
||||
|
||||
# check level indices size
|
||||
levelindices_size_test = int(len(obs_result[2]))
|
||||
levelindices_size_test = int(len(level_indices))
|
||||
self.assertEqual(levelindices_size_test, 49)
|
||||
|
||||
# check level indices unique elements
|
||||
level_indices_test_list = obs_result[2].tolist()
|
||||
level_indices_test_list = level_indices.tolist()
|
||||
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
|
||||
|
||||
"""
|
||||
Test case 5
|
||||
Test case 5: calculate_qparams
|
||||
Assume hardcoded parameters:
|
||||
* b = 6 (total number of bits across all terms)
|
||||
* k = 1 (base bitwidth, i.e. bitwidth of every term)
|
||||
@ -165,5 +196,26 @@ class TestNonUniformObserver(unittest.TestCase):
|
||||
level_indices_test_list = obs_result[2].tolist()
|
||||
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
|
||||
|
||||
"""
|
||||
Test forward method on hard-coded tensor with arbitrary values.
|
||||
Checks that alpha is max of abs value of max and min values in tensor.
|
||||
"""
|
||||
def test_forward(self):
|
||||
obs = APoTObserver(b=4, k=2)
|
||||
|
||||
X = torch.tensor([0.0, -100.23, -37.18, 3.42, 8.93, 9.21, 87.92])
|
||||
|
||||
X = obs.forward(X)
|
||||
|
||||
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(True)
|
||||
|
||||
min_val = torch.min(X)
|
||||
max_val = torch.max(X)
|
||||
|
||||
expected_alpha = torch.max(-min_val, max_val)
|
||||
|
||||
self.assertEqual(alpha, expected_alpha)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -1,33 +1,41 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
import random
|
||||
import unittest
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
from torch.ao.quantization.experimental.quantizer import quantize_APoT
|
||||
|
||||
class TestQuantizedTensor(unittest.TestCase):
|
||||
r""" Tests int_repr on APoTQuantizer with random tensor2quantize
|
||||
and hard-coded values b=4, k=2
|
||||
and hard-coded values
|
||||
"""
|
||||
def test_int_repr(self):
|
||||
# generate random size of tensor2dequantize between 1 -> 20
|
||||
size = random.randint(1, 20)
|
||||
# generate tensor with random fp values
|
||||
tensor2quantize = tensor2quantize = torch.tensor([0, 0.0215, 0.1692, 0.385, 1, 0.0391])
|
||||
|
||||
# generate tensor with random fp values between 0 -> 1000
|
||||
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
|
||||
orig_tensor2quantize = torch.clone(tensor2quantize)
|
||||
observer = APoTObserver(b=4, k=2)
|
||||
|
||||
quantizer = APoTQuantizer(4, 2, torch.max(tensor2quantize), False)
|
||||
observer.forward(tensor2quantize)
|
||||
|
||||
qparams = observer.calculate_qparams(signed=False)
|
||||
|
||||
# get apot quantized tensor result
|
||||
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
|
||||
qtensor = quantize_APoT(tensor2quantize=tensor2quantize,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
tensor_apot = TensorAPoT(quantizer, orig_tensor2quantize)
|
||||
qtensor_data = qtensor.int_repr().int()
|
||||
|
||||
qtensor_int_rep = tensor_apot.int_repr()
|
||||
# expected qtensor values calculated based on
|
||||
# corresponding level_indices to nearest quantization level
|
||||
# for each fp value in tensor2quantize
|
||||
# e.g.
|
||||
# 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices
|
||||
expected_qtensor_data = torch.tensor([0, 3, 8, 13, 5, 12], dtype=torch.int32)
|
||||
|
||||
self.assertTrue(torch.equal(qtensor, qtensor_int_rep))
|
||||
self.assertTrue(torch.equal(qtensor_data, expected_qtensor_data))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
import torch
|
||||
from torch import quantize_per_tensor
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer, quantize_APoT, dequantize_APoT
|
||||
import unittest
|
||||
import random
|
||||
|
||||
@ -22,16 +23,22 @@ class TestQuantizer(unittest.TestCase):
|
||||
# generate tensor with random fp values between 0 -> 1000
|
||||
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
|
||||
|
||||
quantizer = APoTQuantizer(4, 1, torch.max(tensor2quantize), False)
|
||||
observer = APoTObserver(b=8, k=1)
|
||||
observer.forward(tensor2quantize)
|
||||
qparams = observer.calculate_qparams(signed=False, min_val=torch.tensor(0), max_val=torch.tensor(255))
|
||||
|
||||
# get apot quantized tensor result
|
||||
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
|
||||
qtensor = quantize_APoT(tensor2quantize=tensor2quantize,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
# get uniform quantization quantized tensor result
|
||||
uniform_quantized = quantize_per_tensor(input=tensor2quantize, scale=1.0, zero_point=0, dtype=torch.quint8).int_repr()
|
||||
|
||||
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
|
||||
uniform_quantized_tensor = uniform_quantized.data
|
||||
qtensor_data = qtensor.data.int()
|
||||
uniform_quantized_tensor = uniform_quantized.data.int()
|
||||
|
||||
self.assertTrue(torch.equal(qtensor_data, uniform_quantized_tensor))
|
||||
|
||||
@ -54,21 +61,28 @@ class TestQuantizer(unittest.TestCase):
|
||||
level_indices = tensor([ 0, 3, 12, 15, 2, 14, 8, 11, 10, 1, 13, 9, 4, 7, 6, 5]))
|
||||
"""
|
||||
|
||||
# generate tensor with random fp values between 0 -> 1000
|
||||
tensor2quantize = torch.tensor([0.0215, 0.1692, 0.385, 0.0391])
|
||||
# generate tensor with random fp values
|
||||
tensor2quantize = torch.tensor([0, 0.0215, 0.1692, 0.385, 1, 0.0391])
|
||||
|
||||
quantizer = APoTQuantizer(4, 2, 1.0, False)
|
||||
observer = APoTObserver(b=4, k=2)
|
||||
observer.forward(tensor2quantize)
|
||||
qparams = observer.calculate_qparams(signed=False)
|
||||
|
||||
# get apot quantized tensor result
|
||||
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
|
||||
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
|
||||
qtensor = quantize_APoT(tensor2quantize=tensor2quantize,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
qtensor_data = qtensor.data.int()
|
||||
|
||||
# expected qtensor values calculated based on
|
||||
# corresponding level_indices to nearest quantization level
|
||||
# for each fp value in tensor2quantize
|
||||
# e.g.
|
||||
# 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices
|
||||
expected_qtensor = torch.tensor([3, 8, 13, 12], dtype=torch.uint8)
|
||||
expected_qtensor = torch.tensor([0, 3, 8, 13, 5, 12], dtype=torch.int32)
|
||||
|
||||
self.assertTrue(torch.equal(qtensor_data, expected_qtensor))
|
||||
|
||||
@ -81,52 +95,94 @@ class TestQuantizer(unittest.TestCase):
|
||||
* k: 2
|
||||
"""
|
||||
def test_dequantize_quantize_rand_b4(self):
|
||||
# generate random size of float2apot between 1->20
|
||||
# make observer
|
||||
observer = APoTObserver(4, 2)
|
||||
|
||||
# generate random size of tensor2quantize between 1 -> 20
|
||||
size = random.randint(1, 20)
|
||||
|
||||
# initialize quantize APoT tensor to dequantize:
|
||||
# generate tensor with random values between 0 -> 2**4 = 16
|
||||
# because there are 2**b = 2**4 quantization levels total
|
||||
float2apot = 16 * torch.rand(size)
|
||||
quantizer = APoTQuantizer(4, 2, 1.0, False)
|
||||
float2apot = float2apot.int()
|
||||
orig_input = torch.clone(float2apot)
|
||||
# make tensor2quantize: random fp values between 0 -> 1000
|
||||
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
|
||||
|
||||
dequantized_result = quantizer.dequantize(float2apot)
|
||||
observer.forward(tensor2quantize)
|
||||
|
||||
quantized_result = quantizer.quantize_APoT(tensor2quantize=dequantized_result)
|
||||
qparams = observer.calculate_qparams(signed=False)
|
||||
|
||||
quantized_result = quantized_result.int()
|
||||
# make mock apot_tensor
|
||||
original_apot = quantize_APoT(tensor2quantize=tensor2quantize,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
self.assertTrue(torch.equal(quantized_result, orig_input))
|
||||
original_input = torch.clone(original_apot.data).int()
|
||||
|
||||
# dequantize apot_tensor
|
||||
dequantize_result = dequantize_APoT(apot_tensor=original_apot,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
# quantize apot_tensor
|
||||
final_apot = quantize_APoT(tensor2quantize=dequantize_result,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
result = final_apot.data.int()
|
||||
|
||||
self.assertTrue(torch.equal(original_input, result))
|
||||
|
||||
r""" Tests dequantize_apot result on random 1-dim tensor
|
||||
and hardcoded values for b, k.
|
||||
Dequant -> quant an input tensor and verify that
|
||||
result is equivalent to input
|
||||
* tensor2quantize: Tensor
|
||||
* b: 6
|
||||
* k: 2
|
||||
* b: 12
|
||||
* k: 4
|
||||
"""
|
||||
def test_dequantize_quantize_rand_b6(self):
|
||||
# generate random size of float2apot
|
||||
# make observer
|
||||
observer = APoTObserver(12, 4)
|
||||
|
||||
# generate random size of tensor2quantize between 1 -> 20
|
||||
size = random.randint(1, 20)
|
||||
|
||||
# initialize quantize APoT tensor to dequantize:
|
||||
# generate tensor with random values between 0 -> 2**6 = 64
|
||||
# because there are 2**b = 2**6 quantization levels total
|
||||
float2apot = 64 * torch.rand(size)
|
||||
quantizer = APoTQuantizer(6, 2, 1.0, False)
|
||||
float2apot = float2apot.int()
|
||||
orig_input = torch.clone(float2apot)
|
||||
# make tensor2quantize: random fp values between 0 -> 1000
|
||||
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
|
||||
|
||||
dequantized_result = quantizer.dequantize(float2apot)
|
||||
observer.forward(tensor2quantize)
|
||||
|
||||
quantized_result = quantizer.quantize_APoT(tensor2quantize=dequantized_result)
|
||||
qparams = observer.calculate_qparams(signed=False)
|
||||
|
||||
quantized_result = quantized_result.int()
|
||||
# make mock apot_tensor
|
||||
original_apot = quantize_APoT(tensor2quantize=tensor2quantize,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
self.assertTrue(torch.equal(quantized_result, orig_input))
|
||||
original_input = torch.clone(original_apot.data).int()
|
||||
|
||||
# dequantize apot_tensor
|
||||
dequantize_result = dequantize_APoT(apot_tensor=original_apot,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
# quantize apot_tensor
|
||||
final_apot = quantize_APoT(tensor2quantize=dequantize_result,
|
||||
alpha=qparams[0],
|
||||
gamma=qparams[1],
|
||||
quantization_levels=qparams[2],
|
||||
level_indices=qparams[3])
|
||||
|
||||
result = final_apot.data.int()
|
||||
|
||||
self.assertTrue(torch.equal(original_input, result))
|
||||
|
||||
def test_q_apot_alpha(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
@ -6,9 +6,9 @@ class TensorAPoT():
|
||||
quantizer: APoTQuantizer
|
||||
data: torch.Tensor
|
||||
|
||||
def __init__(self, quantizer: APoTQuantizer, tensor2quantize: torch.Tensor):
|
||||
def __init__(self, quantizer: APoTQuantizer, apot_data: torch.Tensor):
|
||||
self.quantizer = quantizer
|
||||
self.data = quantizer.quantize_APoT(tensor2quantize)
|
||||
self.data = apot_data
|
||||
|
||||
def int_repr(self):
|
||||
return self.data
|
||||
|
@ -13,40 +13,48 @@ from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to
|
||||
# when more than one non-uniform method is implemented
|
||||
|
||||
class APoTObserver(ObserverBase):
|
||||
max_val: float
|
||||
b: int
|
||||
k: int
|
||||
n: int
|
||||
alpha: float
|
||||
gamma: float
|
||||
level_indices: torch.Tensor
|
||||
min_val: torch.Tensor
|
||||
max_val: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_val,
|
||||
b,
|
||||
k,
|
||||
dtype=torch.quint8) -> None:
|
||||
dtype=torch.int32) -> None:
|
||||
super().__init__(dtype)
|
||||
self.max_val = max_val
|
||||
self.b = b
|
||||
self.k = k
|
||||
|
||||
def calculate_qparams(self, signed):
|
||||
return self._calculate_qparams(signed)
|
||||
self.min_val = torch.tensor([])
|
||||
self.max_val = torch.tensor([])
|
||||
|
||||
# min_val and max_val are optional args to override
|
||||
# the min_val and max_val observed by forward
|
||||
def calculate_qparams(self, signed, min_val=None, max_val=None):
|
||||
return self._calculate_qparams(signed, min_val, max_val)
|
||||
|
||||
r""" Calculates nonuniform quantization parameters according to APoT paper:
|
||||
https://arxiv.org/pdf/1909.13144.pdf.
|
||||
Arg:
|
||||
signed: specifies whether to include signed values in quantization level calculations
|
||||
min_val: optional arg that can override min_val internal attribute
|
||||
max_val: optional arg that can override max_val internal attribute
|
||||
Returns:
|
||||
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
|
||||
quantization_levels: non-uniform quantization levels (fp representation)
|
||||
level_indices: int representation of quantization_levels indices
|
||||
"""
|
||||
def _calculate_qparams(self, signed):
|
||||
def _calculate_qparams(self, signed, min_val=None, max_val=None):
|
||||
if min_val is not None:
|
||||
self.min_val = min_val
|
||||
if max_val is not None:
|
||||
self.max_val = max_val
|
||||
|
||||
# compute alpha
|
||||
self.alpha = self.max_val
|
||||
alpha = torch.max(-self.min_val, self.max_val)
|
||||
|
||||
# check for valid inputs of b, k
|
||||
assert(self.k and self.k != 0)
|
||||
@ -90,7 +98,7 @@ class APoTObserver(ObserverBase):
|
||||
p_sum += float(tens[1])
|
||||
|
||||
# assign gamma
|
||||
self.gamma = self.alpha / p_sum
|
||||
gamma = alpha / p_sum
|
||||
|
||||
# calculate cartesian product
|
||||
cartesian_product = list(itertools.product(*p_all))
|
||||
@ -104,16 +112,25 @@ class APoTObserver(ObserverBase):
|
||||
sum += ele
|
||||
quantization_levels_list.append(sum)
|
||||
|
||||
quantization_levels_gamma = [self.gamma * ele for ele in quantization_levels_list]
|
||||
quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list]
|
||||
quantization_levels = torch.tensor(quantization_levels_gamma)
|
||||
level_indices = torch.tensor([])
|
||||
quantization_levels, self.level_indices = quantization_levels.sort()
|
||||
quantization_levels, level_indices = quantization_levels.sort()
|
||||
|
||||
return (self.gamma, quantization_levels, self.level_indices)
|
||||
return (alpha, gamma, quantization_levels, level_indices)
|
||||
|
||||
r"""Records the running minimum and maximum of ``x``."""
|
||||
def forward(self, x_orig):
|
||||
r"""Records the running maximum of ``x``."""
|
||||
max_val = self.max_val
|
||||
if x_orig.numel() == 0:
|
||||
return x_orig
|
||||
x = x_orig.detach()
|
||||
min_val, max_val = torch.aminmax(x)
|
||||
if self.min_val.numel():
|
||||
min_val = torch.min(min_val, self.min_val)
|
||||
if self.max_val.numel():
|
||||
max_val = torch.max(max_val, self.max_val)
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
return x_orig
|
||||
|
||||
def quant_levels_visualization(self, obs_result, filename):
|
||||
|
@ -1,74 +1,94 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
||||
|
||||
# class to store APoT quantizer
|
||||
# implements quantize and dequantize
|
||||
# and stores all quantization parameters
|
||||
# class to store APoT quantizer and
|
||||
# implement quantize and dequantize
|
||||
class APoTQuantizer():
|
||||
b: int
|
||||
k: int
|
||||
n: int
|
||||
signed: bool
|
||||
alpha: torch.Tensor
|
||||
gamma: torch.Tensor
|
||||
quantization_levels: torch.Tensor
|
||||
level_indices: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
b,
|
||||
k,
|
||||
max_val,
|
||||
signed,
|
||||
dtype=torch.quint8) -> None:
|
||||
self.signed = signed
|
||||
alpha: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
quantization_levels: torch.Tensor,
|
||||
level_indices: torch.Tensor) -> None:
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.quantization_levels = quantization_levels
|
||||
self.level_indices = level_indices
|
||||
|
||||
# check for valid inputs of b, k
|
||||
assert(k and k != 0)
|
||||
assert(b % k == 0)
|
||||
self.b = b
|
||||
self.k = k
|
||||
self.n = b // k
|
||||
|
||||
# make observer, get quantizion levels and level indices
|
||||
obs = APoTObserver(max_val=max_val, b=b, k=k)
|
||||
obs_result = obs.calculate_qparams(signed=signed)
|
||||
self.quantization_levels = obs_result[1]
|
||||
self.level_indices = obs_result[2]
|
||||
|
||||
r""" Quantizes fp Tensor to integer APoT representatio.
|
||||
Conversion is based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||
r""" Quantizes fp Tensor to integer APoT representation.
|
||||
Conversion is based on the qparams from a specified APoT non-uniform observer.
|
||||
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||
Args:
|
||||
tensor2quantize: fp Tensor
|
||||
Returns:
|
||||
result: integer APoT representation of tensor2quantize
|
||||
result: APoT Tensor representation of tensor2quantize
|
||||
"""
|
||||
def quantize_APoT(self, tensor2quantize: Tensor):
|
||||
def quantize(self, tensor2quantize: Tensor):
|
||||
result = torch.tensor([])
|
||||
|
||||
# clip tensor2quantize values based on alpha qparam
|
||||
tensor2quantize = torch.clamp(tensor2quantize, -self.alpha, self.alpha)
|
||||
|
||||
# map float_to_apot over tensor2quantize elements
|
||||
result = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
|
||||
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
|
||||
|
||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||
|
||||
result = TensorAPoT(self, tensor2quantize)
|
||||
|
||||
return result
|
||||
|
||||
r""" Dequantizes integer Tensor to floating point representation
|
||||
r""" Dequantizes integer Tensor to floating point (fp) representation
|
||||
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||
Args:
|
||||
self: APoTQuantizer with attr data to dequantize
|
||||
apot_tensor: quantized APoT Tensor to dequantize
|
||||
Returns:
|
||||
result: floating point representation of input Tensor
|
||||
result: fp representation of input Tensor
|
||||
"""
|
||||
def dequantize(self, float2apot: Tensor): # type: ignore[override]
|
||||
float2apot = float2apot.float()
|
||||
|
||||
quantization_levels = self.quantization_levels
|
||||
level_indices = self.level_indices
|
||||
def dequantize(self, apot_tensor) -> Tensor:
|
||||
apot_tensor_data = apot_tensor.data
|
||||
|
||||
# map apot_to_float over tensor2quantize elements
|
||||
result = float2apot.apply_(lambda x: float(apot_to_float(x, quantization_levels, level_indices)))
|
||||
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
|
||||
|
||||
return result
|
||||
|
||||
def q_apot_alpha(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
r""" Global method to create quantizer and call quantizer quantize_APoT
|
||||
Args:
|
||||
tensor2quantize: fp Tensor to quantize
|
||||
alpha: Tensor qparam alpha (clipping level)
|
||||
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
||||
quantization levels: Tensor with fp quantization levels
|
||||
level indices: Tensor with integer quantization level indices
|
||||
Returns:
|
||||
result: ApoT Tensor representation of tensor2quantize
|
||||
"""
|
||||
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
|
||||
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
||||
result = quantizer.quantize(tensor2quantize)
|
||||
return result
|
||||
|
||||
r""" Global method to create quantizer and call quantizer dequantize_APoT
|
||||
Args:
|
||||
apot_tensor: APoT Tensor to dequantize
|
||||
alpha: Tensor qparam alpha (clipping level)
|
||||
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
||||
quantization levels: Tensor with fp quantization levels
|
||||
level indices: Tensor with integer quantization level indices
|
||||
Returns:
|
||||
result: fp Tensor dequantized from apot_tensor
|
||||
"""
|
||||
def dequantize_APoT(apot_tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor) -> Tensor:
|
||||
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
||||
result = apot_tensor.quantizer.dequantize(apot_tensor)
|
||||
return result
|
||||
|
Reference in New Issue
Block a user