Revert "[CPU] fix _weight_int8pack_mm with large output shape (#158341)"

This reverts commit e469414b59ceeaae2860e36708de8852b9892776.

Reverted https://github.com/pytorch/pytorch/pull/158341 on behalf of https://github.com/albanD due to Breaks slowtest ([comment](https://github.com/pytorch/pytorch/pull/158341#issuecomment-3132641530))
This commit is contained in:
PyTorch MergeBot
2025-07-29 13:56:20 +00:00
parent 9d32aa9789
commit 61aa2ae20f
2 changed files with 14 additions and 42 deletions

View File

@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
auto* C_data = C.data_ptr<T>();
const auto* S_data = scales.const_data_ptr<T>();
int64_t M = A.size(0);
int64_t N = B.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 4;
int M = A.size(0);
int N = B.size(0);
int K = A.size(1);
int lda = A.stride(0);
constexpr int BLOCK_M = 4;
constexpr int BLOCK_N = 4;
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
int mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
for (const auto i : c10::irange(begin, end)) {
(void)i;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(BLOCK_N, N - nb_start);
const auto* A_ptr = A_data + mb_start * lda;
const auto* B_ptr = B_data + nb_start * K;

View File

@ -7811,34 +7811,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@slowTest
@onlyCPU
def test__int8_mm_large_shape(self, device):
torch.manual_seed(1)
m = 65536
k = 64
n = 50400
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def weight_int8pack_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(
a, b_int8pack, b_scales
)
b_int8pack, b_scales = convert_weight_to_int8pack(b)
res = weight_int8pack_mm(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@onlyCPU
@parametrize("m", [32, 35, 36, 40, 64])
@parametrize("k", [32, 35, 36, 40, 64])