From ab5086a7ae24b42fd4d43985f2870a44e9e778cd Mon Sep 17 00:00:00 2001 From: "Xiao, Wang" Date: Fri, 19 Sep 2025 07:37:10 +0000 Subject: [PATCH] [WOQ] Add XPU kernel for _weight_int8pack_mm (#160938) Summary: This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA. Motivation: Same as https://github.com/pytorch/pytorch/pull/159325. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160938 Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168 --- aten/src/ATen/native/mkldnn/xpu/Blas.cpp | 56 +++++++++++++++++++ .../ATen/native/mkldnn/xpu/detail/QMatmul.cpp | 5 +- aten/src/ATen/native/native_functions.yaml | 1 + test/xpu/test_gemm.py | 48 ++++++++++++++++ .../aoti_torch/generated/c_shim_xpu.h | 1 + 5 files changed, 109 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index 6a66abc7b062..7ef9aa5689d5 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -559,4 +559,60 @@ Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) { at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt)); return _int_mm_out_xpu(self, mat2, result); } + +Tensor _weight_int8pack_mm_xpu( + const Tensor& A, + const Tensor& B, + const Tensor& scales) { + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + + TORCH_CHECK( + A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); + TORCH_CHECK( + A.stride(1) == 1, " : A must be contiguous on the last dimension."); + TORCH_CHECK(B.dtype() == kChar, " : expect B to be int8 tensor."); + TORCH_CHECK(B.is_contiguous(), " : expect B to be contiguous."); + TORCH_CHECK(B.size(1) == K, " : expect B.size(1) == ", K); + + TORCH_CHECK( + scales.dim() == 1 && scales.size(0) == N, + " : expect scales to be 1d tensor with size ", + N); + + auto C = at::empty({M, N}, A.options()); + + // --- Launch kernel --- + Tensor bias = at::Tensor(); + Tensor mat2_zero_points = at::Tensor(); + Tensor non_const_scales = scales; + auto post_op_args = torch::List>(); + + at::native::onednn::quantized_matmul( + A.contiguous(), + 1.0, + 0, + B, + non_const_scales, + mat2_zero_points, + bias, + C, + 1.0, + 0, + C.scalar_type(), + /*other*/ std::nullopt, + /*other scale*/ 1.0, + /*other zp*/ 0, + /*binary post op*/ "none", + /*binary alpha*/ 1.0, + /*post_op_name*/ "none", + post_op_args, + /*post_op_algorithm*/ "none", + /*m2_trans*/ false); + + return C; +} } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp index 41da31c7eb6b..ede01093ff3e 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp @@ -110,8 +110,9 @@ void quantized_matmul( // [Note] Quantized Matrix Multiplication at XPU // The following code integrates oneDNN quantized gemm. The quantization // config we support: - // activation: s8&u8; per tensor calibrated; symmetric&asymmetric - // weight: s8; per_tensor/per_channel calibrated; symmetric + // activation: s8, u8, fp16, bf16, fp32; per tensor calibrated; + // symmetric&asymmetric weight: s8; per_tensor/per_channel calibrated; + // symmetric auto attr = Attr(static_cast(1.0 / output_scale), output_zero_point); construct_attr_by_post_op( binary_post_op, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 21470ee78b41..687288bb18b7 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4243,6 +4243,7 @@ CPU: _weight_int8pack_mm_cpu CUDA: _weight_int8pack_mm_cuda MPS: _weight_int8pack_mm_mps + XPU: _weight_int8pack_mm_xpu - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor python_module: sparse diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index 1164a2b67636..f2a273ccc330 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -19,8 +19,12 @@ from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, + onlyNativeDeviceTypes, precisionOverride, ) +from torch.testing._internal.common_quantization import ( + _dynamically_quantize_per_channel, +) from torch.testing._internal.common_utils import ( iter_indices, parametrize, @@ -1446,6 +1450,50 @@ def forward(self, x_1, w_1): return out_dtype""", ) + @onlyNativeDeviceTypes + @parametrize("m", [32, 64]) + @parametrize("k", [32, 64]) + @parametrize("n", [48, 64]) + @parametrize("compile", [True, False]) + @parametrize("slice", [True, False]) + def test__int8_mm(self, device, m, k, n, compile, slice): + torch.manual_seed(1) + if slice: + # logits are generated from LLaMA LM head like this - + # the activation to LM head is a slice of final hidden state + # of shape (batch_size, sequence_length, hidden dim), + # but is non-contiguous + # Using arbitrary batch-size here, since it'd be converted to 2D + batch_size = 4 + a = torch.rand((batch_size, m, k), dtype=torch.bfloat16, device=device) + # Make a non-contiguous + a = a[:, -1:, :] + a = a.view(-1, a.size(-1)) + else: + a = torch.rand((m, k), dtype=torch.bfloat16, device=device) + + b = torch.rand((n, k), dtype=torch.bfloat16, device=device) + + def convert_weight_to_int8pack(b): + b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( + b, -128, 127, torch.int8 + ) + return b_int8pack, b_scales + + def weight_int8pack_mm(a, b_int8pack, b_scales): + return torch._weight_int8pack_mm(a, b_int8pack, b_scales) + + b_int8pack, b_scales = convert_weight_to_int8pack(b) + if compile: + mod = torch.compile(weight_int8pack_mm) + else: + mod = weight_int8pack_mm + res = mod(a, b_int8pack, b_scales) + ref = torch.mm(a, b.transpose(0, 1)) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 09ebbb76d0b2..39f0dec86165 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -19,6 +19,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attent AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);