mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add new ops wrapped_linear_prepack and wrapped_quantized_linear_prepacked (#134232)
Summary: This diff adds two new operators torch.ops._quantized.wrapped_linear_prepack and torch.ops._quantized.wrapped_quantized_linear_prepacked. It is a decomposition of the op torch.ops._quantized.wrapped_quantized_linear added in the previous diff. We decomposed in this way as packed weight could be computed early so we don;t need to do it in every forward in AOTI Reviewed By: jerryzh168 Differential Revision: D61395887 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134232 Approved by: https://github.com/houseroad
This commit is contained in:
committed by
PyTorch MergeBot
parent
b23779ef0a
commit
311af3b988
@ -3398,6 +3398,10 @@
|
||||
|
||||
- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor
|
||||
|
||||
- func: wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor
|
||||
|
||||
- func: wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
|
||||
|
@ -1,5 +1,6 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
@ -435,6 +436,129 @@ at::Tensor wrapped_quantized_linear_meta(
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_linear_prepack(const at::Tensor& weight,
|
||||
const at::Tensor& weight_scale,
|
||||
const at::Tensor& weight_zero_point,
|
||||
const at::Tensor& bias);
|
||||
|
||||
at::Tensor wrapped_linear_prepack(const at::Tensor& weight,
|
||||
const at::Tensor& weight_scale,
|
||||
const at::Tensor& weight_zero_point,
|
||||
const at::Tensor& bias) {
|
||||
// This op does two things
|
||||
// 1. Use quantize_per_tensor to quantize the weight
|
||||
// 2. Use quantized::linear_prepack to prepack the weight and bias
|
||||
// The reason we do this is because we want to have such wrapper op to
|
||||
// save the quantized weight as constants for AOTI
|
||||
#ifdef USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
weight.dim() == 2,
|
||||
"fbgemm weight packing only packs matrices not vectors.");
|
||||
auto qw = at::quantize_per_tensor(
|
||||
weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
|
||||
|
||||
auto op = Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::linear_prepack", "")
|
||||
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
|
||||
at::Tensor, std::optional<at::Tensor>)>();
|
||||
auto packed_params = op.call(qw, bias);
|
||||
|
||||
auto unique_ptr_wrapper =
|
||||
std::make_unique<decltype(packed_params)>(std::move(packed_params));
|
||||
auto ret = cpp_custom_type_hack::create(
|
||||
std::move(unique_ptr_wrapper), weight.options());
|
||||
return ret;
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
|
||||
const at::Tensor& input_zero_point,
|
||||
const at::Tensor& packed_weight,
|
||||
const at::Tensor& output_scale,
|
||||
const at::Tensor& output_zero_point,
|
||||
[[maybe_unused]] const int64_t out_channel);
|
||||
|
||||
at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
|
||||
const at::Tensor& input_zero_point,
|
||||
const at::Tensor& packed_weight,
|
||||
const at::Tensor& output_scale,
|
||||
const at::Tensor& output_zero_point,
|
||||
[[maybe_unused]] const int64_t out_channel) {
|
||||
// This op is similar to wrapped_quantized_linear, but it takes the prepacked weight
|
||||
#ifdef USE_FBGEMM
|
||||
auto qx = at::quantize_per_tensor(
|
||||
input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
|
||||
const auto scale_val = output_scale.item().toFloat();
|
||||
const auto zero_point_val = output_zero_point.item().toLong();
|
||||
auto packed_weight_ptr =
|
||||
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
|
||||
cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
|
||||
packed_weight);
|
||||
auto result = callOpByName(
|
||||
"quantized::linear", "", qx, packed_weight_ptr, scale_val, zero_point_val);
|
||||
|
||||
return at::dequantize(result[0].toTensor());
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight,
|
||||
[[maybe_unused]] const at::Tensor& weight_scale,
|
||||
[[maybe_unused]] const at::Tensor& weight_zero_point,
|
||||
[[maybe_unused]] const at::Tensor& bias);
|
||||
|
||||
at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight,
|
||||
[[maybe_unused]] const at::Tensor& weight_scale,
|
||||
[[maybe_unused]] const at::Tensor& weight_zero_point,
|
||||
[[maybe_unused]] const at::Tensor& bias) {
|
||||
#ifdef USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
weight.dim() == 2,
|
||||
"fbgemm weight packing only packs matrices not vectors.");
|
||||
const at::SymInt M = weight.sym_size(0);
|
||||
const at::SymInt N = weight.sym_size(1);
|
||||
auto Y = at::empty_symint({M, N}, weight.options().dtype(at::kFloat));
|
||||
return Y;
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
|
||||
[[maybe_unused]] const at::Tensor& input_scale,
|
||||
[[maybe_unused]] const at::Tensor& input_zero_point,
|
||||
[[maybe_unused]] const at::Tensor& packed_weight,
|
||||
[[maybe_unused]] const at::Tensor& output_scale,
|
||||
[[maybe_unused]] const at::Tensor& output_zero_point,
|
||||
const int64_t out_channel);
|
||||
|
||||
at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
|
||||
[[maybe_unused]] const at::Tensor& input_scale,
|
||||
[[maybe_unused]] const at::Tensor& input_zero_point,
|
||||
[[maybe_unused]] const at::Tensor& packed_weight,
|
||||
[[maybe_unused]] const at::Tensor& output_scale,
|
||||
[[maybe_unused]] const at::Tensor& output_zero_point,
|
||||
const int64_t out_channel) {
|
||||
#ifdef USE_FBGEMM
|
||||
auto out_sizes = input.sym_sizes().vec();
|
||||
TORCH_CHECK(
|
||||
out_sizes.size() == 2,
|
||||
"The dimension of weight tensor should be equal to 2");
|
||||
out_sizes[out_sizes.size() - 1] = out_channel;
|
||||
|
||||
return at::empty_symint(out_sizes, input.options());
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class QLinearPackWeightInt8 final {
|
||||
@ -570,10 +694,22 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"),
|
||||
wrapped_linear_prepack);
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"),
|
||||
wrapped_quantized_linear_prepacked);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"),
|
||||
wrapped_linear_prepack_meta);
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"),
|
||||
wrapped_quantized_linear_prepacked_meta);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
|
||||
|
@ -251,6 +251,8 @@ TORCH_LIBRARY(_quantized, m) {
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(onednn, m) {
|
||||
|
@ -228,6 +228,8 @@ xfail_not_implemented = {
|
||||
"aten::var_mean.correction_names",
|
||||
"aten::var_mean.names_dim",
|
||||
"aten::where",
|
||||
"aten::wrapped_linear_prepack",
|
||||
"aten::wrapped_quantized_linear_prepacked",
|
||||
}
|
||||
|
||||
|
||||
|
@ -4223,6 +4223,47 @@ class TestQuantizedLinear(TestCase):
|
||||
ret_ref = qlinear.dequantize()
|
||||
self.assertEqual(ret, ret_ref)
|
||||
|
||||
"""Tests the correctness of the _quantized::wrapped_linear_prepack and
|
||||
_quantized::wrapped_quantized_linear_prepacked ops."""
|
||||
@skipIfNoFBGEMM
|
||||
@given(
|
||||
m=st.integers(2, 6),
|
||||
k=st.integers(2, 6),
|
||||
n=st.integers(2, 6),
|
||||
)
|
||||
def test_wrapped_quantized_linear_prepacked(self, m, n, k):
|
||||
input = torch.randn(m, k, dtype=torch.float32)
|
||||
input_scale = torch.tensor(0.1)
|
||||
input_zero_point = torch.tensor(0)
|
||||
weight = torch.randn(n, k, dtype=torch.float32)
|
||||
weight_scale = torch.tensor(0.1)
|
||||
weight_zero_point = torch.tensor(0)
|
||||
bias = torch.randn(n, dtype=torch.float32)
|
||||
output_scale = torch.tensor(0.1)
|
||||
output_zero_point = torch.tensor(0)
|
||||
out_channel = n
|
||||
|
||||
ret_1 = torch.ops._quantized.wrapped_linear_prepack(
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
bias
|
||||
)
|
||||
ret_2 = torch.ops._quantized.wrapped_quantized_linear_prepacked(
|
||||
input,
|
||||
input_scale,
|
||||
input_zero_point,
|
||||
ret_1,
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
out_channel
|
||||
)
|
||||
qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8)
|
||||
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8)
|
||||
qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias)
|
||||
qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point)
|
||||
ret_ref = qlinear.dequantize()
|
||||
self.assertEqual(ret_2, ret_ref)
|
||||
|
||||
"""Tests the correctness of the quantized::linear_unpack after freeing original tensor op."""
|
||||
@skipIfNoQNNPACK
|
||||
|
@ -1252,6 +1252,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.vsplit: lambda input, indices_or_sections: -1,
|
||||
torch.vstack: lambda tensors, out=None: -1,
|
||||
torch.where: lambda condition, x=None, y=None: -1,
|
||||
torch.wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1,
|
||||
torch.wrapped_quantized_linear_prepacked: (
|
||||
lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950
|
||||
),
|
||||
torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||
torch._fw_primal_copy: lambda self, level: -1,
|
||||
torch._make_dual_copy: lambda primal, tangent, level: -1,
|
||||
|
Reference in New Issue
Block a user