mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CPU] fix _weight_int8pack_mm with large output shape (#158341)
**Summary** `_weight_int8pack_mm` on CPU may cause segmentation fault if output shape is large (i.e., M * N is large). It's because the kernel compute output buffer address by ```c++ auto* C_ptr = C_data + mb_start * N + nb_start; ``` where both `mb_start` and `N` are `int` and when they are large their product may overflow. The solution is simple: declare these variables as `int64_t` so that the product won't overflow. **Test plan** ``` pytest -sv test/test_linalg.py -k test__int8_mm_large_shape ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158341 Approved by: https://github.com/mingfeima, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
657e5e9aa6
commit
e469414b59
@ -7811,6 +7811,34 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@slowTest
|
||||
@onlyCPU
|
||||
def test__int8_mm_large_shape(self, device):
|
||||
torch.manual_seed(1)
|
||||
m = 65536
|
||||
k = 64
|
||||
n = 50400
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
|
||||
|
||||
def convert_weight_to_int8pack(b):
|
||||
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
|
||||
b, -128, 127, torch.int8
|
||||
)
|
||||
return b_int8pack, b_scales
|
||||
|
||||
def weight_int8pack_mm(a, b_int8pack, b_scales):
|
||||
return torch._weight_int8pack_mm(
|
||||
a, b_int8pack, b_scales
|
||||
)
|
||||
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
res = weight_int8pack_mm(a, b_int8pack, b_scales)
|
||||
ref = torch.mm(a, b.transpose(0, 1))
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 35, 36, 40, 64])
|
||||
@parametrize("k", [32, 35, 36, 40, 64])
|
||||
|
Reference in New Issue
Block a user