mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Quantized Conv2d operator (#20772)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20772 Copy of D15178352 A conflicting commit landed at the same time as D15178352 that removed registering kernels using IntArrayRef, Hence, D15178352 was revered. Using std::vector instead. Reviewed By: zafartahirov Differential Revision: D15437237 fbshipit-source-id: cd2f1caebcc720352b48ce25d716cb1ca49a5197
This commit is contained in:
committed by
Facebook Github Bot
parent
aebcd80ae4
commit
cde611a66c
@ -20,6 +20,14 @@ struct FBGEMM_API PackedFCWeight {
|
|||||||
int w_zp;
|
int w_zp;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FBGEMM_API PackedConvWeight {
|
||||||
|
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
|
||||||
|
std::vector<int32_t> col_offsets;
|
||||||
|
std::vector<int32_t> kernel;
|
||||||
|
float w_scale;
|
||||||
|
int32_t w_zp;
|
||||||
|
};
|
||||||
|
|
||||||
// Convert the weight from uint8 to int8.
|
// Convert the weight from uint8 to int8.
|
||||||
static void convert_uint8_int8(
|
static void convert_uint8_int8(
|
||||||
int K,
|
int K,
|
||||||
|
156
aten/src/ATen/native/quantized/cpu/qconv.cpp
Normal file
156
aten/src/ATen/native/quantized/cpu/qconv.cpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/core/Type.h>
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace native {
|
||||||
|
namespace {
|
||||||
|
class QConv2dInt8 final : public c10::OperatorKernel {
|
||||||
|
public:
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
Tensor operator()(
|
||||||
|
Tensor act,
|
||||||
|
Tensor packed_weight,
|
||||||
|
Tensor bias,
|
||||||
|
const std::vector<int64_t>& stride,
|
||||||
|
const std::vector<int64_t>& padding,
|
||||||
|
const std::vector<int64_t>& dilation,
|
||||||
|
const std::vector<int64_t>& output_padding,
|
||||||
|
int64_t groups,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
act.ndimension() == 4,
|
||||||
|
"Activations are supposed to have 4 dimensions.");
|
||||||
|
TORCH_CHECK(stride.size() == 2, "2D convolution only");
|
||||||
|
TORCH_CHECK(padding.size() == 2, "2D convolution only");
|
||||||
|
TORCH_CHECK(dilation.size() == 2, "2D convolution only");
|
||||||
|
TORCH_CHECK(output_padding.size() == 2, "2D convolution only");
|
||||||
|
TORCH_CHECK(
|
||||||
|
(dilation[0] == 1 && dilation[1] == 1),
|
||||||
|
"Currently dilation should be 1");
|
||||||
|
TORCH_CHECK(
|
||||||
|
(output_padding[0] == 0 && output_padding[1] == 0),
|
||||||
|
"Currently output padding should be 0");
|
||||||
|
|
||||||
|
// inputs are in NHWC format
|
||||||
|
int N = act.size(0);
|
||||||
|
int H = act.size(1);
|
||||||
|
int W = act.size(2);
|
||||||
|
int C = act.size(3);
|
||||||
|
int K = bias.size(0);
|
||||||
|
|
||||||
|
Tensor act_contig = act.contiguous();
|
||||||
|
const uint8_t* act_ptr =
|
||||||
|
reinterpret_cast<uint8_t*>(act_contig.data<c10::quint8>());
|
||||||
|
|
||||||
|
PackedConvWeight& pack_ptr =
|
||||||
|
cpp_custom_type_hack::cast<PackedConvWeight>(packed_weight);
|
||||||
|
auto packB = pack_ptr.w.get();
|
||||||
|
// packB->printPackedMatrix("PackedB inside QConv2dInt8:");
|
||||||
|
auto& col_offsets = pack_ptr.col_offsets;
|
||||||
|
auto& kernel = pack_ptr.kernel;
|
||||||
|
|
||||||
|
std::vector<int32_t> row_offset_buf(
|
||||||
|
fbgemm::PackAWithIm2Col<uint8_t>::rowOffsetBufferSize());
|
||||||
|
|
||||||
|
int pad_l = padding[0];
|
||||||
|
int pad_t = padding[1];
|
||||||
|
int stride_h = stride[0];
|
||||||
|
int stride_w = stride[1];
|
||||||
|
int kernel_h = kernel[0];
|
||||||
|
int kernel_w = kernel[1];
|
||||||
|
|
||||||
|
fbgemm::conv_param_t<> conv_p(
|
||||||
|
N, // Batch size
|
||||||
|
C, // Number of input channels
|
||||||
|
K, // Number of output channels
|
||||||
|
{H, W},
|
||||||
|
groups,
|
||||||
|
{kernel_h, kernel_w},
|
||||||
|
{stride_h, stride_w},
|
||||||
|
{pad_l, pad_t, pad_l, pad_t});
|
||||||
|
|
||||||
|
fbgemm::PackAWithIm2Col<uint8_t> packA(
|
||||||
|
conv_p,
|
||||||
|
act_ptr,
|
||||||
|
nullptr,
|
||||||
|
act.q_zero_point().toInt(),
|
||||||
|
row_offset_buf.data());
|
||||||
|
|
||||||
|
fbgemm::DoNothing<> NoOpObj{};
|
||||||
|
|
||||||
|
auto bias_contig = bias.contiguous();
|
||||||
|
|
||||||
|
float act_scale = act.q_scale().toFloat();
|
||||||
|
int32_t act_zero_point = act.q_zero_point().toInt();
|
||||||
|
|
||||||
|
float weight_scale_float = pack_ptr.w_scale;
|
||||||
|
int32_t weight_zero_point_int32 = pack_ptr.w_zp;
|
||||||
|
|
||||||
|
float output_multiplier_float =
|
||||||
|
(act_scale * weight_scale_float) / static_cast<float>(output_scale);
|
||||||
|
|
||||||
|
fbgemm::ReQuantizeOutput<false> outputProcObj(
|
||||||
|
NoOpObj,
|
||||||
|
&output_multiplier_float,
|
||||||
|
output_zero_point,
|
||||||
|
act_zero_point,
|
||||||
|
&weight_zero_point_int32,
|
||||||
|
packA.getRowOffsetBuffer(),
|
||||||
|
col_offsets.data(),
|
||||||
|
bias_contig.data<int32_t>(),
|
||||||
|
K,
|
||||||
|
groups);
|
||||||
|
|
||||||
|
Tensor output = _empty_affine_quantized(
|
||||||
|
{N, H, W, K},
|
||||||
|
device(kCPU).dtype(kQUInt8),
|
||||||
|
output_scale,
|
||||||
|
output_zero_point);
|
||||||
|
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
|
||||||
|
|
||||||
|
// Do the GEMM
|
||||||
|
fbgemm::fbgemmPacked(
|
||||||
|
packA,
|
||||||
|
*packB,
|
||||||
|
reinterpret_cast<uint8_t*>(output.data<c10::quint8>()),
|
||||||
|
buffer.data<int32_t>(),
|
||||||
|
K,
|
||||||
|
outputProcObj,
|
||||||
|
0 /* thread_id*/,
|
||||||
|
1 /* num_threads */);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
Tensor operator()(
|
||||||
|
Tensor /* activation */,
|
||||||
|
Tensor /* packed_weight */,
|
||||||
|
Tensor /* bias */,
|
||||||
|
const std::vector<int64_t>& /* stride */,
|
||||||
|
const std::vector<int64_t>& /* padding */,
|
||||||
|
const std::vector<int64_t>& /* dilation */,
|
||||||
|
const std::vector<int64_t>& /* output padding */,
|
||||||
|
int64_t /* groups */,
|
||||||
|
double /* output scale */,
|
||||||
|
int64_t /* output_zero_point */) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
}
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
};
|
||||||
|
|
||||||
|
static auto registry = c10::RegisterOperators().op(
|
||||||
|
"quantized::fbgemm_conv2d",
|
||||||
|
c10::RegisterOperators::options().kernel<QConv2dInt8>().dispatchKey(
|
||||||
|
QuantizedCPUTensorId()));
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace native
|
||||||
|
} // namespace at
|
76
aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Normal file
76
aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/core/Type.h>
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
|
||||||
|
namespace caffe2 {
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
// Required for cpp_custom_type_hack to work
|
||||||
|
CAFFE_KNOWN_TYPE(PackedConvWeight);
|
||||||
|
#endif
|
||||||
|
} // namespace caffe2
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace native {
|
||||||
|
namespace {
|
||||||
|
class QConvPackWeightInt8 final : public c10::OperatorKernel {
|
||||||
|
public:
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
Tensor operator()(Tensor weight, int64_t groups) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.ndimension() == 4, "Weights are expected to have 4 dimensions");
|
||||||
|
TORCH_CHECK(groups == 1, "Groupwise convolutions are not supported yet");
|
||||||
|
// weights in RS(C/G)K format
|
||||||
|
// matrix dimensions after im2col
|
||||||
|
int NDim = weight.size(3) / groups;
|
||||||
|
int KDim = weight.size(0) * weight.size(1) * groups * weight.size(2);
|
||||||
|
auto weight_config = weight.contiguous();
|
||||||
|
int weight_zero_point_int32 = weight.q_zero_point().toInt();
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight_zero_point_int32 == 0,
|
||||||
|
"Only symmetric quantization is supported for weights yet");
|
||||||
|
const int8_t* weight_ptr_int8 =
|
||||||
|
reinterpret_cast<int8_t*>(weight_config.data<c10::quint8>());
|
||||||
|
|
||||||
|
std::vector<int32_t> col_offsets(NDim * groups);
|
||||||
|
std::vector<int32_t> kernel{static_cast<int>(weight.size(0)),
|
||||||
|
static_cast<int>(weight.size(1))};
|
||||||
|
std::vector<int8_t> weight_int8(KDim * NDim * groups);
|
||||||
|
auto ret_ptr = guts::make_unique<PackedConvWeight>(
|
||||||
|
PackedConvWeight{guts::make_unique<fbgemm::PackBMatrix<int8_t>>(
|
||||||
|
fbgemm::matrix_op_t::NoTranspose,
|
||||||
|
KDim,
|
||||||
|
NDim,
|
||||||
|
weight_ptr_int8,
|
||||||
|
NDim,
|
||||||
|
nullptr, // PackBMatrix manages ownership of pmat
|
||||||
|
groups),
|
||||||
|
col_offsets,
|
||||||
|
kernel,
|
||||||
|
weight.q_scale().toFloat(),
|
||||||
|
weight_zero_point_int32});
|
||||||
|
// TODO: we will need to replace this with torchscript classes at a later
|
||||||
|
// point.
|
||||||
|
return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
|
||||||
|
}
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
Tensor operator()(
|
||||||
|
Tensor, /* weight */
|
||||||
|
int64_t /* groups */
|
||||||
|
) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
}
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
};
|
||||||
|
|
||||||
|
static auto registry = c10::RegisterOperators().op(
|
||||||
|
"quantized::fbgemm_conv_prepack",
|
||||||
|
c10::RegisterOperators::options().kernel<QConvPackWeightInt8>().dispatchKey(
|
||||||
|
QuantizedCPUTensorId()));
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace native
|
||||||
|
} // namespace at
|
@ -25,6 +25,14 @@ def _dequantize(qx, scale, zero_point):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
|
||||||
|
"""Requantizes a numpy array, i.e., intermediate int32 or int16 values are
|
||||||
|
converted back to given type"""
|
||||||
|
qx = (x * multiplier).round() + zero_point
|
||||||
|
qx = np.clip(qx, qmin, qmax).astype(qtype)
|
||||||
|
return qx
|
||||||
|
|
||||||
|
|
||||||
# Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM.
|
# Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM.
|
||||||
# On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction
|
# On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction
|
||||||
# for the 8-bit int multiplication. This instruction vertically multiplies each
|
# for the 8-bit int multiplication. This instruction vertically multiplies each
|
||||||
@ -369,5 +377,119 @@ class TestQuantizedLinear(unittest.TestCase):
|
|||||||
np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
|
np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
|
||||||
|
" Quantized convolution requires FBGEMM. FBGEMM does not play"
|
||||||
|
" well with UBSAN at the moment, so we skip the test if"
|
||||||
|
" we are in a UBSAN environment.",
|
||||||
|
)
|
||||||
|
class TestQuantizedConv(unittest.TestCase):
|
||||||
|
"""Tests the correctness of quantized convolution op."""
|
||||||
|
def test_qconv(self):
|
||||||
|
|
||||||
|
qconv = torch.ops.quantized.fbgemm_conv2d
|
||||||
|
qconv_prepack = torch.ops.quantized.fbgemm_conv_prepack
|
||||||
|
|
||||||
|
# N
|
||||||
|
batch_size = 1
|
||||||
|
# C
|
||||||
|
input_channels = 16
|
||||||
|
# H, W
|
||||||
|
height = width = 24
|
||||||
|
# K
|
||||||
|
output_channels = 8
|
||||||
|
|
||||||
|
kernel_h = kernel_w = 3
|
||||||
|
stride_h = stride_w = 1
|
||||||
|
padding_h = padding_w = 1
|
||||||
|
dilation_h = dilation_w = 1
|
||||||
|
groups = 1
|
||||||
|
|
||||||
|
W_value_min = 0
|
||||||
|
W_value_max = 5
|
||||||
|
# We use small values to avoid overflow.
|
||||||
|
# (the operator expects them in the format (output_channels, input_channels/groups, kernel_h, kernel_w))
|
||||||
|
|
||||||
|
W_init = torch.randint(
|
||||||
|
W_value_min,
|
||||||
|
W_value_max,
|
||||||
|
(output_channels, int(input_channels / groups), kernel_h, kernel_w),
|
||||||
|
)
|
||||||
|
|
||||||
|
b_init = torch.randint(0, 10, (output_channels,))
|
||||||
|
|
||||||
|
# Existing floating point conv operator
|
||||||
|
conv_op = torch.nn.Conv2d(
|
||||||
|
input_channels,
|
||||||
|
output_channels,
|
||||||
|
(kernel_h, kernel_w),
|
||||||
|
(stride_h, stride_w),
|
||||||
|
(padding_h, padding_w),
|
||||||
|
(dilation_h, dilation_w),
|
||||||
|
groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
# assign the weights
|
||||||
|
conv_op.weight = torch.nn.Parameter(
|
||||||
|
W_init.to(dtype=torch.float), requires_grad=False
|
||||||
|
)
|
||||||
|
conv_op.bias = torch.nn.Parameter(
|
||||||
|
b_init.to(dtype=torch.float), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
X_value_min = 0
|
||||||
|
X_value_max = 4
|
||||||
|
X_init = torch.randint(
|
||||||
|
X_value_min, X_value_max, (batch_size, input_channels, height, width)
|
||||||
|
)
|
||||||
|
|
||||||
|
# run on an input tensor
|
||||||
|
result_ref = conv_op(X_init.to(dtype=torch.float))
|
||||||
|
|
||||||
|
# reformat X_init and W_init in the required format by conv operator
|
||||||
|
# NCHW -> NHWC
|
||||||
|
X_NHWC = X_init.permute([0, 2, 3, 1]).contiguous()
|
||||||
|
# KCRS -> RSCK
|
||||||
|
W_RSCK = W_init.permute([2, 3, 1, 0]).contiguous()
|
||||||
|
|
||||||
|
X_scale = 1.5
|
||||||
|
# Currently only 0 as zero point is supported.
|
||||||
|
X_zero_point = 0
|
||||||
|
X = X_scale * (X_NHWC - X_zero_point).to(dtype=torch.float)
|
||||||
|
|
||||||
|
W_scale = 2.5
|
||||||
|
W_zero_point = 0
|
||||||
|
W = W_scale * (W_RSCK - W_zero_point).to(dtype=torch.float)
|
||||||
|
|
||||||
|
X_q = X.quantize_linear(scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)
|
||||||
|
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zero_point, dtype=torch.quint8)
|
||||||
|
b_q = b_init.to(dtype=torch.int32)
|
||||||
|
|
||||||
|
W_prepack = qconv_prepack(W_q, groups)
|
||||||
|
Y_scale = 7.3
|
||||||
|
Y_zero_point = 5
|
||||||
|
|
||||||
|
Y_q = qconv(
|
||||||
|
X_q,
|
||||||
|
W_prepack,
|
||||||
|
b_q,
|
||||||
|
[1, 1], # stride
|
||||||
|
[1, 1], # padding
|
||||||
|
[1, 1], # dilation
|
||||||
|
[0, 0], # output_padding
|
||||||
|
1, # groups
|
||||||
|
Y_scale,
|
||||||
|
Y_zero_point,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_NHWK = result_ref.permute([0, 2, 3, 1])
|
||||||
|
result_q = _requantize(
|
||||||
|
result_NHWK.numpy(), X_scale * W_scale / Y_scale, Y_zero_point
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure the results match
|
||||||
|
np.testing.assert_equal(result_q, Y_q.int_repr().numpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
Reference in New Issue
Block a user