CUDA implementation of fakequant (#20252)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20252

Add CUDA implementation for fakequant op for quantization aware training.

Reviewed By: zafartahirov

Differential Revision: D15243386

fbshipit-source-id: 37610ab046786ffc69aaec5235e5df8304c353d6
This commit is contained in:
Jerry Zhang
2019-05-22 14:35:12 -07:00
committed by Facebook Github Bot
parent fdb923996d
commit 05543153dd
2 changed files with 229 additions and 1 deletions

View File

@ -0,0 +1,165 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/core/op_registration/op_registration.h>
#include <cmath>
/* FakeQuantize Op for PerTensorAffine quantization scheme */
namespace at { namespace native {
namespace {
/* Fake-quantizes the 'inputs' tensor.
Args:
X: Forward input tensor.
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
num_bits: Number of quantization bits.
quant_delay: Count of global steps for which to delay the quantization.
See note below.
iter: The current quantization iteration used for `quant_delay`.
Returns:
Quantized tensor (double dtype).
Notes:
- quant_delay might be set to non-zero to help weights stabilize in the
beginning of the training.
- quantization range [0, 2^bits - 1]
*/
class FakeQuantizePerTensorAffineOp_forward : public c10::OperatorKernel {
public:
at::Tensor operator()(
at::Tensor X,
double scale,
int64_t zero_point,
int64_t num_bits = 8,
int64_t quant_delay = 0,
int64_t iter = 0
) {
// Sanity checks.
TORCH_CHECK(X.is_cuda());
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
if (num_bits > 32 || num_bits < 1) {
throw std::invalid_argument("`num_bits` should be in the [1, 32] range.");
}
if (zero_point < 0) {
throw std::invalid_argument("`zero_point` must be a positive integer.");
}
if (quant_delay < 0) {
throw std::invalid_argument("`quant_delay` must be a positive integer.");
}
if (quant_delay != 0 && iter < 0) {
throw std::invalid_argument(
"`iter` must be >=0 for non-zero `quant_delay`");
}
auto Y = at::empty_like(X);
if (quant_delay > 0 && iter <= quant_delay) {
Y.copy_(X); // We might want to just return the input here.
return Y;
}
float inv_scale = 1.0f / scale;
const float quant_min = 0;
const float quant_max = (1 << num_bits) - 1;
at::cuda::CUDA_tensor_apply2<float, float>(
X,
Y,
[=] __device__ (
const float& input_val,
float& result_val) {
result_val = (fminf(quant_max, fmaxf(quant_min, (std::round(input_val * inv_scale + zero_point)))) - zero_point) * scale;
});
return Y;
}
};
/* Backward path to fake-quantize the 'inputs' tensor.
Args:
X: Forward input tensor.
dY: Backward input tensor.
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
num_bits: Number of quantization bits.
quant_delay: Count of global steps for which to delay the quantization.
See note in forward.
iter: The current quantization iteration used for `quant_delay`.
Returns:
Quantized tensor (double dtype).
Notes:
- quant_delay might be set to non-zero to help weights stabilize in the
beginning of the training.
- quantization range [0, 2^bits - 1]
*/
class FakeQuantizePerTensorAffineOp_backward : public c10::OperatorKernel {
public:
at::Tensor operator()(
at::Tensor X,
at::Tensor dY,
double scale,
int64_t zero_point,
int64_t num_bits = 8,
int64_t quant_delay = 0,
int64_t iter = 0) {
// Sanity checks.
TORCH_CHECK(X.is_cuda());
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
if (num_bits > 32 || num_bits < 1) {
throw std::invalid_argument("`num_bits` should be in the [1, 32] range.");
}
if (zero_point < 0) {
throw std::invalid_argument("`zero_point` must be a positive integer.");
}
if (quant_delay < 0) {
throw std::invalid_argument("`quant_delay` must be a positive integer.");
}
if (X.numel() <= 0) {
return X;
}
if (X.numel() != dY.numel()) {
throw std::invalid_argument("`X` and `dY` are not the same size");
}
if (quant_delay != 0 && iter < 0) {
throw std::invalid_argument(
"`iter` must be >=0 for non-zero `quant_delay`");
}
auto dX = at::zeros_like(dY);
if (quant_delay > 0 && iter <= quant_delay) {
dX.copy_(dY);
return dX;
}
float inv_scale = 1.0f / scale;
const float quant_min = 0;
const float quant_max = (1 << num_bits) - 1;
auto mask = at::empty_like(dY);
at::cuda::CUDA_tensor_apply2<float, float>(
X,
mask,
[=] __device__ (
const float& input_val,
float& result_val) {
float Xq = std::round(input_val * inv_scale + zero_point);
result_val = float(Xq >= quant_min && Xq <= quant_max);
});
dX = mask * dY;
return dX;
}
};
static auto registry =
c10::RegisterOperators()
.op("quantized::fake_quantize_per_tensor_affine_forward(Tensor X, float scale, int zero_point, int num_bits = 8, int quant_delay = 0, int iter = 0) -> Tensor",
c10::RegisterOperators::options()
.kernel<FakeQuantizePerTensorAffineOp_forward>()
.dispatchKey(CUDATensorId()))
.op("quantized::fake_quantize_per_tensor_affine_backward(Tensor X, Tensor dY, float scale, int zero_point, int num_bits=8, int quant_delay=0, int iter = 0) -> Tensor",
c10::RegisterOperators::options()
.kernel<FakeQuantizePerTensorAffineOp_backward>()
.dispatchKey(CUDATensorId()));
} // namespace
}} // namespace at::native

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.cuda
import torch.jit
import numpy as np
import unittest
@ -66,6 +67,9 @@ class TestFakeQuantizePerTensorAffine(unittest.TestCase):
np.testing.assert_allclose(dX, dX_prime, rtol=tolerance, atol=tolerance)
def test_numerical_consistency(self):
'''
Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
'''
np.random.seed(NP_RANDOM_SEED)
fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward
@ -74,13 +78,72 @@ class TestFakeQuantizePerTensorAffine(unittest.TestCase):
num_bits = 8
X = np.random.rand(20, 20) * 125
X_torch = torch.from_numpy(X).float()
Y = X_torch.quantize_linear(scale, zero_point, torch.qint8).dequantize()
Y = torch.dequantize(torch.quantize_linear(X_torch, scale, zero_point, torch.qint8))
Y_prime = fake_quantize_per_tensor_affine_forward(
X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits,
quant_delay=0, iter=0)
tolerance = 1e-6
np.testing.assert_allclose(Y, Y_prime, rtol=tolerance, atol=tolerance)
"""Tests the forward path of the FakeQuantizePerTensorAffine CUDA op."""
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_forward_cuda(self):
np.random.seed(NP_RANDOM_SEED)
fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward
scale = 3
zero_point = 2
num_bits = 8
X = np.random.rand(20, 20) * 125
X_torch = torch.from_numpy(X).float().cuda()
Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits)
Y_prime = fake_quantize_per_tensor_affine_forward(
X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits,
quant_delay=0, iter=0)
tolerance = 1e-6
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
"""Tests the backward method. Note that this runs the reference quantization
and thus the errors might be originating there."""
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_backward_cuda(self):
np.random.seed(NP_RANDOM_SEED)
fake_quantize_per_tensor_affine_backward = torch.ops.quantized.fake_quantize_per_tensor_affine_backward
scale = 3
zero_point = 2
num_bits = 8
X = np.random.rand(20, 20) * 125
Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits)
dY = Y - X # Fake gradient
dX = _fake_quantize_per_tensor_affine_grad_reference(X, dY, scale, zero_point, num_bits)
X_torch = torch.from_numpy(X).float().cuda()
dY_torch = torch.from_numpy(dY).float().cuda()
dX_prime = fake_quantize_per_tensor_affine_backward(
X=X_torch, dY=dY_torch, scale=scale, zero_point=zero_point,
num_bits=num_bits, quant_delay=0, iter=0)
tolerance = 1e-6
np.testing.assert_allclose(dX, dX_prime.cpu(), rtol=tolerance, atol=tolerance)
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_numerical_consistency_cuda(self):
'''
Comparing numerical consistency between CPU quantize/dequantize op and the CUDA fake quantize op
'''
np.random.seed(NP_RANDOM_SEED)
fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward
scale = 3
zero_point = 2
num_bits = 8
X = np.random.rand(20, 20) * 125
X_torch = torch.from_numpy(X).float()
Y = torch.dequantize(torch.quantize_linear(X_torch, scale, zero_point, torch.qint8))
Y_prime = fake_quantize_per_tensor_affine_forward(
X=X_torch.cuda(), scale=scale, zero_point=zero_point, num_bits=num_bits,
quant_delay=0, iter=0)
tolerance = 1e-6
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
if __name__ == '__main__':
run_tests()