From 05543153dd7848debe0b147e44d411f0d60c5669 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 22 May 2019 14:35:12 -0700 Subject: [PATCH] CUDA implementation of fakequant (#20252) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20252 Add CUDA implementation for fakequant op for quantization aware training. Reviewed By: zafartahirov Differential Revision: D15243386 fbshipit-source-id: 37610ab046786ffc69aaec5235e5df8304c353d6 --- .../cuda/fake_quantize_per_tensor_affine.cu | 165 ++++++++++++++++++ test/test_fake_quant.py | 65 ++++++- 2 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 aten/src/ATen/native/quantized/cuda/fake_quantize_per_tensor_affine.cu diff --git a/aten/src/ATen/native/quantized/cuda/fake_quantize_per_tensor_affine.cu b/aten/src/ATen/native/quantized/cuda/fake_quantize_per_tensor_affine.cu new file mode 100644 index 000000000000..c50260c972c6 --- /dev/null +++ b/aten/src/ATen/native/quantized/cuda/fake_quantize_per_tensor_affine.cu @@ -0,0 +1,165 @@ +#include +#include +#include +#include + +/* FakeQuantize Op for PerTensorAffine quantization scheme */ +namespace at { namespace native { +namespace { +/* Fake-quantizes the 'inputs' tensor. +Args: + X: Forward input tensor. + scale: scale of per tensor affine quantization + zero_point: zero_point of per tensor affine quantization + num_bits: Number of quantization bits. + quant_delay: Count of global steps for which to delay the quantization. + See note below. + iter: The current quantization iteration used for `quant_delay`. +Returns: + Quantized tensor (double dtype). + +Notes: + - quant_delay might be set to non-zero to help weights stabilize in the + beginning of the training. + - quantization range [0, 2^bits - 1] +*/ +class FakeQuantizePerTensorAffineOp_forward : public c10::OperatorKernel { + public: + at::Tensor operator()( + at::Tensor X, + double scale, + int64_t zero_point, + int64_t num_bits = 8, + int64_t quant_delay = 0, + int64_t iter = 0 + ) { + // Sanity checks. + TORCH_CHECK(X.is_cuda()); + TORCH_CHECK(X.scalar_type() == ScalarType::Float); + if (num_bits > 32 || num_bits < 1) { + throw std::invalid_argument("`num_bits` should be in the [1, 32] range."); + } + if (zero_point < 0) { + throw std::invalid_argument("`zero_point` must be a positive integer."); + } + if (quant_delay < 0) { + throw std::invalid_argument("`quant_delay` must be a positive integer."); + } + + if (quant_delay != 0 && iter < 0) { + throw std::invalid_argument( + "`iter` must be >=0 for non-zero `quant_delay`"); + } + + auto Y = at::empty_like(X); + + if (quant_delay > 0 && iter <= quant_delay) { + Y.copy_(X); // We might want to just return the input here. + return Y; + } + + float inv_scale = 1.0f / scale; + const float quant_min = 0; + const float quant_max = (1 << num_bits) - 1; + at::cuda::CUDA_tensor_apply2( + X, + Y, + [=] __device__ ( + const float& input_val, + float& result_val) { + result_val = (fminf(quant_max, fmaxf(quant_min, (std::round(input_val * inv_scale + zero_point)))) - zero_point) * scale; + }); + return Y; + } +}; + +/* Backward path to fake-quantize the 'inputs' tensor. + +Args: + X: Forward input tensor. + dY: Backward input tensor. + scale: scale of per tensor affine quantization + zero_point: zero_point of per tensor affine quantization + num_bits: Number of quantization bits. + quant_delay: Count of global steps for which to delay the quantization. + See note in forward. + iter: The current quantization iteration used for `quant_delay`. +Returns: + Quantized tensor (double dtype). + +Notes: + - quant_delay might be set to non-zero to help weights stabilize in the + beginning of the training. + - quantization range [0, 2^bits - 1] +*/ +class FakeQuantizePerTensorAffineOp_backward : public c10::OperatorKernel { + public: + at::Tensor operator()( + at::Tensor X, + at::Tensor dY, + double scale, + int64_t zero_point, + int64_t num_bits = 8, + int64_t quant_delay = 0, + int64_t iter = 0) { + // Sanity checks. + TORCH_CHECK(X.is_cuda()); + TORCH_CHECK(X.scalar_type() == ScalarType::Float); + if (num_bits > 32 || num_bits < 1) { + throw std::invalid_argument("`num_bits` should be in the [1, 32] range."); + } + if (zero_point < 0) { + throw std::invalid_argument("`zero_point` must be a positive integer."); + } + if (quant_delay < 0) { + throw std::invalid_argument("`quant_delay` must be a positive integer."); + } + if (X.numel() <= 0) { + return X; + } + if (X.numel() != dY.numel()) { + throw std::invalid_argument("`X` and `dY` are not the same size"); + } + + if (quant_delay != 0 && iter < 0) { + throw std::invalid_argument( + "`iter` must be >=0 for non-zero `quant_delay`"); + } + + auto dX = at::zeros_like(dY); + if (quant_delay > 0 && iter <= quant_delay) { + dX.copy_(dY); + return dX; + } + + float inv_scale = 1.0f / scale; + const float quant_min = 0; + const float quant_max = (1 << num_bits) - 1; + auto mask = at::empty_like(dY); + at::cuda::CUDA_tensor_apply2( + X, + mask, + [=] __device__ ( + const float& input_val, + float& result_val) { + float Xq = std::round(input_val * inv_scale + zero_point); + result_val = float(Xq >= quant_min && Xq <= quant_max); + }); + dX = mask * dY; + return dX; + } +}; + +static auto registry = + c10::RegisterOperators() + .op("quantized::fake_quantize_per_tensor_affine_forward(Tensor X, float scale, int zero_point, int num_bits = 8, int quant_delay = 0, int iter = 0) -> Tensor", + c10::RegisterOperators::options() + .kernel() + .dispatchKey(CUDATensorId())) + .op("quantized::fake_quantize_per_tensor_affine_backward(Tensor X, Tensor dY, float scale, int zero_point, int num_bits=8, int quant_delay=0, int iter = 0) -> Tensor", + c10::RegisterOperators::options() + .kernel() + .dispatchKey(CUDATensorId())); + +} // namespace +}} // namespace at::native diff --git a/test/test_fake_quant.py b/test/test_fake_quant.py index 36c1a835aa55..7c39ee2b9b8c 100644 --- a/test/test_fake_quant.py +++ b/test/test_fake_quant.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch +import torch.cuda import torch.jit import numpy as np import unittest @@ -66,6 +67,9 @@ class TestFakeQuantizePerTensorAffine(unittest.TestCase): np.testing.assert_allclose(dX, dX_prime, rtol=tolerance, atol=tolerance) def test_numerical_consistency(self): + ''' + Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op + ''' np.random.seed(NP_RANDOM_SEED) fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward @@ -74,13 +78,72 @@ class TestFakeQuantizePerTensorAffine(unittest.TestCase): num_bits = 8 X = np.random.rand(20, 20) * 125 X_torch = torch.from_numpy(X).float() - Y = X_torch.quantize_linear(scale, zero_point, torch.qint8).dequantize() + Y = torch.dequantize(torch.quantize_linear(X_torch, scale, zero_point, torch.qint8)) Y_prime = fake_quantize_per_tensor_affine_forward( X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits, quant_delay=0, iter=0) tolerance = 1e-6 np.testing.assert_allclose(Y, Y_prime, rtol=tolerance, atol=tolerance) + """Tests the forward path of the FakeQuantizePerTensorAffine CUDA op.""" + @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') + def test_forward_cuda(self): + np.random.seed(NP_RANDOM_SEED) + fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward + + scale = 3 + zero_point = 2 + num_bits = 8 + X = np.random.rand(20, 20) * 125 + X_torch = torch.from_numpy(X).float().cuda() + Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits) + Y_prime = fake_quantize_per_tensor_affine_forward( + X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits, + quant_delay=0, iter=0) + tolerance = 1e-6 + np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance) + + """Tests the backward method. Note that this runs the reference quantization + and thus the errors might be originating there.""" + @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') + def test_backward_cuda(self): + np.random.seed(NP_RANDOM_SEED) + fake_quantize_per_tensor_affine_backward = torch.ops.quantized.fake_quantize_per_tensor_affine_backward + + scale = 3 + zero_point = 2 + num_bits = 8 + X = np.random.rand(20, 20) * 125 + Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits) + dY = Y - X # Fake gradient + dX = _fake_quantize_per_tensor_affine_grad_reference(X, dY, scale, zero_point, num_bits) + X_torch = torch.from_numpy(X).float().cuda() + dY_torch = torch.from_numpy(dY).float().cuda() + dX_prime = fake_quantize_per_tensor_affine_backward( + X=X_torch, dY=dY_torch, scale=scale, zero_point=zero_point, + num_bits=num_bits, quant_delay=0, iter=0) + tolerance = 1e-6 + np.testing.assert_allclose(dX, dX_prime.cpu(), rtol=tolerance, atol=tolerance) + + @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') + def test_numerical_consistency_cuda(self): + ''' + Comparing numerical consistency between CPU quantize/dequantize op and the CUDA fake quantize op + ''' + np.random.seed(NP_RANDOM_SEED) + fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward + + scale = 3 + zero_point = 2 + num_bits = 8 + X = np.random.rand(20, 20) * 125 + X_torch = torch.from_numpy(X).float() + Y = torch.dequantize(torch.quantize_linear(X_torch, scale, zero_point, torch.qint8)) + Y_prime = fake_quantize_per_tensor_affine_forward( + X=X_torch.cuda(), scale=scale, zero_point=zero_point, num_bits=num_bits, + quant_delay=0, iter=0) + tolerance = 1e-6 + np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance) if __name__ == '__main__': run_tests()