mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
CUDA implementation of fakequant (#20252)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20252 Add CUDA implementation for fakequant op for quantization aware training. Reviewed By: zafartahirov Differential Revision: D15243386 fbshipit-source-id: 37610ab046786ffc69aaec5235e5df8304c353d6
This commit is contained in:
committed by
Facebook Github Bot
parent
fdb923996d
commit
05543153dd
@ -0,0 +1,165 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <cmath>
|
||||
|
||||
/* FakeQuantize Op for PerTensorAffine quantization scheme */
|
||||
namespace at { namespace native {
|
||||
namespace {
|
||||
/* Fake-quantizes the 'inputs' tensor.
|
||||
Args:
|
||||
X: Forward input tensor.
|
||||
scale: scale of per tensor affine quantization
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
num_bits: Number of quantization bits.
|
||||
quant_delay: Count of global steps for which to delay the quantization.
|
||||
See note below.
|
||||
iter: The current quantization iteration used for `quant_delay`.
|
||||
Returns:
|
||||
Quantized tensor (double dtype).
|
||||
|
||||
Notes:
|
||||
- quant_delay might be set to non-zero to help weights stabilize in the
|
||||
beginning of the training.
|
||||
- quantization range [0, 2^bits - 1]
|
||||
*/
|
||||
class FakeQuantizePerTensorAffineOp_forward : public c10::OperatorKernel {
|
||||
public:
|
||||
at::Tensor operator()(
|
||||
at::Tensor X,
|
||||
double scale,
|
||||
int64_t zero_point,
|
||||
int64_t num_bits = 8,
|
||||
int64_t quant_delay = 0,
|
||||
int64_t iter = 0
|
||||
) {
|
||||
// Sanity checks.
|
||||
TORCH_CHECK(X.is_cuda());
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
if (num_bits > 32 || num_bits < 1) {
|
||||
throw std::invalid_argument("`num_bits` should be in the [1, 32] range.");
|
||||
}
|
||||
if (zero_point < 0) {
|
||||
throw std::invalid_argument("`zero_point` must be a positive integer.");
|
||||
}
|
||||
if (quant_delay < 0) {
|
||||
throw std::invalid_argument("`quant_delay` must be a positive integer.");
|
||||
}
|
||||
|
||||
if (quant_delay != 0 && iter < 0) {
|
||||
throw std::invalid_argument(
|
||||
"`iter` must be >=0 for non-zero `quant_delay`");
|
||||
}
|
||||
|
||||
auto Y = at::empty_like(X);
|
||||
|
||||
if (quant_delay > 0 && iter <= quant_delay) {
|
||||
Y.copy_(X); // We might want to just return the input here.
|
||||
return Y;
|
||||
}
|
||||
|
||||
float inv_scale = 1.0f / scale;
|
||||
const float quant_min = 0;
|
||||
const float quant_max = (1 << num_bits) - 1;
|
||||
at::cuda::CUDA_tensor_apply2<float, float>(
|
||||
X,
|
||||
Y,
|
||||
[=] __device__ (
|
||||
const float& input_val,
|
||||
float& result_val) {
|
||||
result_val = (fminf(quant_max, fmaxf(quant_min, (std::round(input_val * inv_scale + zero_point)))) - zero_point) * scale;
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
};
|
||||
|
||||
/* Backward path to fake-quantize the 'inputs' tensor.
|
||||
|
||||
Args:
|
||||
X: Forward input tensor.
|
||||
dY: Backward input tensor.
|
||||
scale: scale of per tensor affine quantization
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
num_bits: Number of quantization bits.
|
||||
quant_delay: Count of global steps for which to delay the quantization.
|
||||
See note in forward.
|
||||
iter: The current quantization iteration used for `quant_delay`.
|
||||
Returns:
|
||||
Quantized tensor (double dtype).
|
||||
|
||||
Notes:
|
||||
- quant_delay might be set to non-zero to help weights stabilize in the
|
||||
beginning of the training.
|
||||
- quantization range [0, 2^bits - 1]
|
||||
*/
|
||||
class FakeQuantizePerTensorAffineOp_backward : public c10::OperatorKernel {
|
||||
public:
|
||||
at::Tensor operator()(
|
||||
at::Tensor X,
|
||||
at::Tensor dY,
|
||||
double scale,
|
||||
int64_t zero_point,
|
||||
int64_t num_bits = 8,
|
||||
int64_t quant_delay = 0,
|
||||
int64_t iter = 0) {
|
||||
// Sanity checks.
|
||||
TORCH_CHECK(X.is_cuda());
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
if (num_bits > 32 || num_bits < 1) {
|
||||
throw std::invalid_argument("`num_bits` should be in the [1, 32] range.");
|
||||
}
|
||||
if (zero_point < 0) {
|
||||
throw std::invalid_argument("`zero_point` must be a positive integer.");
|
||||
}
|
||||
if (quant_delay < 0) {
|
||||
throw std::invalid_argument("`quant_delay` must be a positive integer.");
|
||||
}
|
||||
if (X.numel() <= 0) {
|
||||
return X;
|
||||
}
|
||||
if (X.numel() != dY.numel()) {
|
||||
throw std::invalid_argument("`X` and `dY` are not the same size");
|
||||
}
|
||||
|
||||
if (quant_delay != 0 && iter < 0) {
|
||||
throw std::invalid_argument(
|
||||
"`iter` must be >=0 for non-zero `quant_delay`");
|
||||
}
|
||||
|
||||
auto dX = at::zeros_like(dY);
|
||||
if (quant_delay > 0 && iter <= quant_delay) {
|
||||
dX.copy_(dY);
|
||||
return dX;
|
||||
}
|
||||
|
||||
float inv_scale = 1.0f / scale;
|
||||
const float quant_min = 0;
|
||||
const float quant_max = (1 << num_bits) - 1;
|
||||
auto mask = at::empty_like(dY);
|
||||
at::cuda::CUDA_tensor_apply2<float, float>(
|
||||
X,
|
||||
mask,
|
||||
[=] __device__ (
|
||||
const float& input_val,
|
||||
float& result_val) {
|
||||
float Xq = std::round(input_val * inv_scale + zero_point);
|
||||
result_val = float(Xq >= quant_min && Xq <= quant_max);
|
||||
});
|
||||
dX = mask * dY;
|
||||
return dX;
|
||||
}
|
||||
};
|
||||
|
||||
static auto registry =
|
||||
c10::RegisterOperators()
|
||||
.op("quantized::fake_quantize_per_tensor_affine_forward(Tensor X, float scale, int zero_point, int num_bits = 8, int quant_delay = 0, int iter = 0) -> Tensor",
|
||||
c10::RegisterOperators::options()
|
||||
.kernel<FakeQuantizePerTensorAffineOp_forward>()
|
||||
.dispatchKey(CUDATensorId()))
|
||||
.op("quantized::fake_quantize_per_tensor_affine_backward(Tensor X, Tensor dY, float scale, int zero_point, int num_bits=8, int quant_delay=0, int iter = 0) -> Tensor",
|
||||
c10::RegisterOperators::options()
|
||||
.kernel<FakeQuantizePerTensorAffineOp_backward>()
|
||||
.dispatchKey(CUDATensorId()));
|
||||
|
||||
} // namespace
|
||||
}} // namespace at::native
|
@ -1,6 +1,7 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.jit
|
||||
import numpy as np
|
||||
import unittest
|
||||
@ -66,6 +67,9 @@ class TestFakeQuantizePerTensorAffine(unittest.TestCase):
|
||||
np.testing.assert_allclose(dX, dX_prime, rtol=tolerance, atol=tolerance)
|
||||
|
||||
def test_numerical_consistency(self):
|
||||
'''
|
||||
Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
|
||||
'''
|
||||
np.random.seed(NP_RANDOM_SEED)
|
||||
fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward
|
||||
|
||||
@ -74,13 +78,72 @@ class TestFakeQuantizePerTensorAffine(unittest.TestCase):
|
||||
num_bits = 8
|
||||
X = np.random.rand(20, 20) * 125
|
||||
X_torch = torch.from_numpy(X).float()
|
||||
Y = X_torch.quantize_linear(scale, zero_point, torch.qint8).dequantize()
|
||||
Y = torch.dequantize(torch.quantize_linear(X_torch, scale, zero_point, torch.qint8))
|
||||
Y_prime = fake_quantize_per_tensor_affine_forward(
|
||||
X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits,
|
||||
quant_delay=0, iter=0)
|
||||
tolerance = 1e-6
|
||||
np.testing.assert_allclose(Y, Y_prime, rtol=tolerance, atol=tolerance)
|
||||
|
||||
"""Tests the forward path of the FakeQuantizePerTensorAffine CUDA op."""
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||
def test_forward_cuda(self):
|
||||
np.random.seed(NP_RANDOM_SEED)
|
||||
fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward
|
||||
|
||||
scale = 3
|
||||
zero_point = 2
|
||||
num_bits = 8
|
||||
X = np.random.rand(20, 20) * 125
|
||||
X_torch = torch.from_numpy(X).float().cuda()
|
||||
Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits)
|
||||
Y_prime = fake_quantize_per_tensor_affine_forward(
|
||||
X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits,
|
||||
quant_delay=0, iter=0)
|
||||
tolerance = 1e-6
|
||||
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||
|
||||
"""Tests the backward method. Note that this runs the reference quantization
|
||||
and thus the errors might be originating there."""
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||
def test_backward_cuda(self):
|
||||
np.random.seed(NP_RANDOM_SEED)
|
||||
fake_quantize_per_tensor_affine_backward = torch.ops.quantized.fake_quantize_per_tensor_affine_backward
|
||||
|
||||
scale = 3
|
||||
zero_point = 2
|
||||
num_bits = 8
|
||||
X = np.random.rand(20, 20) * 125
|
||||
Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits)
|
||||
dY = Y - X # Fake gradient
|
||||
dX = _fake_quantize_per_tensor_affine_grad_reference(X, dY, scale, zero_point, num_bits)
|
||||
X_torch = torch.from_numpy(X).float().cuda()
|
||||
dY_torch = torch.from_numpy(dY).float().cuda()
|
||||
dX_prime = fake_quantize_per_tensor_affine_backward(
|
||||
X=X_torch, dY=dY_torch, scale=scale, zero_point=zero_point,
|
||||
num_bits=num_bits, quant_delay=0, iter=0)
|
||||
tolerance = 1e-6
|
||||
np.testing.assert_allclose(dX, dX_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||
def test_numerical_consistency_cuda(self):
|
||||
'''
|
||||
Comparing numerical consistency between CPU quantize/dequantize op and the CUDA fake quantize op
|
||||
'''
|
||||
np.random.seed(NP_RANDOM_SEED)
|
||||
fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward
|
||||
|
||||
scale = 3
|
||||
zero_point = 2
|
||||
num_bits = 8
|
||||
X = np.random.rand(20, 20) * 125
|
||||
X_torch = torch.from_numpy(X).float()
|
||||
Y = torch.dequantize(torch.quantize_linear(X_torch, scale, zero_point, torch.qint8))
|
||||
Y_prime = fake_quantize_per_tensor_affine_forward(
|
||||
X=X_torch.cuda(), scale=scale, zero_point=zero_point, num_bits=num_bits,
|
||||
quant_delay=0, iter=0)
|
||||
tolerance = 1e-6
|
||||
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user