[CPU] fix _weight_int8pack_mm with large output shape (#158341)

**Summary**
`_weight_int8pack_mm` on CPU may cause segmentation fault if output shape is large (i.e., M * N is large). It's because the kernel compute output buffer address by
```c++
auto* C_ptr = C_data + mb_start * N + nb_start;
```
where both `mb_start` and `N` are `int` and when they are large their product may overflow.
The solution is simple: declare these variables as `int64_t` so that the product won't overflow.

**Test plan**
```
pytest -sv test/test_linalg.py -k test__int8_mm_large_shape
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158341
Approved by: https://github.com/mingfeima, https://github.com/drisspg
This commit is contained in:
Xia, Weiwen
2025-07-29 01:14:50 +00:00
committed by PyTorch MergeBot
parent 657e5e9aa6
commit e469414b59
2 changed files with 42 additions and 14 deletions

View File

@ -7811,6 +7811,34 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@slowTest
@onlyCPU
def test__int8_mm_large_shape(self, device):
torch.manual_seed(1)
m = 65536
k = 64
n = 50400
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def weight_int8pack_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(
a, b_int8pack, b_scales
)
b_int8pack, b_scales = convert_weight_to_int8pack(b)
res = weight_int8pack_mm(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@onlyCPU
@parametrize("m", [32, 35, 36, 40, 64])
@parametrize("k", [32, 35, 36, 40, 64])