Facilitate at::_weight_int4pack_mm_with_scale_and_zeros related registration (#147962)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147962
Approved by: https://github.com/jerryzh168, https://github.com/guangyey, https://github.com/EikanWang
ghstack dependencies: #137566

Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
This commit is contained in:
ZhiweiYan-96
2025-04-08 02:21:35 +00:00
committed by PyTorch MergeBot
parent da7322548b
commit 52d172eafd
5 changed files with 105 additions and 1 deletions

View File

@ -4165,6 +4165,10 @@
MPS: _weight_int4pack_mm_mps
CUDA: _weight_int4pack_mm_cuda
- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor
dispatch:
XPU: _weight_int4pack_mm_xpu
# Split int4 pack weight between cpu and other devices due to
# https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.
- func: _convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor

View File

@ -651,6 +651,7 @@ aten::_values_copy
aten::_values_copy.out
aten::_weight_int4pack_mm
aten::_weight_int4pack_mm_for_cpu
aten::_weight_int4pack_mm_with_scales_and_zeros
aten::_weight_int8pack_mm
aten::_weight_norm_interface_backward
aten::_weight_norm_interface_backward.out

View File

@ -15,7 +15,12 @@ from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
precisionOverride,
)
from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase
from torch.testing._internal.common_utils import (
iter_indices,
parametrize,
run_tests,
TestCase,
)
class TestBasicGEMM(TestCase):
@ -1119,6 +1124,84 @@ class TestBasicGEMM(TestCase):
with torch.no_grad():
torch.matmul(a, b, out=c)
def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16):
# w [k, n] = [32, 48]
assert w.dim() == 2
# w [n, k] = [48, 32]
w = w.transpose(0, 1).contiguous()
assert q_group_size > 1
assert w.shape[-1] % q_group_size == 0
# to_quant: [n * k / group_size, group_size]
to_quant = w.reshape(-1, q_group_size)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
assert torch.isnan(scales).sum() == 0
zeros = min_int - min_val.div(scales).round()
zeros = torch.clamp(zeros, min_int, max_int)
zeros = zeros.to(torch.int8)
assert torch.isnan(zeros).sum() == 0
out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int)
assert torch.isnan(out).sum() == 0
# [n, k]
out = out.to(dtype=torch.int32).reshape(w.shape)
if out.device != torch.device("cpu"):
out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8)
# Scales and zeros for the same q-group should be contiguous, so we can
# load as a 32-bit word
scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous()
zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous()
return out, scales, zeros
@parametrize("m", [128])
@parametrize("k", [512, 1024])
@parametrize("n", [512, 1024])
def test__int4_mm(self, device, m, k, n):
q_group = 32
inner_k_tiles = 2
torch.manual_seed(1)
a_bf16 = torch.rand((m, k), dtype=torch.float32, device=device)
b_bf16 = torch.rand((k, n), dtype=torch.float32, device=device)
def convert_weight_to_int4pack(b):
# b_uint8 [n, k //2]
b_uint8, scales, zeros = self._group_quantize_tensor(
b, n_bit=4, q_group_size=q_group
)
# b_int4pack [k//8, n]
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, inner_k_tiles)
return b_int4pack, scales, zeros
def weight_int4pack_mm(a, b_int4pack, qscale, qzeros):
return torch._weight_int4pack_mm_with_scales_and_zeros(
a, b_int4pack, q_group, qscale, qzeros
)
b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b_bf16)
for dtype in [torch.bfloat16, torch.float16]:
a = a_bf16.to(dtype=dtype)
b = b_bf16.to(dtype=dtype)
b_scales = b_scales.to(dtype=dtype)
ref = torch.mm(a, b)
res = weight_int4pack_mm(a, b_int4pack, b_scales, zeros_int8)
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)

View File

@ -1624,6 +1624,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._values_copy",
"torch._weight_int4pack_mm",
"torch._weight_int4pack_mm_for_cpu",
"torch._weight_int4pack_mm_with_scales_and_zeros",
"torch._weight_int8pack_mm",
"torch._weight_norm_interface",
"torch._weight_norm",

View File

@ -3636,6 +3636,21 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
@register_meta([aten._weight_int4pack_mm_with_scales_and_zeros])
def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros):
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
torch._check(
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
)
torch._check(
w.dtype is torch.int32,
lambda: f"expected w to be int32, got {w.dtype}",
)
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
def kai_roundup(a: int, b: int) -> int:
return ((a + b - 1) // b) * b