mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant] Implement forward and backward autograd functions for fake quantize (#81438)
### Summary: This PR implements custom autograd functions for forward and backward to be used in APoT fake quantization. The implementation follows this doc about custom autograd functions: https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html ### Test Plan: Run tests with: `python test/quantization/core/experimental/test_fake_quantize.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/81438 Approved by: https://github.com/jerryzh168
This commit is contained in:
3
mypy.ini
3
mypy.ini
@ -70,6 +70,9 @@ ignore_missing_imports = True
|
||||
[mypy-torch.ao.quantization.experimental.APoT_tensor]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
|
||||
ignore_missing_imports = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
@ -5,6 +5,10 @@ import unittest
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
||||
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
|
||||
from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function
|
||||
forward_helper = fake_quantize_function.forward
|
||||
backward = fake_quantize_function.backward
|
||||
from torch.autograd import gradcheck
|
||||
|
||||
class TestFakeQuantize(unittest.TestCase):
|
||||
r""" Tests fake quantize calculate_qparams() method
|
||||
@ -72,5 +76,17 @@ class TestFakeQuantize(unittest.TestCase):
|
||||
with self.assertRaises(Exception):
|
||||
apot_fake.forward(torch.clone(X), False)
|
||||
|
||||
r""" Tests fake quantize helper backward() method
|
||||
using torch.autograd.gradcheck function.
|
||||
"""
|
||||
def test_backward(self):
|
||||
input = torch.randn(20, dtype=torch.double, requires_grad=True)
|
||||
|
||||
observer = APoTObserver(b=4, k=2)
|
||||
observer(input)
|
||||
alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False)
|
||||
|
||||
test = gradcheck(fake_quantize_function.apply, (input, alpha, gamma, quantization_levels, level_indices), atol=1e-4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantizeBase
|
||||
from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function
|
||||
|
||||
class APoTFakeQuantize(FakeQuantizeBase):
|
||||
alpha: Tensor
|
||||
@ -28,7 +28,6 @@ class APoTFakeQuantize(FakeQuantizeBase):
|
||||
and self.quantization_levels is not None
|
||||
and self.level_indices is not None), "Must set qparams for fake quant"
|
||||
|
||||
X = quantize_APoT(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
|
||||
X = dequantize_APoT(X)
|
||||
X = fake_quantize_function.apply(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
|
||||
|
||||
return X
|
||||
|
27
torch/ao/quantization/experimental/fake_quantize_function.py
Normal file
27
torch/ao/quantization/experimental/fake_quantize_function.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
||||
|
||||
class fake_quantize_function(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, # type: ignore[override]
|
||||
x: Tensor,
|
||||
alpha: Tensor,
|
||||
gamma: Tensor,
|
||||
quantization_levels: Tensor,
|
||||
level_indices: Tensor) -> Tensor:
|
||||
quantized_result = quantize_APoT(x, alpha, gamma, quantization_levels, level_indices)
|
||||
|
||||
# calculate mask tensor
|
||||
mask = x.detach().apply_(lambda x: (x <= alpha and x >= -alpha))
|
||||
|
||||
result = dequantize_APoT(quantized_result)
|
||||
|
||||
ctx.save_for_backward(mask)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore[override]
|
||||
mask = ctx.saved_tensors
|
||||
return grad_output * mask
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import numpy as np
|
||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
||||
|
||||
# class to store APoT quantizer and
|
||||
@ -33,10 +34,13 @@ class APoTQuantizer():
|
||||
result = torch.tensor([])
|
||||
|
||||
# map float_to_apot over tensor2quantize elements
|
||||
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x,
|
||||
self.quantization_levels,
|
||||
self.level_indices,
|
||||
self.alpha))
|
||||
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
|
||||
self.quantization_levels,
|
||||
self.level_indices,
|
||||
self.alpha))
|
||||
|
||||
# convert to APoT int representation for dtype
|
||||
tensor2quantize = tensor2quantize.int()
|
||||
|
||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||
|
||||
@ -56,7 +60,12 @@ class APoTQuantizer():
|
||||
apot_tensor_data = apot_tensor.data
|
||||
|
||||
# map apot_to_float over tensor2quantize elements
|
||||
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
|
||||
result_temp = np.empty(apot_tensor_data.size())
|
||||
for ele in apot_tensor_data:
|
||||
new_ele = apot_to_float(ele, self.quantization_levels, self.level_indices)
|
||||
np.append(result_temp, new_ele)
|
||||
|
||||
result = torch.from_numpy(result_temp).int()
|
||||
|
||||
return result
|
||||
|
||||
|
Reference in New Issue
Block a user