[AOTI] Enable OP test__weight_int4pack_mm_with_scales_and_zeros in AOTI. (#155780)

The op test__weight_int4pack_mm_with_scales_and_zeros is for Intel GPU. It is functionally equivalent to the CUDA/CPU op test__weight_int4pack_mm (with the constraint that oneDNN only supports integer zero points, which is why we need this API). Since test__weight_int4pack_mm is already included in AOTI's fallback list, this PR adds support for XPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155780
Approved by: https://github.com/jansel
This commit is contained in:
xinan.lin
2025-06-11 22:50:25 -07:00
committed by PyTorch MergeBot
parent 463fe36532
commit 670dab6c63
3 changed files with 78 additions and 0 deletions

View File

@ -5906,6 +5906,82 @@ class AOTInductorTestsTemplate:
model = Model(b_int4pack, b_scales_and_zeros_f32)
self.check_model(model, (a,))
@parametrize("m", [32])
@parametrize("n", [64])
@parametrize("q_group", [32, 64])
@parametrize("num_groups", [1, 2])
def test__weight_int4pack_mm_with_scales_and_zeros(self, m, n, q_group, num_groups):
if "xpu" not in self.device:
raise unittest.SkipTest("requires Intel GPU")
class Model(torch.nn.Module):
def __init__(self, weight, scale, zeros) -> None:
super().__init__()
self.weight = weight
self.scale = scale
self.zeros = zeros
def forward(self, a):
return torch._weight_int4pack_mm_with_scales_and_zeros(
a, self.weight, q_group, self.scale, self.zeros
)
def _group_quantize_tensor_xpu(w, n_bit=4, q_group_size=16):
# w [k, n] = [32, 48]
assert w.dim() == 2
# w [n, k] = [48, 32]
w = w.transpose(0, 1).contiguous()
assert q_group_size > 1
assert w.shape[-1] % q_group_size == 0
# to_quant: [n * k / group_size, group_size]
to_quant = w.reshape(-1, q_group_size)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
assert torch.isnan(scales).sum() == 0
zeros = min_int - min_val.div(scales).round()
zeros = torch.clamp(zeros, min_int, max_int)
zeros = zeros.to(torch.int8)
assert torch.isnan(zeros).sum() == 0
out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int)
assert torch.isnan(out).sum() == 0
# [n, k]
out = out.to(dtype=torch.int32).reshape(w.shape)
if out.device != torch.device("cpu"):
out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8)
# Scales and zeros for the same q-group should be contiguous, so we can
# load as a 32-bit word
scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous()
zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous()
return out, scales, zeros
def convert_weight_to_int4pack(b):
# b_uint8 [n, k //2]
b_uint8, scales, zeros = _group_quantize_tensor_xpu(
b, n_bit=4, q_group_size=q_group
)
# b_int4pack [k//8, n]
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=2)
return b_int4pack, scales, zeros
k = q_group * num_groups
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b)
model = Model(b_int4pack, b_scales, zeros_int8)
self.check_model(model, (a,))
def test_assert_tensor_meta(self):
class Module(torch.nn.Module):
def forward(self, x):

View File

@ -16,6 +16,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);

View File

@ -169,4 +169,5 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
"aten.view_as_complex.default": {},
"aten.view_as_real.default": {},
"aten.view.dtype": {},
"aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
}