mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add new op wrapped_quantized_linear (#134024)
Summary: This diff adds a new operator wrapped_quantized_linear (torch.ops._quantized.wrapped_quantized_linear) and takes the following input argument: input (in fp32) , input_scale, input_zero_point, weight (in fp32), weight_scale, weight_zero_point, bias (in fp32), output_scale, output_zero_point, and out_channel. It does the following 1. Use quantize_per_tensor(input, input_scale, input_zero_point) to quantize the input tensor to int8 2. Use quantized::linear_prepack(weight, weight_scale, weight_zero_point, bias) to pack the weight and bias 3. Use quantized::linear to perform int8 quantized linear 4. dequantize This new op is essentially a wrapper of mutiple ops. We do this as torch.export cannot handle models where it has old quantize apis. Reviewed By: jerryzh168 Differential Revision: D61377266 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134024 Approved by: https://github.com/houseroad
This commit is contained in:
committed by
PyTorch MergeBot
parent
022cd7c9aa
commit
3d8db41337
@ -18,6 +18,9 @@
|
||||
#else
|
||||
#include <ATen/ops/_saturate_weight_to_fp16.h>
|
||||
#include <ATen/ops/_saturate_weight_to_fp16_native.h>
|
||||
#include <ATen/ops/dequantize.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/quantize_per_tensor.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
@ -316,6 +319,122 @@ at::Tensor _saturate_weight_to_fp16(const Tensor& weight) {
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <class... Inputs>
|
||||
inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
|
||||
return {std::forward<Inputs>(inputs)...};
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
inline std::vector<c10::IValue> callOpByHandle(
|
||||
const c10::OperatorHandle& op,
|
||||
Args... args) {
|
||||
auto stack = makeStack(std::forward<Args>(args)...);
|
||||
c10::Dispatcher::singleton().callBoxed(op, &stack);
|
||||
return stack;
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
inline std::vector<c10::IValue> callOpByName(
|
||||
const char* func_name,
|
||||
const char* overload_name,
|
||||
Args... args) {
|
||||
const std::optional<c10::OperatorHandle> op_handle =
|
||||
c10::Dispatcher::singleton().findSchema({func_name, overload_name});
|
||||
assert(op_handle.has_value());
|
||||
return callOpByHandle(op_handle.value(), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
at::Tensor wrapped_quantized_linear(
|
||||
at::Tensor input,
|
||||
const at::Tensor& input_scale,
|
||||
const at::Tensor& input_zero_point,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& weight_scale,
|
||||
const at::Tensor& weight_zero_point,
|
||||
const at::Tensor& bias,
|
||||
const at::Tensor& output_scale,
|
||||
const at::Tensor& output_zero_point,
|
||||
[[maybe_unused]] const int64_t out_channel);
|
||||
|
||||
at::Tensor wrapped_quantized_linear(
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
at::Tensor input,
|
||||
const at::Tensor& input_scale,
|
||||
const at::Tensor& input_zero_point,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& weight_scale,
|
||||
const at::Tensor& weight_zero_point,
|
||||
const at::Tensor& bias,
|
||||
const at::Tensor& output_scale,
|
||||
const at::Tensor& output_zero_point,
|
||||
[[maybe_unused]] const int64_t out_channel) {
|
||||
//This op does four things:
|
||||
// 1. Use quantize_per_tensor to quantize the input
|
||||
// 2. Use quantized::linear_prepack to prepack the weight and bias
|
||||
// 3. Use quantized::linear to do the int8 linear quantized computation
|
||||
// 4. Use dequantize to dequantize the result of quantized::linear
|
||||
// The reason we do this is because we want to have such wrapper op to
|
||||
// bypass the issue from torch.export
|
||||
#ifdef USE_FBGEMM
|
||||
auto qw = at::quantize_per_tensor(
|
||||
weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
|
||||
auto op = Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::linear_prepack", "")
|
||||
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
|
||||
at::Tensor, std::optional<at::Tensor>)>();
|
||||
auto packed_params = op.call(qw, bias);
|
||||
|
||||
auto qx = at::quantize_per_tensor(
|
||||
input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
|
||||
|
||||
const auto scale_val = output_scale.item().toFloat();
|
||||
const auto zero_point_val = output_zero_point.item().toLong();
|
||||
|
||||
auto result = callOpByName(
|
||||
"quantized::linear", "", qx, packed_params, scale_val, zero_point_val);
|
||||
|
||||
return at::dequantize(result[0].toTensor());
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_quantized_linear_meta(
|
||||
at::Tensor input,
|
||||
[[maybe_unused]] const at::Tensor& input_scale,
|
||||
[[maybe_unused]] const at::Tensor& input_zero_point,
|
||||
const at::Tensor& weight,
|
||||
[[maybe_unused]] const at::Tensor& weight_scale,
|
||||
[[maybe_unused]] const at::Tensor& weight_zero_point,
|
||||
[[maybe_unused]] const at::Tensor& bias,
|
||||
[[maybe_unused]] const at::Tensor& output_scale,
|
||||
[[maybe_unused]] const at::Tensor& output_zero_point,
|
||||
[[maybe_unused]] const int64_t out_channel);
|
||||
|
||||
at::Tensor wrapped_quantized_linear_meta(
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
at::Tensor input,
|
||||
[[maybe_unused]] const at::Tensor& input_scale,
|
||||
[[maybe_unused]] const at::Tensor& input_zero_point,
|
||||
const at::Tensor& weight,
|
||||
[[maybe_unused]] const at::Tensor& weight_scale,
|
||||
[[maybe_unused]] const at::Tensor& weight_zero_point,
|
||||
[[maybe_unused]] const at::Tensor& bias,
|
||||
[[maybe_unused]] const at::Tensor& output_scale,
|
||||
[[maybe_unused]] const at::Tensor& output_zero_point,
|
||||
[[maybe_unused]] const int64_t out_channel) {
|
||||
#ifdef USE_FBGEMM
|
||||
const at::SymInt M = input.sym_size(0);
|
||||
const at::SymInt N = weight.sym_size(0);
|
||||
auto Y = at::empty_symint({M, N}, input.options().dtype(at::kFloat));
|
||||
return Y;
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class QLinearPackWeightInt8 final {
|
||||
@ -450,6 +569,11 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
|
||||
register_linear_params();
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
|
||||
|
@ -250,6 +250,7 @@ TORCH_LIBRARY(_quantized, m) {
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(onednn, m) {
|
||||
|
@ -4184,6 +4184,46 @@ class TestQuantizedLinear(TestCase):
|
||||
np.testing.assert_equal(
|
||||
W_q.q_zero_point(), W_q_origin.q_zero_point())
|
||||
|
||||
"""Tests the correctness of the _quantized::wrapped_quantized_linear op."""
|
||||
@skipIfNoFBGEMM
|
||||
@given(
|
||||
m=st.integers(2, 6),
|
||||
k=st.integers(2, 6),
|
||||
n=st.integers(2, 6),
|
||||
)
|
||||
def test_wrapped_quantized_linear(self, m, n, k):
|
||||
input = torch.randn(m, k, dtype=torch.float32)
|
||||
input_scale = torch.tensor(0.1)
|
||||
input_zero_point = torch.tensor(0)
|
||||
weight = torch.randn(n, k, dtype=torch.float32)
|
||||
weight_scale = torch.tensor(0.1)
|
||||
weight_zero_point = torch.tensor(0)
|
||||
bias = torch.randn(n, dtype=torch.float32)
|
||||
output_scale = torch.tensor(0.1)
|
||||
output_zero_point = torch.tensor(0)
|
||||
out_channel = n
|
||||
|
||||
ret = torch.ops._quantized.wrapped_quantized_linear(
|
||||
input,
|
||||
input_scale,
|
||||
input_zero_point,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
bias,
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
out_channel,
|
||||
)
|
||||
|
||||
qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8)
|
||||
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8)
|
||||
qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias)
|
||||
qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point)
|
||||
ret_ref = qlinear.dequantize()
|
||||
self.assertEqual(ret, ret_ref)
|
||||
|
||||
|
||||
"""Tests the correctness of the quantized::linear_unpack after freeing original tensor op."""
|
||||
@skipIfNoQNNPACK
|
||||
@given(W=hu.tensor(shapes=hu.array_shapes(2, 2,),
|
||||
|
Reference in New Issue
Block a user