[WOQ] Add XPU kernel for _weight_int8pack_mm (#160938)

Summary:
This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA.

Motivation:
Same as https://github.com/pytorch/pytorch/pull/159325.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160938
Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168
This commit is contained in:
Xiao, Wang
2025-09-19 07:37:10 +00:00
committed by PyTorch MergeBot
parent 0815091d86
commit ab5086a7ae
5 changed files with 109 additions and 2 deletions

View File

@ -559,4 +559,60 @@ Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
return _int_mm_out_xpu(self, mat2, result);
}
Tensor _weight_int8pack_mm_xpu(
const Tensor& A,
const Tensor& B,
const Tensor& scales) {
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);
TORCH_CHECK(
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");
TORCH_CHECK(
A.stride(1) == 1, " : A must be contiguous on the last dimension.");
TORCH_CHECK(B.dtype() == kChar, " : expect B to be int8 tensor.");
TORCH_CHECK(B.is_contiguous(), " : expect B to be contiguous.");
TORCH_CHECK(B.size(1) == K, " : expect B.size(1) == ", K);
TORCH_CHECK(
scales.dim() == 1 && scales.size(0) == N,
" : expect scales to be 1d tensor with size ",
N);
auto C = at::empty({M, N}, A.options());
// --- Launch kernel ---
Tensor bias = at::Tensor();
Tensor mat2_zero_points = at::Tensor();
Tensor non_const_scales = scales;
auto post_op_args = torch::List<std::optional<at::Scalar>>();
at::native::onednn::quantized_matmul(
A.contiguous(),
1.0,
0,
B,
non_const_scales,
mat2_zero_points,
bias,
C,
1.0,
0,
C.scalar_type(),
/*other*/ std::nullopt,
/*other scale*/ 1.0,
/*other zp*/ 0,
/*binary post op*/ "none",
/*binary alpha*/ 1.0,
/*post_op_name*/ "none",
post_op_args,
/*post_op_algorithm*/ "none",
/*m2_trans*/ false);
return C;
}
} // namespace at::native

View File

@ -110,8 +110,9 @@ void quantized_matmul(
// [Note] Quantized Matrix Multiplication at XPU
// The following code integrates oneDNN quantized gemm. The quantization
// config we support:
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
// weight: s8; per_tensor/per_channel calibrated; symmetric
// activation: s8, u8, fp16, bf16, fp32; per tensor calibrated;
// symmetric&asymmetric weight: s8; per_tensor/per_channel calibrated;
// symmetric
auto attr = Attr(static_cast<float>(1.0 / output_scale), output_zero_point);
construct_attr_by_post_op(
binary_post_op,

View File

@ -4243,6 +4243,7 @@
CPU: _weight_int8pack_mm_cpu
CUDA: _weight_int8pack_mm_cuda
MPS: _weight_int8pack_mm_mps
XPU: _weight_int8pack_mm_xpu
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
python_module: sparse

View File

@ -19,8 +19,12 @@ from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
onlyNativeDeviceTypes,
precisionOverride,
)
from torch.testing._internal.common_quantization import (
_dynamically_quantize_per_channel,
)
from torch.testing._internal.common_utils import (
iter_indices,
parametrize,
@ -1446,6 +1450,50 @@ def forward(self, x_1, w_1):
return out_dtype""",
)
@onlyNativeDeviceTypes
@parametrize("m", [32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [48, 64])
@parametrize("compile", [True, False])
@parametrize("slice", [True, False])
def test__int8_mm(self, device, m, k, n, compile, slice):
torch.manual_seed(1)
if slice:
# logits are generated from LLaMA LM head like this -
# the activation to LM head is a slice of final hidden state
# of shape (batch_size, sequence_length, hidden dim),
# but is non-contiguous
# Using arbitrary batch-size here, since it'd be converted to 2D
batch_size = 4
a = torch.rand((batch_size, m, k), dtype=torch.bfloat16, device=device)
# Make a non-contiguous
a = a[:, -1:, :]
a = a.view(-1, a.size(-1))
else:
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def weight_int8pack_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
b_int8pack, b_scales = convert_weight_to_int8pack(b)
if compile:
mod = torch.compile(weight_int8pack_mm)
else:
mod = weight_int8pack_mm
res = mod(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)

View File

@ -19,6 +19,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attent
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);