add int8 packed gemm support on CPU device (#118056)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118056
Approved by: https://github.com/mikekgfb
This commit is contained in:
mingfeima
2024-03-06 17:56:20 -08:00
committed by PyTorch MergeBot
parent e8e3049f57
commit b3065f6899
9 changed files with 446 additions and 0 deletions

View File

@ -38,6 +38,7 @@
#include <ATen/ops/_linalg_slogdet_native.h>
#include <ATen/ops/_unsafe_view.h>
#include <ATen/ops/_weight_int4pack_mm_native.h>
#include <ATen/ops/_weight_int8pack_mm_native.h>
#include <ATen/ops/addbmm_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addr.h>
@ -3394,6 +3395,7 @@ Tensor kron(const Tensor& self, const Tensor& other) {
// Weight Only Quantization Gemm
DEFINE_DISPATCH(weight_to_int4pack_stub);
DEFINE_DISPATCH(int4pack_mm_stub);
DEFINE_DISPATCH(int8pack_mm_stub);
Tensor _convert_weight_to_int4pack_cpu(
const Tensor& in,
@ -3472,5 +3474,37 @@ Tensor _weight_int4pack_mm_cpu(
return C;
}
Tensor _weight_int8pack_mm_cpu(
const Tensor& A,
const Tensor& B,
const Tensor& scales) {
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);
TORCH_CHECK(A.dtype() == kBFloat16,
__func__, " : expect A to be bfloat16 tensor.");
TORCH_CHECK(A.is_contiguous(),
__func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2,
__func__, " : expect A to be 2D tensor.");
TORCH_CHECK(B.dtype() == kChar,
__func__, " : expect B to be int8 tensor.");
TORCH_CHECK(B.is_contiguous(),
__func__, " : expect B to be contiguous.");
TORCH_CHECK(B.size(1) == K,
__func__, " : expect B.size(1) == ", K);
TORCH_CHECK(scales.dim() == 1 && scales.size(0) == N,
__func__, " : expect scales to be 1d tensor with size ", N);
auto C = at::empty({M, N}, A.options());
int8pack_mm_stub(kCPU, C, A, B, scales);
return C;
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,302 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
#include <c10/util/Unroll.h>
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif
namespace at::native {
namespace {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
// A block : {BLOCK_M, BLOCK_K}, lda = K
// B block : {BLOCK_K, BLOCK_N}, ldb = K
// C block : {BLOCK_M, BLOCK_N}, ldc = N
//
// scales block: {BLOCK_N}
//
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N;
const int PREFETCH_SIZE_K = 16 * 4;
__m512 va;
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
__m512 scale[COLS];
auto load_scale = [&](int i) {
float ss = static_cast<float>(scales[i]);
scale[i] = _mm512_set1_ps(ss);
};
c10::ForcedUnroll<COLS>{}(load_scale);
auto loadc = [&](auto i) {
vc[i] = _mm512_setzero_ps();
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
__m256i a16 = _mm256_load_si256((__m256i*)(A + row * lda + k));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
vec::cvtbf16_fp32(a16, va);
}
if constexpr (row == 0) {
__m128i b8 = _mm_load_si128((__m128i*)(B + col * ldb + k));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
__m512i b32 = _mm512_cvtepi8_epi32(b8);
vb[col] = _mm512_cvtepi32_ps(b32);
vb[col] = _mm512_mul_ps(vb[col], scale[col]);
}
constexpr int idx = row * COLS + col;
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
};
for (int k = 0; k < K; k += 16) {
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
C[row * ldc + col] = static_cast<BFloat16>(_mm512_reduce_add_ps(vc[i]));
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
static inline float _mm256_reduce_add_ps(__m256& v) {
__m256 v1 = _mm256_permute2f128_ps(v, v, 0x1);
v = _mm256_add_ps(v, v1);
v1 = _mm256_shuffle_ps(v, v, 0x4E);
v = _mm256_add_ps(v, v1);
v1 = _mm256_shuffle_ps(v, v, 0xB1);
v = _mm256_add_ps(v, v1);
return _mm256_cvtss_f32(v);
}
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N;
const int PREFETCH_SIZE_K = 16 * 4;
__m256 va;
__m256 vb[COLS];
__m256 vc[ROWS * COLS];
__m256 scale[COLS];
auto load_scale = [&](int i) {
float ss = static_cast<float>(scales[i]);
scale[i] = _mm256_set1_ps(ss);
};
c10::ForcedUnroll<COLS>{}(load_scale);
auto loadc = [&](auto i) {
vc[i] = _mm256_setzero_ps();
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
__m128i a16 = _mm_load_si128((__m128i*)(A + row * lda + k));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
vec::cvtbf16_fp32(a16, va);
}
if constexpr (row == 0) {
__m128i b8 = _mm_loadu_si64((__m128i*)(B + col * ldb + k));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
__m256i b32 = _mm256_cvtepi8_epi32(b8);
vb[col] = _mm256_cvtepi32_ps(b32);
vb[col] = _mm256_mul_ps(vb[col], scale[col]);
}
constexpr int idx = row * COLS + col;
vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
};
for (int k = 0; k < K; k += 8) {
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
C[row * ldc + col] = static_cast<BFloat16>(_mm256_reduce_add_ps(vc[i]));
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
#else
// non-vectorized version
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
for (const auto m : c10::irange(BLOCK_M)) {
for (const auto n : c10::irange(BLOCK_N)) {
float c_val = 0;
float scale_val = static_cast<float>(scales[n]);
for (const auto k : c10::irange(K)) {
float a_val = static_cast<float>(A[m * lda + k]);
float b_val = static_cast<float>(B[n * ldb + k]);
c_val += a_val * (b_val * scale_val);
}
C[m * ldc + n] = c_val;
}
}
}
#endif
#define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE) \
tinygemm_kernel<MB_SIZE, NB_SIZE>( \
A_ptr, B_ptr, S_ptr, C_ptr, \
K, K, N, K);
#define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE) \
switch (nb_size) { \
case 1: \
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 1); \
break; \
case 2: \
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 2); \
break; \
case 3: \
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 3); \
break; \
case 4: \
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 4); \
break; \
default: \
TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
break; \
}
void int8pack_mm_kernel(
const Tensor& C,
const Tensor& A,
const Tensor& B,
const Tensor& scales) {
const auto* A_data = A.data_ptr<BFloat16>();
const auto* B_data = B.data_ptr<int8_t>();
auto* C_data = C.data_ptr<BFloat16>();
const auto* S_data = scales.data_ptr<BFloat16>();
int M = A.size(0);
int N = B.size(0);
int K = A.size(1);
constexpr int BLOCK_M = 4;
constexpr int BLOCK_N = 4;
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
int mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
for (const auto i : c10::irange(begin, end)) {
(void)i;
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(BLOCK_N, N - nb_start);
const auto* A_ptr = A_data + mb_start * K;
const auto* B_ptr = B_data + nb_start * K;
const auto* S_ptr = S_data + nb_start;
auto* C_ptr = C_data + mb_start * N + nb_start;
switch (mb_size) {
case 1:
LAUNCH_TINYGEMM_NB_SIZE(1);
break;
case 2:
LAUNCH_TINYGEMM_NB_SIZE(2);
break;
case 3:
LAUNCH_TINYGEMM_NB_SIZE(3);
break;
case 4:
LAUNCH_TINYGEMM_NB_SIZE(4);
break;
default:
TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
}
// move to the next index
data_index_step(mb, MB, nb, NB);
}
});
}
} // anonymous namespace
ALSO_REGISTER_AVX512_DISPATCH(int8pack_mm_stub, &int8pack_mm_kernel);
} // at::native

View File

@ -7,8 +7,10 @@ namespace at::native {
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
} // namespace at::native

View File

@ -4100,6 +4100,10 @@
CPU: _weight_int4pack_mm_cpu
CUDA: _weight_int4pack_mm_cuda
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
dispatch:
CPU: _weight_int8pack_mm_cpu
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
python_module: sparse

View File

@ -1155,6 +1155,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
"aten/src/ATen/native/cpu/int4mm_kernel.cpp",
"aten/src/ATen/native/cpu/int8mm_kernel.cpp",
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
"aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",

View File

@ -606,6 +606,7 @@ aten::_values
aten::_values_copy
aten::_values_copy.out
aten::_weight_int4pack_mm
aten::_weight_int8pack_mm
aten::_weight_norm_interface_backward
aten::_weight_norm_interface_backward.out
aten::adaptive_avg_pool2d.out

View File

@ -5990,6 +5990,92 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
def _dynamically_quantize_per_channel(self, x, quant_min, quant_max, target_dtype):
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
# default setup for affine quantization of activations
x_dtype = x.dtype
x = x.float()
eps = torch.finfo(torch.float32).eps
# get min and max
min_val, max_val = torch.aminmax(x, dim=1)
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
return quant, scales.to(x_dtype), zero_points
@onlyCPU
@parametrize("m", [32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [48, 64])
def test__int8_mm(self, device, m, k, n):
torch.manual_seed(1)
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = self._dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def weight_int8pack_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(
a, b_int8pack, b_scales
)
b_int8pack, b_scales = convert_weight_to_int8pack(b)
res = weight_int8pack_mm(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@onlyCPU
@parametrize("m", [32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [48, 64])
def test_compile_int8_mm(self, device, m, k, n):
if sys.version_info >= (3, 12):
self.skipTest("Dynamo is not supported on Python 3.12+")
torch.manual_seed(1)
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
b_int8pack, b_scales, _ = self._dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
@torch.compile
def int8_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(
a, b_int8pack, b_scales
)
res = int8_mm(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@slowTest
@onlyNativeDeviceTypes
# bfloat16 doesn't have sufficient precision to pass this test

View File

@ -1487,6 +1487,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._use_cudnn_rnn_flatten_weight",
"torch._values_copy",
"torch._weight_int4pack_mm",
"torch._weight_int8pack_mm",
"torch._weight_norm_interface",
"torch._weight_norm",
"torch.abs_",

View File

@ -3455,6 +3455,21 @@ def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
@register_meta([aten._weight_int8pack_mm])
def meta__weight_int8pack_mm(x, w, q_scales):
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
torch._check(
x.dtype is torch.bfloat16,
lambda: f"expected x to be bf16, got {x.dtype}",
)
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
torch._check(
w.dtype is torch.int8,
lambda: f"expected w to be int8, got {w.dtype}",
)
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
@register_meta(aten._cdist_forward.default)
def meta_cdist_forward(x1, x2, p, compute_mode):
torch._check(