Quantized Conv2d operator (#20772)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20772

Copy of D15178352

A conflicting commit landed at the same time as D15178352 that removed registering kernels using IntArrayRef, Hence, D15178352 was revered. Using std::vector instead.

Reviewed By: zafartahirov

Differential Revision: D15437237

fbshipit-source-id: cd2f1caebcc720352b48ce25d716cb1ca49a5197
This commit is contained in:
Daya S Khudia
2019-05-22 17:47:20 -07:00
committed by Facebook Github Bot
parent aebcd80ae4
commit cde611a66c
4 changed files with 362 additions and 0 deletions

View File

@ -20,6 +20,14 @@ struct FBGEMM_API PackedFCWeight {
int w_zp;
};
struct FBGEMM_API PackedConvWeight {
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
std::vector<int32_t> col_offsets;
std::vector<int32_t> kernel;
float w_scale;
int32_t w_zp;
};
// Convert the weight from uint8 to int8.
static void convert_uint8_int8(
int K,

View File

@ -0,0 +1,156 @@
#include <ATen/ATen.h>
#include <ATen/core/Type.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/quantized/Quantizer.h>
namespace at {
namespace native {
namespace {
class QConv2dInt8 final : public c10::OperatorKernel {
public:
#ifdef USE_FBGEMM
Tensor operator()(
Tensor act,
Tensor packed_weight,
Tensor bias,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const std::vector<int64_t>& output_padding,
int64_t groups,
double output_scale,
int64_t output_zero_point) {
TORCH_CHECK(
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
TORCH_CHECK(
act.ndimension() == 4,
"Activations are supposed to have 4 dimensions.");
TORCH_CHECK(stride.size() == 2, "2D convolution only");
TORCH_CHECK(padding.size() == 2, "2D convolution only");
TORCH_CHECK(dilation.size() == 2, "2D convolution only");
TORCH_CHECK(output_padding.size() == 2, "2D convolution only");
TORCH_CHECK(
(dilation[0] == 1 && dilation[1] == 1),
"Currently dilation should be 1");
TORCH_CHECK(
(output_padding[0] == 0 && output_padding[1] == 0),
"Currently output padding should be 0");
// inputs are in NHWC format
int N = act.size(0);
int H = act.size(1);
int W = act.size(2);
int C = act.size(3);
int K = bias.size(0);
Tensor act_contig = act.contiguous();
const uint8_t* act_ptr =
reinterpret_cast<uint8_t*>(act_contig.data<c10::quint8>());
PackedConvWeight& pack_ptr =
cpp_custom_type_hack::cast<PackedConvWeight>(packed_weight);
auto packB = pack_ptr.w.get();
// packB->printPackedMatrix("PackedB inside QConv2dInt8:");
auto& col_offsets = pack_ptr.col_offsets;
auto& kernel = pack_ptr.kernel;
std::vector<int32_t> row_offset_buf(
fbgemm::PackAWithIm2Col<uint8_t>::rowOffsetBufferSize());
int pad_l = padding[0];
int pad_t = padding[1];
int stride_h = stride[0];
int stride_w = stride[1];
int kernel_h = kernel[0];
int kernel_w = kernel[1];
fbgemm::conv_param_t<> conv_p(
N, // Batch size
C, // Number of input channels
K, // Number of output channels
{H, W},
groups,
{kernel_h, kernel_w},
{stride_h, stride_w},
{pad_l, pad_t, pad_l, pad_t});
fbgemm::PackAWithIm2Col<uint8_t> packA(
conv_p,
act_ptr,
nullptr,
act.q_zero_point().toInt(),
row_offset_buf.data());
fbgemm::DoNothing<> NoOpObj{};
auto bias_contig = bias.contiguous();
float act_scale = act.q_scale().toFloat();
int32_t act_zero_point = act.q_zero_point().toInt();
float weight_scale_float = pack_ptr.w_scale;
int32_t weight_zero_point_int32 = pack_ptr.w_zp;
float output_multiplier_float =
(act_scale * weight_scale_float) / static_cast<float>(output_scale);
fbgemm::ReQuantizeOutput<false> outputProcObj(
NoOpObj,
&output_multiplier_float,
output_zero_point,
act_zero_point,
&weight_zero_point_int32,
packA.getRowOffsetBuffer(),
col_offsets.data(),
bias_contig.data<int32_t>(),
K,
groups);
Tensor output = _empty_affine_quantized(
{N, H, W, K},
device(kCPU).dtype(kQUInt8),
output_scale,
output_zero_point);
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
// Do the GEMM
fbgemm::fbgemmPacked(
packA,
*packB,
reinterpret_cast<uint8_t*>(output.data<c10::quint8>()),
buffer.data<int32_t>(),
K,
outputProcObj,
0 /* thread_id*/,
1 /* num_threads */);
return output;
}
#else // USE_FBGEMM
Tensor operator()(
Tensor /* activation */,
Tensor /* packed_weight */,
Tensor /* bias */,
const std::vector<int64_t>& /* stride */,
const std::vector<int64_t>& /* padding */,
const std::vector<int64_t>& /* dilation */,
const std::vector<int64_t>& /* output padding */,
int64_t /* groups */,
double /* output scale */,
int64_t /* output_zero_point */) {
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
#endif // USE_FBGEMM
};
static auto registry = c10::RegisterOperators().op(
"quantized::fbgemm_conv2d",
c10::RegisterOperators::options().kernel<QConv2dInt8>().dispatchKey(
QuantizedCPUTensorId()));
} // namespace
} // namespace native
} // namespace at

View File

@ -0,0 +1,76 @@
#include <ATen/ATen.h>
#include <ATen/core/Type.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/quantized/Quantizer.h>
namespace caffe2 {
#ifdef USE_FBGEMM
// Required for cpp_custom_type_hack to work
CAFFE_KNOWN_TYPE(PackedConvWeight);
#endif
} // namespace caffe2
namespace at {
namespace native {
namespace {
class QConvPackWeightInt8 final : public c10::OperatorKernel {
public:
#ifdef USE_FBGEMM
Tensor operator()(Tensor weight, int64_t groups) {
TORCH_CHECK(
weight.ndimension() == 4, "Weights are expected to have 4 dimensions");
TORCH_CHECK(groups == 1, "Groupwise convolutions are not supported yet");
// weights in RS(C/G)K format
// matrix dimensions after im2col
int NDim = weight.size(3) / groups;
int KDim = weight.size(0) * weight.size(1) * groups * weight.size(2);
auto weight_config = weight.contiguous();
int weight_zero_point_int32 = weight.q_zero_point().toInt();
TORCH_CHECK(
weight_zero_point_int32 == 0,
"Only symmetric quantization is supported for weights yet");
const int8_t* weight_ptr_int8 =
reinterpret_cast<int8_t*>(weight_config.data<c10::quint8>());
std::vector<int32_t> col_offsets(NDim * groups);
std::vector<int32_t> kernel{static_cast<int>(weight.size(0)),
static_cast<int>(weight.size(1))};
std::vector<int8_t> weight_int8(KDim * NDim * groups);
auto ret_ptr = guts::make_unique<PackedConvWeight>(
PackedConvWeight{guts::make_unique<fbgemm::PackBMatrix<int8_t>>(
fbgemm::matrix_op_t::NoTranspose,
KDim,
NDim,
weight_ptr_int8,
NDim,
nullptr, // PackBMatrix manages ownership of pmat
groups),
col_offsets,
kernel,
weight.q_scale().toFloat(),
weight_zero_point_int32});
// TODO: we will need to replace this with torchscript classes at a later
// point.
return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
}
#else // USE_FBGEMM
Tensor operator()(
Tensor, /* weight */
int64_t /* groups */
) {
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
#endif // USE_FBGEMM
};
static auto registry = c10::RegisterOperators().op(
"quantized::fbgemm_conv_prepack",
c10::RegisterOperators::options().kernel<QConvPackWeightInt8>().dispatchKey(
QuantizedCPUTensorId()));
} // namespace
} // namespace native
} // namespace at

View File

@ -25,6 +25,14 @@ def _dequantize(qx, scale, zero_point):
return x
def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
"""Requantizes a numpy array, i.e., intermediate int32 or int16 values are
converted back to given type"""
qx = (x * multiplier).round() + zero_point
qx = np.clip(qx, qmin, qmax).astype(qtype)
return qx
# Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM.
# On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction
# for the 8-bit int multiplication. This instruction vertically multiplies each
@ -369,5 +377,119 @@ class TestQuantizedLinear(unittest.TestCase):
np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
@unittest.skipIf(
TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
" Quantized convolution requires FBGEMM. FBGEMM does not play"
" well with UBSAN at the moment, so we skip the test if"
" we are in a UBSAN environment.",
)
class TestQuantizedConv(unittest.TestCase):
"""Tests the correctness of quantized convolution op."""
def test_qconv(self):
qconv = torch.ops.quantized.fbgemm_conv2d
qconv_prepack = torch.ops.quantized.fbgemm_conv_prepack
# N
batch_size = 1
# C
input_channels = 16
# H, W
height = width = 24
# K
output_channels = 8
kernel_h = kernel_w = 3
stride_h = stride_w = 1
padding_h = padding_w = 1
dilation_h = dilation_w = 1
groups = 1
W_value_min = 0
W_value_max = 5
# We use small values to avoid overflow.
# (the operator expects them in the format (output_channels, input_channels/groups, kernel_h, kernel_w))
W_init = torch.randint(
W_value_min,
W_value_max,
(output_channels, int(input_channels / groups), kernel_h, kernel_w),
)
b_init = torch.randint(0, 10, (output_channels,))
# Existing floating point conv operator
conv_op = torch.nn.Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(padding_h, padding_w),
(dilation_h, dilation_w),
groups,
)
# assign the weights
conv_op.weight = torch.nn.Parameter(
W_init.to(dtype=torch.float), requires_grad=False
)
conv_op.bias = torch.nn.Parameter(
b_init.to(dtype=torch.float), requires_grad=False
)
X_value_min = 0
X_value_max = 4
X_init = torch.randint(
X_value_min, X_value_max, (batch_size, input_channels, height, width)
)
# run on an input tensor
result_ref = conv_op(X_init.to(dtype=torch.float))
# reformat X_init and W_init in the required format by conv operator
# NCHW -> NHWC
X_NHWC = X_init.permute([0, 2, 3, 1]).contiguous()
# KCRS -> RSCK
W_RSCK = W_init.permute([2, 3, 1, 0]).contiguous()
X_scale = 1.5
# Currently only 0 as zero point is supported.
X_zero_point = 0
X = X_scale * (X_NHWC - X_zero_point).to(dtype=torch.float)
W_scale = 2.5
W_zero_point = 0
W = W_scale * (W_RSCK - W_zero_point).to(dtype=torch.float)
X_q = X.quantize_linear(scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zero_point, dtype=torch.quint8)
b_q = b_init.to(dtype=torch.int32)
W_prepack = qconv_prepack(W_q, groups)
Y_scale = 7.3
Y_zero_point = 5
Y_q = qconv(
X_q,
W_prepack,
b_q,
[1, 1], # stride
[1, 1], # padding
[1, 1], # dilation
[0, 0], # output_padding
1, # groups
Y_scale,
Y_zero_point,
)
result_NHWK = result_ref.permute([0, 2, 3, 1])
result_q = _requantize(
result_NHWK.numpy(), X_scale * W_scale / Y_scale, Y_zero_point
)
# Make sure the results match
np.testing.assert_equal(result_q, Y_q.int_repr().numpy())
if __name__ == "__main__":
run_tests()