[quant] Modify APoT nonuniform quantization workflow (#80075)

### Summary:
This PR updates the design of APoT Observer, Quantizer, and Tensor to be more consistent with their uniform counterparts in the PyTorch framework. APoT Observer now calculates alpha as the max between the absolute values of the max and min values in the input tensor. APoT Quantizer is modified so its instance methods quantize_APoT and dequantize_APoT are called by their global method counterparts. APoT Tensor is modified to account for the new method definition of the `quantize_APoT` from APoT Quantizer.

### Test Plan:
Run APoT Observer class unit tests with: `python pytorch/test/quantization/core/experimental/test_nonuniform_observer.py`
Run APoT Quantize class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantizer.py`
Run APoT Tensor class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantized_tensor.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80075
Approved by: https://github.com/jerryzh168
This commit is contained in:
asl3
2022-06-27 03:51:00 -07:00
committed by PyTorch MergeBot
parent 71d9592a72
commit 777c12f2df
7 changed files with 293 additions and 137 deletions

View File

@ -67,6 +67,9 @@ ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.observer]
ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.APoT_tensor]
ignore_missing_imports = True
#
# Files with various errors. Mostly real errors, possibly some false
# positives as well.

View File

@ -2,20 +2,23 @@
from torch.ao.quantization.experimental.observer import APoTObserver
import unittest
import torch
class TestNonUniformObserver(unittest.TestCase):
"""
Test case 1
Test case 1: calculate_qparams
Test that error is thrown when k == 0
"""
def test_calculate_qparams_invalid(self):
obs = APoTObserver(max_val=0.0, b=0, k=0)
obs = APoTObserver(b=0, k=0)
with self.assertRaises(AssertionError):
obs_result = obs.calculate_qparams(signed=False)
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False,
min_val=torch.tensor([0]),
max_val=torch.tensor([0]))
"""
Test case 2
Test case 2: calculate_qparams
APoT paper example: https://arxiv.org/pdf/1909.13144.pdf
Assume hardcoded parameters:
* b = 4 (total number of bits across all terms)
@ -24,8 +27,18 @@ class TestNonUniformObserver(unittest.TestCase):
* note: b = k * n
"""
def test_calculate_qparams_2terms(self):
obs = APoTObserver(max_val=1.0, b=4, k=2)
obs_result = obs.calculate_qparams(signed=False)
obs = APoTObserver(b=4, k=2)
min_val = torch.tensor([0])
max_val = torch.tensor([1])
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False,
min_val=min_val,
max_val=max_val)
alpha_test = torch.max(-min_val, max_val)
# check alpha value
self.assertEqual(alpha, alpha_test)
# calculate expected gamma value
gamma_test = 0
@ -35,32 +48,41 @@ class TestNonUniformObserver(unittest.TestCase):
gamma_test = 1 / gamma_test
# check gamma value
self.assertEqual(obs_result[0], gamma_test)
self.assertEqual(gamma, gamma_test)
# check quantization levels size
quantlevels_size_test = int(len(obs_result[1]))
quantlevels_size_test = int(len(quantization_levels))
quantlevels_size = 2**4
self.assertEqual(quantlevels_size_test, quantlevels_size)
# check level indices size
levelindices_size_test = int(len(obs_result[2]))
levelindices_size_test = int(len(level_indices))
self.assertEqual(levelindices_size_test, 16)
# check level indices unique values
level_indices_test_list = obs_result[2].tolist()
level_indices_test_list = level_indices.tolist()
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
"""
Test case 3
Test case 3: calculate_qparams
Assume hardcoded parameters:
* b = 6 (total number of bits across all terms)
* k = 2 (base bitwidth, i.e. bitwidth of every term)
* n = 3 (number of additive terms)
"""
def test_calculate_qparams_3terms(self):
obs = APoTObserver(max_val=1.0, b=6, k=2)
obs = APoTObserver(b=6, k=2)
obs_result = obs.calculate_qparams(signed=False)
min_val = torch.tensor([0])
max_val = torch.tensor([1])
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False,
min_val=min_val,
max_val=max_val)
alpha_test = torch.max(-min_val, max_val)
# check alpha value
self.assertEqual(alpha, alpha_test)
# calculate expected gamma value
gamma_test = 0
@ -70,23 +92,23 @@ class TestNonUniformObserver(unittest.TestCase):
gamma_test = 1 / gamma_test
# check gamma value
self.assertEqual(obs_result[0], gamma_test)
self.assertEqual(gamma, gamma_test)
# check quantization levels size
quantlevels_size_test = int(len(obs_result[1]))
quantlevels_size_test = int(len(quantization_levels))
quantlevels_size = 2**6
self.assertEqual(quantlevels_size_test, quantlevels_size)
# check level indices size
levelindices_size_test = int(len(obs_result[2]))
levelindices_size_test = int(len(level_indices))
self.assertEqual(levelindices_size_test, 64)
# check level indices unique values
level_indices_test_list = obs_result[2].tolist()
level_indices_test_list = level_indices.tolist()
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
"""
Test case 4
Test case 4: calculate_qparams
Same as test case 2 but with signed = True
Assume hardcoded parameters:
* b = 4 (total number of bits across all terms)
@ -95,8 +117,17 @@ class TestNonUniformObserver(unittest.TestCase):
* signed = True
"""
def test_calculate_qparams_signed(self):
obs = APoTObserver(max_val=1.0, b=4, k=2)
obs_result = obs.calculate_qparams(signed=True)
obs = APoTObserver(b=4, k=2)
min_val = torch.tensor([0])
max_val = torch.tensor([1])
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True,
min_val=min_val,
max_val=max_val)
alpha_test = torch.max(-min_val, max_val)
# check alpha value
self.assertEqual(alpha, alpha_test)
# calculate expected gamma value
gamma_test = 0
@ -106,15 +137,15 @@ class TestNonUniformObserver(unittest.TestCase):
gamma_test = 1 / gamma_test
# check gamma value
self.assertEqual(obs_result[0], gamma_test)
self.assertEqual(gamma, gamma_test)
# check quantization levels size
quantlevels_size_test = int(len(obs_result[1]))
quantlevels_size_test = int(len(quantization_levels))
self.assertEqual(quantlevels_size_test, 49)
# check negatives of each element contained
# in quantization levels
quantlevels_test_list = obs_result[1].tolist()
quantlevels_test_list = quantization_levels.tolist()
negatives_contained = True
for ele in quantlevels_test_list:
if not (-ele) in quantlevels_test_list:
@ -122,15 +153,15 @@ class TestNonUniformObserver(unittest.TestCase):
self.assertTrue(negatives_contained)
# check level indices size
levelindices_size_test = int(len(obs_result[2]))
levelindices_size_test = int(len(level_indices))
self.assertEqual(levelindices_size_test, 49)
# check level indices unique elements
level_indices_test_list = obs_result[2].tolist()
level_indices_test_list = level_indices.tolist()
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
"""
Test case 5
Test case 5: calculate_qparams
Assume hardcoded parameters:
* b = 6 (total number of bits across all terms)
* k = 1 (base bitwidth, i.e. bitwidth of every term)
@ -165,5 +196,26 @@ class TestNonUniformObserver(unittest.TestCase):
level_indices_test_list = obs_result[2].tolist()
self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
"""
Test forward method on hard-coded tensor with arbitrary values.
Checks that alpha is max of abs value of max and min values in tensor.
"""
def test_forward(self):
obs = APoTObserver(b=4, k=2)
X = torch.tensor([0.0, -100.23, -37.18, 3.42, 8.93, 9.21, 87.92])
X = obs.forward(X)
alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(True)
min_val = torch.min(X)
max_val = torch.max(X)
expected_alpha = torch.max(-min_val, max_val)
self.assertEqual(alpha, expected_alpha)
if __name__ == '__main__':
unittest.main()

View File

@ -1,33 +1,41 @@
# Owner(s): ["oncall: quantization"]
import torch
import random
import unittest
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.quantizer import quantize_APoT
class TestQuantizedTensor(unittest.TestCase):
r""" Tests int_repr on APoTQuantizer with random tensor2quantize
and hard-coded values b=4, k=2
and hard-coded values
"""
def test_int_repr(self):
# generate random size of tensor2dequantize between 1 -> 20
size = random.randint(1, 20)
# generate tensor with random fp values
tensor2quantize = tensor2quantize = torch.tensor([0, 0.0215, 0.1692, 0.385, 1, 0.0391])
# generate tensor with random fp values between 0 -> 1000
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
orig_tensor2quantize = torch.clone(tensor2quantize)
observer = APoTObserver(b=4, k=2)
quantizer = APoTQuantizer(4, 2, torch.max(tensor2quantize), False)
observer.forward(tensor2quantize)
qparams = observer.calculate_qparams(signed=False)
# get apot quantized tensor result
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
qtensor = quantize_APoT(tensor2quantize=tensor2quantize,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
tensor_apot = TensorAPoT(quantizer, orig_tensor2quantize)
qtensor_data = qtensor.int_repr().int()
qtensor_int_rep = tensor_apot.int_repr()
# expected qtensor values calculated based on
# corresponding level_indices to nearest quantization level
# for each fp value in tensor2quantize
# e.g.
# 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices
expected_qtensor_data = torch.tensor([0, 3, 8, 13, 5, 12], dtype=torch.int32)
self.assertTrue(torch.equal(qtensor, qtensor_int_rep))
self.assertTrue(torch.equal(qtensor_data, expected_qtensor_data))
if __name__ == '__main__':
unittest.main()

View File

@ -2,7 +2,8 @@
import torch
from torch import quantize_per_tensor
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.quantizer import APoTQuantizer, quantize_APoT, dequantize_APoT
import unittest
import random
@ -22,16 +23,22 @@ class TestQuantizer(unittest.TestCase):
# generate tensor with random fp values between 0 -> 1000
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
quantizer = APoTQuantizer(4, 1, torch.max(tensor2quantize), False)
observer = APoTObserver(b=8, k=1)
observer.forward(tensor2quantize)
qparams = observer.calculate_qparams(signed=False, min_val=torch.tensor(0), max_val=torch.tensor(255))
# get apot quantized tensor result
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
qtensor = quantize_APoT(tensor2quantize=tensor2quantize,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
# get uniform quantization quantized tensor result
uniform_quantized = quantize_per_tensor(input=tensor2quantize, scale=1.0, zero_point=0, dtype=torch.quint8).int_repr()
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
uniform_quantized_tensor = uniform_quantized.data
qtensor_data = qtensor.data.int()
uniform_quantized_tensor = uniform_quantized.data.int()
self.assertTrue(torch.equal(qtensor_data, uniform_quantized_tensor))
@ -54,21 +61,28 @@ class TestQuantizer(unittest.TestCase):
level_indices = tensor([ 0, 3, 12, 15, 2, 14, 8, 11, 10, 1, 13, 9, 4, 7, 6, 5]))
"""
# generate tensor with random fp values between 0 -> 1000
tensor2quantize = torch.tensor([0.0215, 0.1692, 0.385, 0.0391])
# generate tensor with random fp values
tensor2quantize = torch.tensor([0, 0.0215, 0.1692, 0.385, 1, 0.0391])
quantizer = APoTQuantizer(4, 2, 1.0, False)
observer = APoTObserver(b=4, k=2)
observer.forward(tensor2quantize)
qparams = observer.calculate_qparams(signed=False)
# get apot quantized tensor result
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
qtensor = quantize_APoT(tensor2quantize=tensor2quantize,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
qtensor_data = qtensor.data.int()
# expected qtensor values calculated based on
# corresponding level_indices to nearest quantization level
# for each fp value in tensor2quantize
# e.g.
# 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices
expected_qtensor = torch.tensor([3, 8, 13, 12], dtype=torch.uint8)
expected_qtensor = torch.tensor([0, 3, 8, 13, 5, 12], dtype=torch.int32)
self.assertTrue(torch.equal(qtensor_data, expected_qtensor))
@ -81,52 +95,94 @@ class TestQuantizer(unittest.TestCase):
* k: 2
"""
def test_dequantize_quantize_rand_b4(self):
# generate random size of float2apot between 1->20
# make observer
observer = APoTObserver(4, 2)
# generate random size of tensor2quantize between 1 -> 20
size = random.randint(1, 20)
# initialize quantize APoT tensor to dequantize:
# generate tensor with random values between 0 -> 2**4 = 16
# because there are 2**b = 2**4 quantization levels total
float2apot = 16 * torch.rand(size)
quantizer = APoTQuantizer(4, 2, 1.0, False)
float2apot = float2apot.int()
orig_input = torch.clone(float2apot)
# make tensor2quantize: random fp values between 0 -> 1000
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
dequantized_result = quantizer.dequantize(float2apot)
observer.forward(tensor2quantize)
quantized_result = quantizer.quantize_APoT(tensor2quantize=dequantized_result)
qparams = observer.calculate_qparams(signed=False)
quantized_result = quantized_result.int()
# make mock apot_tensor
original_apot = quantize_APoT(tensor2quantize=tensor2quantize,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
self.assertTrue(torch.equal(quantized_result, orig_input))
original_input = torch.clone(original_apot.data).int()
# dequantize apot_tensor
dequantize_result = dequantize_APoT(apot_tensor=original_apot,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
# quantize apot_tensor
final_apot = quantize_APoT(tensor2quantize=dequantize_result,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
result = final_apot.data.int()
self.assertTrue(torch.equal(original_input, result))
r""" Tests dequantize_apot result on random 1-dim tensor
and hardcoded values for b, k.
Dequant -> quant an input tensor and verify that
result is equivalent to input
* tensor2quantize: Tensor
* b: 6
* k: 2
* b: 12
* k: 4
"""
def test_dequantize_quantize_rand_b6(self):
# generate random size of float2apot
# make observer
observer = APoTObserver(12, 4)
# generate random size of tensor2quantize between 1 -> 20
size = random.randint(1, 20)
# initialize quantize APoT tensor to dequantize:
# generate tensor with random values between 0 -> 2**6 = 64
# because there are 2**b = 2**6 quantization levels total
float2apot = 64 * torch.rand(size)
quantizer = APoTQuantizer(6, 2, 1.0, False)
float2apot = float2apot.int()
orig_input = torch.clone(float2apot)
# make tensor2quantize: random fp values between 0 -> 1000
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
dequantized_result = quantizer.dequantize(float2apot)
observer.forward(tensor2quantize)
quantized_result = quantizer.quantize_APoT(tensor2quantize=dequantized_result)
qparams = observer.calculate_qparams(signed=False)
quantized_result = quantized_result.int()
# make mock apot_tensor
original_apot = quantize_APoT(tensor2quantize=tensor2quantize,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
self.assertTrue(torch.equal(quantized_result, orig_input))
original_input = torch.clone(original_apot.data).int()
# dequantize apot_tensor
dequantize_result = dequantize_APoT(apot_tensor=original_apot,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
# quantize apot_tensor
final_apot = quantize_APoT(tensor2quantize=dequantize_result,
alpha=qparams[0],
gamma=qparams[1],
quantization_levels=qparams[2],
level_indices=qparams[3])
result = final_apot.data.int()
self.assertTrue(torch.equal(original_input, result))
def test_q_apot_alpha(self):
with self.assertRaises(NotImplementedError):

View File

@ -6,9 +6,9 @@ class TensorAPoT():
quantizer: APoTQuantizer
data: torch.Tensor
def __init__(self, quantizer: APoTQuantizer, tensor2quantize: torch.Tensor):
def __init__(self, quantizer: APoTQuantizer, apot_data: torch.Tensor):
self.quantizer = quantizer
self.data = quantizer.quantize_APoT(tensor2quantize)
self.data = apot_data
def int_repr(self):
return self.data

View File

@ -13,40 +13,48 @@ from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to
# when more than one non-uniform method is implemented
class APoTObserver(ObserverBase):
max_val: float
b: int
k: int
n: int
alpha: float
gamma: float
level_indices: torch.Tensor
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
max_val,
b,
k,
dtype=torch.quint8) -> None:
dtype=torch.int32) -> None:
super().__init__(dtype)
self.max_val = max_val
self.b = b
self.k = k
def calculate_qparams(self, signed):
return self._calculate_qparams(signed)
self.min_val = torch.tensor([])
self.max_val = torch.tensor([])
# min_val and max_val are optional args to override
# the min_val and max_val observed by forward
def calculate_qparams(self, signed, min_val=None, max_val=None):
return self._calculate_qparams(signed, min_val, max_val)
r""" Calculates nonuniform quantization parameters according to APoT paper:
https://arxiv.org/pdf/1909.13144.pdf.
Arg:
signed: specifies whether to include signed values in quantization level calculations
min_val: optional arg that can override min_val internal attribute
max_val: optional arg that can override max_val internal attribute
Returns:
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
quantization_levels: non-uniform quantization levels (fp representation)
level_indices: int representation of quantization_levels indices
"""
def _calculate_qparams(self, signed):
def _calculate_qparams(self, signed, min_val=None, max_val=None):
if min_val is not None:
self.min_val = min_val
if max_val is not None:
self.max_val = max_val
# compute alpha
self.alpha = self.max_val
alpha = torch.max(-self.min_val, self.max_val)
# check for valid inputs of b, k
assert(self.k and self.k != 0)
@ -90,7 +98,7 @@ class APoTObserver(ObserverBase):
p_sum += float(tens[1])
# assign gamma
self.gamma = self.alpha / p_sum
gamma = alpha / p_sum
# calculate cartesian product
cartesian_product = list(itertools.product(*p_all))
@ -104,16 +112,25 @@ class APoTObserver(ObserverBase):
sum += ele
quantization_levels_list.append(sum)
quantization_levels_gamma = [self.gamma * ele for ele in quantization_levels_list]
quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list]
quantization_levels = torch.tensor(quantization_levels_gamma)
level_indices = torch.tensor([])
quantization_levels, self.level_indices = quantization_levels.sort()
quantization_levels, level_indices = quantization_levels.sort()
return (self.gamma, quantization_levels, self.level_indices)
return (alpha, gamma, quantization_levels, level_indices)
r"""Records the running minimum and maximum of ``x``."""
def forward(self, x_orig):
r"""Records the running maximum of ``x``."""
max_val = self.max_val
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach()
min_val, max_val = torch.aminmax(x)
if self.min_val.numel():
min_val = torch.min(min_val, self.min_val)
if self.max_val.numel():
max_val = torch.max(max_val, self.max_val)
self.min_val = min_val
self.max_val = max_val
return x_orig
def quant_levels_visualization(self, obs_result, filename):

View File

@ -1,74 +1,94 @@
import torch
from torch import Tensor
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# class to store APoT quantizer
# implements quantize and dequantize
# and stores all quantization parameters
# class to store APoT quantizer and
# implement quantize and dequantize
class APoTQuantizer():
b: int
k: int
n: int
signed: bool
alpha: torch.Tensor
gamma: torch.Tensor
quantization_levels: torch.Tensor
level_indices: torch.Tensor
def __init__(
self,
b,
k,
max_val,
signed,
dtype=torch.quint8) -> None:
self.signed = signed
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor) -> None:
self.alpha = alpha
self.gamma = gamma
self.quantization_levels = quantization_levels
self.level_indices = level_indices
# check for valid inputs of b, k
assert(k and k != 0)
assert(b % k == 0)
self.b = b
self.k = k
self.n = b // k
# make observer, get quantizion levels and level indices
obs = APoTObserver(max_val=max_val, b=b, k=k)
obs_result = obs.calculate_qparams(signed=signed)
self.quantization_levels = obs_result[1]
self.level_indices = obs_result[2]
r""" Quantizes fp Tensor to integer APoT representatio.
Conversion is based on the calculated quantization levels from a specified APoT non-uniform observer.
r""" Quantizes fp Tensor to integer APoT representation.
Conversion is based on the qparams from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: integer APoT representation of tensor2quantize
result: APoT Tensor representation of tensor2quantize
"""
def quantize_APoT(self, tensor2quantize: Tensor):
def quantize(self, tensor2quantize: Tensor):
result = torch.tensor([])
# clip tensor2quantize values based on alpha qparam
tensor2quantize = torch.clamp(tensor2quantize, -self.alpha, self.alpha)
# map float_to_apot over tensor2quantize elements
result = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize)
return result
r""" Dequantizes integer Tensor to floating point representation
r""" Dequantizes integer Tensor to floating point (fp) representation
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
self: APoTQuantizer with attr data to dequantize
apot_tensor: quantized APoT Tensor to dequantize
Returns:
result: floating point representation of input Tensor
result: fp representation of input Tensor
"""
def dequantize(self, float2apot: Tensor): # type: ignore[override]
float2apot = float2apot.float()
quantization_levels = self.quantization_levels
level_indices = self.level_indices
def dequantize(self, apot_tensor) -> Tensor:
apot_tensor_data = apot_tensor.data
# map apot_to_float over tensor2quantize elements
result = float2apot.apply_(lambda x: float(apot_to_float(x, quantization_levels, level_indices)))
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
return result
def q_apot_alpha(self) -> float:
raise NotImplementedError
r""" Global method to create quantizer and call quantizer quantize_APoT
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: ApoT Tensor representation of tensor2quantize
"""
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quantize(tensor2quantize)
return result
r""" Global method to create quantizer and call quantizer dequantize_APoT
Args:
apot_tensor: APoT Tensor to dequantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: fp Tensor dequantized from apot_tensor
"""
def dequantize_APoT(apot_tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor) -> Tensor:
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = apot_tensor.quantizer.dequantize(apot_tensor)
return result