mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Enable OP test__weight_int4pack_mm_with_scales_and_zeros
in AOTI. (#155780)
The op test__weight_int4pack_mm_with_scales_and_zeros is for Intel GPU. It is functionally equivalent to the CUDA/CPU op test__weight_int4pack_mm (with the constraint that oneDNN only supports integer zero points, which is why we need this API). Since test__weight_int4pack_mm is already included in AOTI's fallback list, this PR adds support for XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155780 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
463fe36532
commit
670dab6c63
@ -5906,6 +5906,82 @@ class AOTInductorTestsTemplate:
|
||||
model = Model(b_int4pack, b_scales_and_zeros_f32)
|
||||
self.check_model(model, (a,))
|
||||
|
||||
@parametrize("m", [32])
|
||||
@parametrize("n", [64])
|
||||
@parametrize("q_group", [32, 64])
|
||||
@parametrize("num_groups", [1, 2])
|
||||
def test__weight_int4pack_mm_with_scales_and_zeros(self, m, n, q_group, num_groups):
|
||||
if "xpu" not in self.device:
|
||||
raise unittest.SkipTest("requires Intel GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, weight, scale, zeros) -> None:
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.scale = scale
|
||||
self.zeros = zeros
|
||||
|
||||
def forward(self, a):
|
||||
return torch._weight_int4pack_mm_with_scales_and_zeros(
|
||||
a, self.weight, q_group, self.scale, self.zeros
|
||||
)
|
||||
|
||||
def _group_quantize_tensor_xpu(w, n_bit=4, q_group_size=16):
|
||||
# w [k, n] = [32, 48]
|
||||
assert w.dim() == 2
|
||||
# w [n, k] = [48, 32]
|
||||
w = w.transpose(0, 1).contiguous()
|
||||
assert q_group_size > 1
|
||||
assert w.shape[-1] % q_group_size == 0
|
||||
|
||||
# to_quant: [n * k / group_size, group_size]
|
||||
to_quant = w.reshape(-1, q_group_size)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
|
||||
zeros = min_int - min_val.div(scales).round()
|
||||
zeros = torch.clamp(zeros, min_int, max_int)
|
||||
zeros = zeros.to(torch.int8)
|
||||
assert torch.isnan(zeros).sum() == 0
|
||||
|
||||
out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int)
|
||||
assert torch.isnan(out).sum() == 0
|
||||
|
||||
# [n, k]
|
||||
out = out.to(dtype=torch.int32).reshape(w.shape)
|
||||
if out.device != torch.device("cpu"):
|
||||
out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8)
|
||||
|
||||
# Scales and zeros for the same q-group should be contiguous, so we can
|
||||
# load as a 32-bit word
|
||||
scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous()
|
||||
zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous()
|
||||
|
||||
return out, scales, zeros
|
||||
|
||||
def convert_weight_to_int4pack(b):
|
||||
# b_uint8 [n, k //2]
|
||||
b_uint8, scales, zeros = _group_quantize_tensor_xpu(
|
||||
b, n_bit=4, q_group_size=q_group
|
||||
)
|
||||
# b_int4pack [k//8, n]
|
||||
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=2)
|
||||
|
||||
return b_int4pack, scales, zeros
|
||||
|
||||
k = q_group * num_groups
|
||||
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
|
||||
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
|
||||
b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b)
|
||||
model = Model(b_int4pack, b_scales, zeros_int8)
|
||||
self.check_model(model, (a,))
|
||||
|
||||
def test_assert_tensor_meta(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -16,6 +16,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
|
@ -169,4 +169,5 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.view_as_complex.default": {},
|
||||
"aten.view_as_real.default": {},
|
||||
"aten.view.dtype": {},
|
||||
"aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
|
||||
}
|
||||
|
Reference in New Issue
Block a user