diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 05692969253d..498ce8aca4c8 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -20,6 +20,14 @@ struct FBGEMM_API PackedFCWeight { int w_zp; }; +struct FBGEMM_API PackedConvWeight { + std::unique_ptr> w; + std::vector col_offsets; + std::vector kernel; + float w_scale; + int32_t w_zp; +}; + // Convert the weight from uint8 to int8. static void convert_uint8_int8( int K, diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp new file mode 100644 index 000000000000..7675a1cd02e6 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -0,0 +1,156 @@ +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { +class QConv2dInt8 final : public c10::OperatorKernel { + public: +#ifdef USE_FBGEMM + Tensor operator()( + Tensor act, + Tensor packed_weight, + Tensor bias, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const std::vector& output_padding, + int64_t groups, + double output_scale, + int64_t output_zero_point) { + TORCH_CHECK( + fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + TORCH_CHECK( + act.ndimension() == 4, + "Activations are supposed to have 4 dimensions."); + TORCH_CHECK(stride.size() == 2, "2D convolution only"); + TORCH_CHECK(padding.size() == 2, "2D convolution only"); + TORCH_CHECK(dilation.size() == 2, "2D convolution only"); + TORCH_CHECK(output_padding.size() == 2, "2D convolution only"); + TORCH_CHECK( + (dilation[0] == 1 && dilation[1] == 1), + "Currently dilation should be 1"); + TORCH_CHECK( + (output_padding[0] == 0 && output_padding[1] == 0), + "Currently output padding should be 0"); + + // inputs are in NHWC format + int N = act.size(0); + int H = act.size(1); + int W = act.size(2); + int C = act.size(3); + int K = bias.size(0); + + Tensor act_contig = act.contiguous(); + const uint8_t* act_ptr = + reinterpret_cast(act_contig.data()); + + PackedConvWeight& pack_ptr = + cpp_custom_type_hack::cast(packed_weight); + auto packB = pack_ptr.w.get(); + // packB->printPackedMatrix("PackedB inside QConv2dInt8:"); + auto& col_offsets = pack_ptr.col_offsets; + auto& kernel = pack_ptr.kernel; + + std::vector row_offset_buf( + fbgemm::PackAWithIm2Col::rowOffsetBufferSize()); + + int pad_l = padding[0]; + int pad_t = padding[1]; + int stride_h = stride[0]; + int stride_w = stride[1]; + int kernel_h = kernel[0]; + int kernel_w = kernel[1]; + + fbgemm::conv_param_t<> conv_p( + N, // Batch size + C, // Number of input channels + K, // Number of output channels + {H, W}, + groups, + {kernel_h, kernel_w}, + {stride_h, stride_w}, + {pad_l, pad_t, pad_l, pad_t}); + + fbgemm::PackAWithIm2Col packA( + conv_p, + act_ptr, + nullptr, + act.q_zero_point().toInt(), + row_offset_buf.data()); + + fbgemm::DoNothing<> NoOpObj{}; + + auto bias_contig = bias.contiguous(); + + float act_scale = act.q_scale().toFloat(); + int32_t act_zero_point = act.q_zero_point().toInt(); + + float weight_scale_float = pack_ptr.w_scale; + int32_t weight_zero_point_int32 = pack_ptr.w_zp; + + float output_multiplier_float = + (act_scale * weight_scale_float) / static_cast(output_scale); + + fbgemm::ReQuantizeOutput outputProcObj( + NoOpObj, + &output_multiplier_float, + output_zero_point, + act_zero_point, + &weight_zero_point_int32, + packA.getRowOffsetBuffer(), + col_offsets.data(), + bias_contig.data(), + K, + groups); + + Tensor output = _empty_affine_quantized( + {N, H, W, K}, + device(kCPU).dtype(kQUInt8), + output_scale, + output_zero_point); + auto buffer = at::zeros_like(output, output.options().dtype(at::kInt)); + + // Do the GEMM + fbgemm::fbgemmPacked( + packA, + *packB, + reinterpret_cast(output.data()), + buffer.data(), + K, + outputProcObj, + 0 /* thread_id*/, + 1 /* num_threads */); + + return output; + } +#else // USE_FBGEMM + Tensor operator()( + Tensor /* activation */, + Tensor /* packed_weight */, + Tensor /* bias */, + const std::vector& /* stride */, + const std::vector& /* padding */, + const std::vector& /* dilation */, + const std::vector& /* output padding */, + int64_t /* groups */, + double /* output scale */, + int64_t /* output_zero_point */) { + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); + } +#endif // USE_FBGEMM +}; + +static auto registry = c10::RegisterOperators().op( + "quantized::fbgemm_conv2d", + c10::RegisterOperators::options().kernel().dispatchKey( + QuantizedCPUTensorId())); + +} // namespace +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp new file mode 100644 index 000000000000..f009a6aad81b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include + +namespace caffe2 { +#ifdef USE_FBGEMM +// Required for cpp_custom_type_hack to work +CAFFE_KNOWN_TYPE(PackedConvWeight); +#endif +} // namespace caffe2 + +namespace at { +namespace native { +namespace { +class QConvPackWeightInt8 final : public c10::OperatorKernel { + public: +#ifdef USE_FBGEMM + Tensor operator()(Tensor weight, int64_t groups) { + TORCH_CHECK( + weight.ndimension() == 4, "Weights are expected to have 4 dimensions"); + TORCH_CHECK(groups == 1, "Groupwise convolutions are not supported yet"); + // weights in RS(C/G)K format + // matrix dimensions after im2col + int NDim = weight.size(3) / groups; + int KDim = weight.size(0) * weight.size(1) * groups * weight.size(2); + auto weight_config = weight.contiguous(); + int weight_zero_point_int32 = weight.q_zero_point().toInt(); + TORCH_CHECK( + weight_zero_point_int32 == 0, + "Only symmetric quantization is supported for weights yet"); + const int8_t* weight_ptr_int8 = + reinterpret_cast(weight_config.data()); + + std::vector col_offsets(NDim * groups); + std::vector kernel{static_cast(weight.size(0)), + static_cast(weight.size(1))}; + std::vector weight_int8(KDim * NDim * groups); + auto ret_ptr = guts::make_unique( + PackedConvWeight{guts::make_unique>( + fbgemm::matrix_op_t::NoTranspose, + KDim, + NDim, + weight_ptr_int8, + NDim, + nullptr, // PackBMatrix manages ownership of pmat + groups), + col_offsets, + kernel, + weight.q_scale().toFloat(), + weight_zero_point_int32}); + // TODO: we will need to replace this with torchscript classes at a later + // point. + return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options()); + } +#else // USE_FBGEMM + Tensor operator()( + Tensor, /* weight */ + int64_t /* groups */ + ) { + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); + } +#endif // USE_FBGEMM +}; + +static auto registry = c10::RegisterOperators().op( + "quantized::fbgemm_conv_prepack", + c10::RegisterOperators::options().kernel().dispatchKey( + QuantizedCPUTensorId())); + +} // namespace +} // namespace native +} // namespace at diff --git a/test/test_quantized.py b/test/test_quantized.py index 1d2752180998..8553265d372e 100644 --- a/test/test_quantized.py +++ b/test/test_quantized.py @@ -25,6 +25,14 @@ def _dequantize(qx, scale, zero_point): return x +def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8): + """Requantizes a numpy array, i.e., intermediate int32 or int16 values are + converted back to given type""" + qx = (x * multiplier).round() + zero_point + qx = np.clip(qx, qmin, qmax).astype(qtype) + return qx + + # Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM. # On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction # for the 8-bit int multiplication. This instruction vertically multiplies each @@ -369,5 +377,119 @@ class TestQuantizedLinear(unittest.TestCase): np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy()) +@unittest.skipIf( + TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(), + " Quantized convolution requires FBGEMM. FBGEMM does not play" + " well with UBSAN at the moment, so we skip the test if" + " we are in a UBSAN environment.", +) +class TestQuantizedConv(unittest.TestCase): + """Tests the correctness of quantized convolution op.""" + def test_qconv(self): + + qconv = torch.ops.quantized.fbgemm_conv2d + qconv_prepack = torch.ops.quantized.fbgemm_conv_prepack + + # N + batch_size = 1 + # C + input_channels = 16 + # H, W + height = width = 24 + # K + output_channels = 8 + + kernel_h = kernel_w = 3 + stride_h = stride_w = 1 + padding_h = padding_w = 1 + dilation_h = dilation_w = 1 + groups = 1 + + W_value_min = 0 + W_value_max = 5 + # We use small values to avoid overflow. + # (the operator expects them in the format (output_channels, input_channels/groups, kernel_h, kernel_w)) + + W_init = torch.randint( + W_value_min, + W_value_max, + (output_channels, int(input_channels / groups), kernel_h, kernel_w), + ) + + b_init = torch.randint(0, 10, (output_channels,)) + + # Existing floating point conv operator + conv_op = torch.nn.Conv2d( + input_channels, + output_channels, + (kernel_h, kernel_w), + (stride_h, stride_w), + (padding_h, padding_w), + (dilation_h, dilation_w), + groups, + ) + + # assign the weights + conv_op.weight = torch.nn.Parameter( + W_init.to(dtype=torch.float), requires_grad=False + ) + conv_op.bias = torch.nn.Parameter( + b_init.to(dtype=torch.float), requires_grad=False + ) + + X_value_min = 0 + X_value_max = 4 + X_init = torch.randint( + X_value_min, X_value_max, (batch_size, input_channels, height, width) + ) + + # run on an input tensor + result_ref = conv_op(X_init.to(dtype=torch.float)) + + # reformat X_init and W_init in the required format by conv operator + # NCHW -> NHWC + X_NHWC = X_init.permute([0, 2, 3, 1]).contiguous() + # KCRS -> RSCK + W_RSCK = W_init.permute([2, 3, 1, 0]).contiguous() + + X_scale = 1.5 + # Currently only 0 as zero point is supported. + X_zero_point = 0 + X = X_scale * (X_NHWC - X_zero_point).to(dtype=torch.float) + + W_scale = 2.5 + W_zero_point = 0 + W = W_scale * (W_RSCK - W_zero_point).to(dtype=torch.float) + + X_q = X.quantize_linear(scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) + W_q = W.quantize_linear(scale=W_scale, zero_point=W_zero_point, dtype=torch.quint8) + b_q = b_init.to(dtype=torch.int32) + + W_prepack = qconv_prepack(W_q, groups) + Y_scale = 7.3 + Y_zero_point = 5 + + Y_q = qconv( + X_q, + W_prepack, + b_q, + [1, 1], # stride + [1, 1], # padding + [1, 1], # dilation + [0, 0], # output_padding + 1, # groups + Y_scale, + Y_zero_point, + ) + + result_NHWK = result_ref.permute([0, 2, 3, 1]) + result_q = _requantize( + result_NHWK.numpy(), X_scale * W_scale / Y_scale, Y_zero_point + ) + + # Make sure the results match + np.testing.assert_equal(result_q, Y_q.int_repr().numpy()) + + if __name__ == "__main__": run_tests()