mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Facilitate at::_weight_int4pack_mm_with_scale_and_zeros related registration (#147962)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147962 Approved by: https://github.com/jerryzh168, https://github.com/guangyey, https://github.com/EikanWang ghstack dependencies: #137566 Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
da7322548b
commit
52d172eafd
@ -4165,6 +4165,10 @@
|
||||
MPS: _weight_int4pack_mm_mps
|
||||
CUDA: _weight_int4pack_mm_cuda
|
||||
|
||||
- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor
|
||||
dispatch:
|
||||
XPU: _weight_int4pack_mm_xpu
|
||||
|
||||
# Split int4 pack weight between cpu and other devices due to
|
||||
# https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.
|
||||
- func: _convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor
|
||||
|
||||
@ -651,6 +651,7 @@ aten::_values_copy
|
||||
aten::_values_copy.out
|
||||
aten::_weight_int4pack_mm
|
||||
aten::_weight_int4pack_mm_for_cpu
|
||||
aten::_weight_int4pack_mm_with_scales_and_zeros
|
||||
aten::_weight_int8pack_mm
|
||||
aten::_weight_norm_interface_backward
|
||||
aten::_weight_norm_interface_backward.out
|
||||
|
||||
@ -15,7 +15,12 @@ from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
precisionOverride,
|
||||
)
|
||||
from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
iter_indices,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
class TestBasicGEMM(TestCase):
|
||||
@ -1119,6 +1124,84 @@ class TestBasicGEMM(TestCase):
|
||||
with torch.no_grad():
|
||||
torch.matmul(a, b, out=c)
|
||||
|
||||
def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16):
|
||||
# w [k, n] = [32, 48]
|
||||
assert w.dim() == 2
|
||||
# w [n, k] = [48, 32]
|
||||
w = w.transpose(0, 1).contiguous()
|
||||
assert q_group_size > 1
|
||||
assert w.shape[-1] % q_group_size == 0
|
||||
|
||||
# to_quant: [n * k / group_size, group_size]
|
||||
to_quant = w.reshape(-1, q_group_size)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
|
||||
zeros = min_int - min_val.div(scales).round()
|
||||
zeros = torch.clamp(zeros, min_int, max_int)
|
||||
zeros = zeros.to(torch.int8)
|
||||
assert torch.isnan(zeros).sum() == 0
|
||||
|
||||
out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int)
|
||||
assert torch.isnan(out).sum() == 0
|
||||
|
||||
# [n, k]
|
||||
out = out.to(dtype=torch.int32).reshape(w.shape)
|
||||
if out.device != torch.device("cpu"):
|
||||
out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8)
|
||||
|
||||
# Scales and zeros for the same q-group should be contiguous, so we can
|
||||
# load as a 32-bit word
|
||||
scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous()
|
||||
zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous()
|
||||
|
||||
return out, scales, zeros
|
||||
|
||||
@parametrize("m", [128])
|
||||
@parametrize("k", [512, 1024])
|
||||
@parametrize("n", [512, 1024])
|
||||
def test__int4_mm(self, device, m, k, n):
|
||||
q_group = 32
|
||||
inner_k_tiles = 2
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_bf16 = torch.rand((m, k), dtype=torch.float32, device=device)
|
||||
b_bf16 = torch.rand((k, n), dtype=torch.float32, device=device)
|
||||
|
||||
def convert_weight_to_int4pack(b):
|
||||
# b_uint8 [n, k //2]
|
||||
b_uint8, scales, zeros = self._group_quantize_tensor(
|
||||
b, n_bit=4, q_group_size=q_group
|
||||
)
|
||||
# b_int4pack [k//8, n]
|
||||
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, inner_k_tiles)
|
||||
|
||||
return b_int4pack, scales, zeros
|
||||
|
||||
def weight_int4pack_mm(a, b_int4pack, qscale, qzeros):
|
||||
return torch._weight_int4pack_mm_with_scales_and_zeros(
|
||||
a, b_int4pack, q_group, qscale, qzeros
|
||||
)
|
||||
|
||||
b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b_bf16)
|
||||
|
||||
for dtype in [torch.bfloat16, torch.float16]:
|
||||
a = a_bf16.to(dtype=dtype)
|
||||
b = b_bf16.to(dtype=dtype)
|
||||
b_scales = b_scales.to(dtype=dtype)
|
||||
ref = torch.mm(a, b)
|
||||
|
||||
res = weight_int4pack_mm(a, b_int4pack, b_scales, zeros_int8)
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
||||
@ -1624,6 +1624,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._values_copy",
|
||||
"torch._weight_int4pack_mm",
|
||||
"torch._weight_int4pack_mm_for_cpu",
|
||||
"torch._weight_int4pack_mm_with_scales_and_zeros",
|
||||
"torch._weight_int8pack_mm",
|
||||
"torch._weight_norm_interface",
|
||||
"torch._weight_norm",
|
||||
|
||||
@ -3636,6 +3636,21 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
|
||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
@register_meta([aten._weight_int4pack_mm_with_scales_and_zeros])
|
||||
def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros):
|
||||
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
|
||||
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
|
||||
torch._check(
|
||||
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
|
||||
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
|
||||
)
|
||||
torch._check(
|
||||
w.dtype is torch.int32,
|
||||
lambda: f"expected w to be int32, got {w.dtype}",
|
||||
)
|
||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
def kai_roundup(a: int, b: int) -> int:
|
||||
return ((a + b - 1) // b) * b
|
||||
|
||||
|
||||
Reference in New Issue
Block a user