Add new ops wrapped_linear_prepack and wrapped_quantized_linear_prepacked (#134232)

Summary:
This diff adds two new operators torch.ops._quantized.wrapped_linear_prepack and torch.ops._quantized.wrapped_quantized_linear_prepacked. It is a decomposition of the op torch.ops._quantized.wrapped_quantized_linear added in the previous diff.

We decomposed in this way as packed weight could be computed early so we don;t need to do it in every forward in AOTI

Reviewed By: jerryzh168

Differential Revision: D61395887

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134232
Approved by: https://github.com/houseroad
This commit is contained in:
Huamin Li
2024-08-23 04:54:24 +00:00
committed by PyTorch MergeBot
parent b23779ef0a
commit 311af3b988
6 changed files with 189 additions and 0 deletions

View File

@ -3398,6 +3398,10 @@
- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor
- func: wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor
- func: wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

View File

@ -1,5 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/Context.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
@ -435,6 +436,129 @@ at::Tensor wrapped_quantized_linear_meta(
#endif // USE_FBGEMM
}
at::Tensor wrapped_linear_prepack(const at::Tensor& weight,
const at::Tensor& weight_scale,
const at::Tensor& weight_zero_point,
const at::Tensor& bias);
at::Tensor wrapped_linear_prepack(const at::Tensor& weight,
const at::Tensor& weight_scale,
const at::Tensor& weight_zero_point,
const at::Tensor& bias) {
// This op does two things
// 1. Use quantize_per_tensor to quantize the weight
// 2. Use quantized::linear_prepack to prepack the weight and bias
// The reason we do this is because we want to have such wrapper op to
// save the quantized weight as constants for AOTI
#ifdef USE_FBGEMM
TORCH_CHECK(
weight.dim() == 2,
"fbgemm weight packing only packs matrices not vectors.");
auto qw = at::quantize_per_tensor(
weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
auto op = Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_prepack", "")
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
at::Tensor, std::optional<at::Tensor>)>();
auto packed_params = op.call(qw, bias);
auto unique_ptr_wrapper =
std::make_unique<decltype(packed_params)>(std::move(packed_params));
auto ret = cpp_custom_type_hack::create(
std::move(unique_ptr_wrapper), weight.options());
return ret;
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& packed_weight,
const at::Tensor& output_scale,
const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel);
at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& packed_weight,
const at::Tensor& output_scale,
const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel) {
// This op is similar to wrapped_quantized_linear, but it takes the prepacked weight
#ifdef USE_FBGEMM
auto qx = at::quantize_per_tensor(
input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
const auto scale_val = output_scale.item().toFloat();
const auto zero_point_val = output_zero_point.item().toLong();
auto packed_weight_ptr =
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
packed_weight);
auto result = callOpByName(
"quantized::linear", "", qx, packed_weight_ptr, scale_val, zero_point_val);
return at::dequantize(result[0].toTensor());
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight,
[[maybe_unused]] const at::Tensor& weight_scale,
[[maybe_unused]] const at::Tensor& weight_zero_point,
[[maybe_unused]] const at::Tensor& bias);
at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight,
[[maybe_unused]] const at::Tensor& weight_scale,
[[maybe_unused]] const at::Tensor& weight_zero_point,
[[maybe_unused]] const at::Tensor& bias) {
#ifdef USE_FBGEMM
TORCH_CHECK(
weight.dim() == 2,
"fbgemm weight packing only packs matrices not vectors.");
const at::SymInt M = weight.sym_size(0);
const at::SymInt N = weight.sym_size(1);
auto Y = at::empty_symint({M, N}, weight.options().dtype(at::kFloat));
return Y;
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
[[maybe_unused]] const at::Tensor& packed_weight,
[[maybe_unused]] const at::Tensor& output_scale,
[[maybe_unused]] const at::Tensor& output_zero_point,
const int64_t out_channel);
at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
[[maybe_unused]] const at::Tensor& packed_weight,
[[maybe_unused]] const at::Tensor& output_scale,
[[maybe_unused]] const at::Tensor& output_zero_point,
const int64_t out_channel) {
#ifdef USE_FBGEMM
auto out_sizes = input.sym_sizes().vec();
TORCH_CHECK(
out_sizes.size() == 2,
"The dimension of weight tensor should be equal to 2");
out_sizes[out_sizes.size() - 1] = out_channel;
return at::empty_symint(out_sizes, input.options());
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
namespace {
class QLinearPackWeightInt8 final {
@ -570,10 +694,22 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear));
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"),
wrapped_linear_prepack);
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"),
wrapped_quantized_linear_prepacked);
}
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta));
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"),
wrapped_linear_prepack_meta);
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"),
wrapped_quantized_linear_prepacked_meta);
}
TORCH_LIBRARY_IMPL(onednn, CPU, m) {

View File

@ -251,6 +251,8 @@ TORCH_LIBRARY(_quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
}
TORCH_LIBRARY(onednn, m) {

View File

@ -228,6 +228,8 @@ xfail_not_implemented = {
"aten::var_mean.correction_names",
"aten::var_mean.names_dim",
"aten::where",
"aten::wrapped_linear_prepack",
"aten::wrapped_quantized_linear_prepacked",
}

View File

@ -4223,6 +4223,47 @@ class TestQuantizedLinear(TestCase):
ret_ref = qlinear.dequantize()
self.assertEqual(ret, ret_ref)
"""Tests the correctness of the _quantized::wrapped_linear_prepack and
_quantized::wrapped_quantized_linear_prepacked ops."""
@skipIfNoFBGEMM
@given(
m=st.integers(2, 6),
k=st.integers(2, 6),
n=st.integers(2, 6),
)
def test_wrapped_quantized_linear_prepacked(self, m, n, k):
input = torch.randn(m, k, dtype=torch.float32)
input_scale = torch.tensor(0.1)
input_zero_point = torch.tensor(0)
weight = torch.randn(n, k, dtype=torch.float32)
weight_scale = torch.tensor(0.1)
weight_zero_point = torch.tensor(0)
bias = torch.randn(n, dtype=torch.float32)
output_scale = torch.tensor(0.1)
output_zero_point = torch.tensor(0)
out_channel = n
ret_1 = torch.ops._quantized.wrapped_linear_prepack(
weight,
weight_scale,
weight_zero_point,
bias
)
ret_2 = torch.ops._quantized.wrapped_quantized_linear_prepacked(
input,
input_scale,
input_zero_point,
ret_1,
output_scale,
output_zero_point,
out_channel
)
qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8)
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8)
qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias)
qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point)
ret_ref = qlinear.dequantize()
self.assertEqual(ret_2, ret_ref)
"""Tests the correctness of the quantized::linear_unpack after freeing original tensor op."""
@skipIfNoQNNPACK

View File

@ -1252,6 +1252,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.vsplit: lambda input, indices_or_sections: -1,
torch.vstack: lambda tensors, out=None: -1,
torch.where: lambda condition, x=None, y=None: -1,
torch.wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1,
torch.wrapped_quantized_linear_prepacked: (
lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950
),
torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch._fw_primal_copy: lambda self, level: -1,
torch._make_dual_copy: lambda primal, tangent, level: -1,