[quant] Implement forward and backward autograd functions for fake quantize (#81438)

### Summary:
This PR implements custom autograd functions for forward and backward to be used in APoT fake quantization. The implementation follows this doc about custom autograd functions: https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html

### Test Plan:
Run tests with: `python test/quantization/core/experimental/test_fake_quantize.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81438
Approved by: https://github.com/jerryzh168
This commit is contained in:
asl3
2022-07-18 16:09:50 -07:00
committed by PyTorch MergeBot
parent 4aac42cc98
commit 368018530e
5 changed files with 62 additions and 8 deletions

View File

@ -70,6 +70,9 @@ ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.APoT_tensor]
ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
ignore_missing_imports = True
#
# Files with various errors. Mostly real errors, possibly some false
# positives as well.

View File

@ -5,6 +5,10 @@ import unittest
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function
forward_helper = fake_quantize_function.forward
backward = fake_quantize_function.backward
from torch.autograd import gradcheck
class TestFakeQuantize(unittest.TestCase):
r""" Tests fake quantize calculate_qparams() method
@ -72,5 +76,17 @@ class TestFakeQuantize(unittest.TestCase):
with self.assertRaises(Exception):
apot_fake.forward(torch.clone(X), False)
r""" Tests fake quantize helper backward() method
using torch.autograd.gradcheck function.
"""
def test_backward(self):
input = torch.randn(20, dtype=torch.double, requires_grad=True)
observer = APoTObserver(b=4, k=2)
observer(input)
alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False)
test = gradcheck(fake_quantize_function.apply, (input, alpha, gamma, quantization_levels, level_indices), atol=1e-4)
if __name__ == '__main__':
unittest.main()

View File

@ -1,8 +1,8 @@
import torch
from torch import Tensor
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
from torch.ao.quantization.fake_quantize import FakeQuantizeBase
from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function
class APoTFakeQuantize(FakeQuantizeBase):
alpha: Tensor
@ -28,7 +28,6 @@ class APoTFakeQuantize(FakeQuantizeBase):
and self.quantization_levels is not None
and self.level_indices is not None), "Must set qparams for fake quant"
X = quantize_APoT(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
X = dequantize_APoT(X)
X = fake_quantize_function.apply(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
return X

View File

@ -0,0 +1,27 @@
import torch
from torch import Tensor
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
class fake_quantize_function(torch.autograd.Function):
@staticmethod
def forward(ctx, # type: ignore[override]
x: Tensor,
alpha: Tensor,
gamma: Tensor,
quantization_levels: Tensor,
level_indices: Tensor) -> Tensor:
quantized_result = quantize_APoT(x, alpha, gamma, quantization_levels, level_indices)
# calculate mask tensor
mask = x.detach().apply_(lambda x: (x <= alpha and x >= -alpha))
result = dequantize_APoT(quantized_result)
ctx.save_for_backward(mask)
return result
@staticmethod
def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore[override]
mask = ctx.saved_tensors
return grad_output * mask

View File

@ -1,5 +1,6 @@
import torch
from torch import Tensor
import numpy as np
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# class to store APoT quantizer and
@ -33,10 +34,13 @@ class APoTQuantizer():
result = torch.tensor([])
# map float_to_apot over tensor2quantize elements
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x,
self.quantization_levels,
self.level_indices,
self.alpha))
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
self.quantization_levels,
self.level_indices,
self.alpha))
# convert to APoT int representation for dtype
tensor2quantize = tensor2quantize.int()
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
@ -56,7 +60,12 @@ class APoTQuantizer():
apot_tensor_data = apot_tensor.data
# map apot_to_float over tensor2quantize elements
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
result_temp = np.empty(apot_tensor_data.size())
for ele in apot_tensor_data:
new_ele = apot_to_float(ele, self.quantization_levels, self.level_indices)
np.append(result_temp, new_ele)
result = torch.from_numpy(result_temp).int()
return result