Add new op wrapped_quantized_linear (#134024)

Summary:
This diff adds a new operator wrapped_quantized_linear (torch.ops._quantized.wrapped_quantized_linear) and takes the following input argument: input (in fp32) , input_scale, input_zero_point, weight (in fp32), weight_scale, weight_zero_point, bias (in fp32), output_scale, output_zero_point, and out_channel. It does the following

1. Use quantize_per_tensor(input, input_scale, input_zero_point) to quantize the input tensor to int8
2. Use quantized::linear_prepack(weight, weight_scale, weight_zero_point, bias) to pack the weight and bias
3. Use quantized::linear to perform int8 quantized linear
4. dequantize

This new op is essentially a wrapper of mutiple ops. We do this as torch.export cannot handle models where it has old quantize apis.

Reviewed By: jerryzh168

Differential Revision: D61377266

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134024
Approved by: https://github.com/houseroad
This commit is contained in:
Huamin Li
2024-08-21 09:26:58 +00:00
committed by PyTorch MergeBot
parent 022cd7c9aa
commit 3d8db41337
3 changed files with 165 additions and 0 deletions

View File

@ -18,6 +18,9 @@
#else
#include <ATen/ops/_saturate_weight_to_fp16.h>
#include <ATen/ops/_saturate_weight_to_fp16_native.h>
#include <ATen/ops/dequantize.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/quantize_per_tensor.h>
#include <ATen/ops/zeros.h>
#endif
@ -316,6 +319,122 @@ at::Tensor _saturate_weight_to_fp16(const Tensor& weight) {
return weight;
}
template <class... Inputs>
inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
return {std::forward<Inputs>(inputs)...};
}
template <class... Args>
inline std::vector<c10::IValue> callOpByHandle(
const c10::OperatorHandle& op,
Args... args) {
auto stack = makeStack(std::forward<Args>(args)...);
c10::Dispatcher::singleton().callBoxed(op, &stack);
return stack;
}
template <class... Args>
inline std::vector<c10::IValue> callOpByName(
const char* func_name,
const char* overload_name,
Args... args) {
const std::optional<c10::OperatorHandle> op_handle =
c10::Dispatcher::singleton().findSchema({func_name, overload_name});
assert(op_handle.has_value());
return callOpByHandle(op_handle.value(), std::forward<Args>(args)...);
}
at::Tensor wrapped_quantized_linear(
at::Tensor input,
const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& weight,
const at::Tensor& weight_scale,
const at::Tensor& weight_zero_point,
const at::Tensor& bias,
const at::Tensor& output_scale,
const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel);
at::Tensor wrapped_quantized_linear(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
at::Tensor input,
const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& weight,
const at::Tensor& weight_scale,
const at::Tensor& weight_zero_point,
const at::Tensor& bias,
const at::Tensor& output_scale,
const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel) {
//This op does four things:
// 1. Use quantize_per_tensor to quantize the input
// 2. Use quantized::linear_prepack to prepack the weight and bias
// 3. Use quantized::linear to do the int8 linear quantized computation
// 4. Use dequantize to dequantize the result of quantized::linear
// The reason we do this is because we want to have such wrapper op to
// bypass the issue from torch.export
#ifdef USE_FBGEMM
auto qw = at::quantize_per_tensor(
weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
auto op = Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_prepack", "")
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
at::Tensor, std::optional<at::Tensor>)>();
auto packed_params = op.call(qw, bias);
auto qx = at::quantize_per_tensor(
input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
const auto scale_val = output_scale.item().toFloat();
const auto zero_point_val = output_zero_point.item().toLong();
auto result = callOpByName(
"quantized::linear", "", qx, packed_params, scale_val, zero_point_val);
return at::dequantize(result[0].toTensor());
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_quantized_linear_meta(
at::Tensor input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
const at::Tensor& weight,
[[maybe_unused]] const at::Tensor& weight_scale,
[[maybe_unused]] const at::Tensor& weight_zero_point,
[[maybe_unused]] const at::Tensor& bias,
[[maybe_unused]] const at::Tensor& output_scale,
[[maybe_unused]] const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel);
at::Tensor wrapped_quantized_linear_meta(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
at::Tensor input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
const at::Tensor& weight,
[[maybe_unused]] const at::Tensor& weight_scale,
[[maybe_unused]] const at::Tensor& weight_zero_point,
[[maybe_unused]] const at::Tensor& bias,
[[maybe_unused]] const at::Tensor& output_scale,
[[maybe_unused]] const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel) {
#ifdef USE_FBGEMM
const at::SymInt M = input.sym_size(0);
const at::SymInt N = weight.sym_size(0);
auto Y = at::empty_symint({M, N}, input.options().dtype(at::kFloat));
return Y;
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
namespace {
class QLinearPackWeightInt8 final {
@ -450,6 +569,11 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
register_linear_params();
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear));
}
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta));
}
TORCH_LIBRARY_IMPL(onednn, CPU, m) {

View File

@ -250,6 +250,7 @@ TORCH_LIBRARY(_quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
}
TORCH_LIBRARY(onednn, m) {

View File

@ -4184,6 +4184,46 @@ class TestQuantizedLinear(TestCase):
np.testing.assert_equal(
W_q.q_zero_point(), W_q_origin.q_zero_point())
"""Tests the correctness of the _quantized::wrapped_quantized_linear op."""
@skipIfNoFBGEMM
@given(
m=st.integers(2, 6),
k=st.integers(2, 6),
n=st.integers(2, 6),
)
def test_wrapped_quantized_linear(self, m, n, k):
input = torch.randn(m, k, dtype=torch.float32)
input_scale = torch.tensor(0.1)
input_zero_point = torch.tensor(0)
weight = torch.randn(n, k, dtype=torch.float32)
weight_scale = torch.tensor(0.1)
weight_zero_point = torch.tensor(0)
bias = torch.randn(n, dtype=torch.float32)
output_scale = torch.tensor(0.1)
output_zero_point = torch.tensor(0)
out_channel = n
ret = torch.ops._quantized.wrapped_quantized_linear(
input,
input_scale,
input_zero_point,
weight,
weight_scale,
weight_zero_point,
bias,
output_scale,
output_zero_point,
out_channel,
)
qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8)
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8)
qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias)
qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point)
ret_ref = qlinear.dequantize()
self.assertEqual(ret, ret_ref)
"""Tests the correctness of the quantized::linear_unpack after freeing original tensor op."""
@skipIfNoQNNPACK
@given(W=hu.tensor(shapes=hu.array_shapes(2, 2,),