mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[CPU] fix _weight_int8pack_mm with large output shape (#158341)"
This reverts commit e469414b59ceeaae2860e36708de8852b9892776. Reverted https://github.com/pytorch/pytorch/pull/158341 on behalf of https://github.com/albanD due to Breaks slowtest ([comment](https://github.com/pytorch/pytorch/pull/158341#issuecomment-3132641530))
This commit is contained in:
@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
|
||||
auto* C_data = C.data_ptr<T>();
|
||||
const auto* S_data = scales.const_data_ptr<T>();
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t N = B.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 4;
|
||||
int M = A.size(0);
|
||||
int N = B.size(0);
|
||||
int K = A.size(1);
|
||||
int lda = A.stride(0);
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 4;
|
||||
|
||||
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
|
||||
int mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
(void)i;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
const auto* A_ptr = A_data + mb_start * lda;
|
||||
const auto* B_ptr = B_data + nb_start * K;
|
||||
|
@ -7811,34 +7811,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@slowTest
|
||||
@onlyCPU
|
||||
def test__int8_mm_large_shape(self, device):
|
||||
torch.manual_seed(1)
|
||||
m = 65536
|
||||
k = 64
|
||||
n = 50400
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
|
||||
|
||||
def convert_weight_to_int8pack(b):
|
||||
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
|
||||
b, -128, 127, torch.int8
|
||||
)
|
||||
return b_int8pack, b_scales
|
||||
|
||||
def weight_int8pack_mm(a, b_int8pack, b_scales):
|
||||
return torch._weight_int8pack_mm(
|
||||
a, b_int8pack, b_scales
|
||||
)
|
||||
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
res = weight_int8pack_mm(a, b_int8pack, b_scales)
|
||||
ref = torch.mm(a, b.transpose(0, 1))
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 35, 36, 40, 64])
|
||||
@parametrize("k", [32, 35, 36, 40, 64])
|
||||
|
Reference in New Issue
Block a user