mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
add int4 packed gemm support on CPU device (#117475)
This patch adds int4 packed gemm support on CPU, both `avx512` and `avx2` are supported. It is used to speedup https://github.com/pytorch-labs/gpt-fast The default perf measured on Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores) is `16.13 sec total, 12.40 tokens/sec` * WOQ int4 on avx512: `5.92 sec total, 33.79 tokens/sec` * WOQ int4 on avx2: `6.90 sec total, 29.00 tokens/sec` WOQ int4 is measured with method: https://github.com/pytorch-labs/gpt-fast?tab=readme-ov-file#int4-weight-only-quantization Pull Request resolved: https://github.com/pytorch/pytorch/pull/117475 Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
54d92f2e37
commit
a427d90411
@ -30,6 +30,25 @@ static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
|
||||
cvtbf16_fp32(hi, o2);
|
||||
}
|
||||
|
||||
static inline __m256i cvtfp32_bf16(const __m512& src) {
|
||||
__m512i value = _mm512_castps_si512(src);
|
||||
__m512i nan = _mm512_set1_epi32(0xffff);
|
||||
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
||||
__m512i ones = _mm512_set1_epi32(0x1);
|
||||
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
||||
// uint32_t lsb = (input >> 16) & 1;
|
||||
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
|
||||
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||
t_value = _mm512_add_epi32(t_value, vec_bias);
|
||||
// input += rounding_bias;
|
||||
t_value = _mm512_add_epi32(t_value, value);
|
||||
// input = input >> 16;
|
||||
t_value = _mm512_srli_epi32(t_value, 16);
|
||||
// Check NaN before converting back to bf16
|
||||
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
||||
return _mm512_cvtusepi32_epi16(t_value);
|
||||
}
|
||||
|
||||
static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
|
||||
__m512i lo = _mm512_castps_si512(a);
|
||||
__m512i hi = _mm512_castps_si512(b);
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/LinearAlgebra.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/ReduceOps.h>
|
||||
@ -29,12 +30,14 @@
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_compute_linear_combination_native.h>
|
||||
#include <ATen/ops/_convert_weight_to_int4pack_native.h>
|
||||
#include <ATen/ops/_linalg_check_errors.h>
|
||||
#include <ATen/ops/_linalg_det.h>
|
||||
#include <ATen/ops/_linalg_det_native.h>
|
||||
#include <ATen/ops/_linalg_slogdet.h>
|
||||
#include <ATen/ops/_linalg_slogdet_native.h>
|
||||
#include <ATen/ops/_unsafe_view.h>
|
||||
#include <ATen/ops/_weight_int4pack_mm_native.h>
|
||||
#include <ATen/ops/addbmm_native.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addr.h>
|
||||
@ -3388,5 +3391,86 @@ Tensor kron(const Tensor& self, const Tensor& other) {
|
||||
return KronImpl(self, other).kron();
|
||||
}
|
||||
|
||||
// Weight Only Quantization Gemm
|
||||
DEFINE_DISPATCH(weight_to_int4pack_stub);
|
||||
DEFINE_DISPATCH(int4pack_mm_stub);
|
||||
|
||||
Tensor _convert_weight_to_int4pack_cpu(
|
||||
const Tensor& in,
|
||||
int64_t innerKTiles) {
|
||||
|
||||
TORCH_CHECK(in.dim() == 2,
|
||||
__func__, " : expect weight to be 2D tensor.");
|
||||
TORCH_CHECK(in.dtype() == at::kInt,
|
||||
__func__, " : expect weight to be kInt.");
|
||||
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
|
||||
__func__, " : innerKTiles need to be 2, 4, or 8, got ", innerKTiles);
|
||||
|
||||
auto weight = in.contiguous();
|
||||
auto N = weight.size(0);
|
||||
auto K = weight.size(1);
|
||||
|
||||
// Create fake shapes for cpu. The meta registration in dynamo requires
|
||||
// operator has the same output shape for each device. So creating a fake
|
||||
// shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
|
||||
constexpr int64_t kNTileSize = 8;
|
||||
constexpr int64_t kKTileSize = 16;
|
||||
auto nTiles = (N + kNTileSize - 1) / kNTileSize;
|
||||
|
||||
TORCH_CHECK(N % 16 == 0,
|
||||
__func__, " : expect N to be dividable by 16");
|
||||
const int64_t kSuperKTileSize = kKTileSize * innerKTiles;
|
||||
TORCH_CHECK( K % kSuperKTileSize == 0,
|
||||
__func__, " : epxect K to be dividable by ", kSuperKTileSize);
|
||||
auto kSuperTiles = (K + kSuperKTileSize - 1) / kSuperKTileSize;
|
||||
|
||||
auto weight_packed = at::empty(
|
||||
{nTiles, kSuperTiles, 32, innerKTiles / 2},
|
||||
at::TensorOptions().dtype(at::kInt));
|
||||
|
||||
weight_to_int4pack_stub(kCPU, weight_packed, weight, N, K);
|
||||
return weight_packed;
|
||||
}
|
||||
|
||||
Tensor _weight_int4pack_mm_cpu(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
int64_t qGroupSize,
|
||||
const Tensor& qScaleAndZeros) {
|
||||
|
||||
constexpr int64_t kNTileSize = 8;
|
||||
|
||||
auto M = A.size(0);
|
||||
auto N = B.size(0) * kNTileSize;
|
||||
auto K = A.size(1);
|
||||
|
||||
TORCH_CHECK(A.dtype() == kBFloat16,
|
||||
__func__, " : expect A to be bfloat16 tensor.");
|
||||
TORCH_CHECK(A.is_contiguous(),
|
||||
__func__, " : expect A to be contiguous.");
|
||||
TORCH_CHECK(A.dim() == 2,
|
||||
__func__, " : expect A to be 2D tensor.");
|
||||
|
||||
TORCH_CHECK(B.dtype() == kInt,
|
||||
__func__, " : expect B to be int32 tensor.");
|
||||
TORCH_CHECK(B.is_contiguous(),
|
||||
__func__, " : expect B to be contiguous.");
|
||||
TORCH_CHECK(B.dim() == 4,
|
||||
__func__, " : expect B to 4d tensor.");
|
||||
|
||||
TORCH_CHECK(qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128
|
||||
|| qGroupSize == 256,
|
||||
__func__, ": expect qGroupSize to be 32, 64, 128 or 256, got ", qGroupSize);
|
||||
|
||||
TORCH_CHECK(qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N
|
||||
&& qScaleAndZeros.size(2) == 2,
|
||||
__func__, ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", N, ", 2]");
|
||||
|
||||
auto C = at::empty({M, N}, A.options());
|
||||
int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros, N, K);
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
597
aten/src/ATen/native/cpu/int4mm_kernel.cpp
Normal file
597
aten/src/ATen/native/cpu/int4mm_kernel.cpp
Normal file
@ -0,0 +1,597 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Unroll.h>
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64))
|
||||
#define RESTRICT __restrict
|
||||
#else
|
||||
#define RESTRICT __restrict__
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
|
||||
inline bool is_block_start(int index, int BLOCK_SIZE) {
|
||||
return !(index & (BLOCK_SIZE -1));
|
||||
}
|
||||
|
||||
#if (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
|
||||
// convert 16x int4 to int8, handle 64 bits at a time
|
||||
// used in avx2 and avx512
|
||||
inline __m128i conver_int4_to_int8(const uint8_t* data) {
|
||||
__m128i tmp = _mm_loadu_si64((const __m128i*)data);
|
||||
__m128i bytes = _mm_cvtepu8_epi16(tmp);
|
||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||
__m128i high = _mm_andnot_si128(lowMask, bytes);
|
||||
__m128i low = _mm_and_si128(lowMask, bytes);
|
||||
high = _mm_slli_epi16(high, 4);
|
||||
bytes = _mm_or_si128(low, high);
|
||||
return bytes;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
|
||||
// A block : {BLOCK_M, BLOCK_K}, lda = K
|
||||
// B block : {BLOCK_K, BLOCK_N / 2}, ldb = BLOCK_N / 2
|
||||
// C block : {BLOCK_M, BLOCK_N}, ldc = N
|
||||
//
|
||||
// ScaleAndZeros block : {1, BLOCK_N, 2}
|
||||
//
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT ScaleAndZeros,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K,
|
||||
int BLOCK_K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int PREFETCH_SIZE_K = 16 * 4;
|
||||
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
|
||||
|
||||
// number of blocks on K
|
||||
const int KB = K / BLOCK_K;
|
||||
|
||||
__m512 va;
|
||||
__m512 vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 scale[COLS];
|
||||
__m512 zero[COLS];
|
||||
|
||||
// Lookup table to de-quantize int4 values to bf16.
|
||||
// Values are dequantized as truly int4 [-8, 7] range;
|
||||
//
|
||||
// dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
|
||||
//
|
||||
static const __m512 lut = _mm512_set_ps(
|
||||
7.0f, 6.0f, 5.0f, 4.0f,
|
||||
3.0f, 2.0f, 1.0f, 0.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f,
|
||||
-5.0f, -6.0f, -7.0f, -8.0f);
|
||||
|
||||
// index for transpose
|
||||
static const __m512i idx1 = _mm512_set_epi32(
|
||||
30, 28, 26, 24, 22, 20, 18, 16,
|
||||
14, 12, 10, 8, 6, 4, 2, 0);
|
||||
static const __m512i idx2 = _mm512_set_epi32(
|
||||
31, 29, 27, 25, 23, 21, 19, 17,
|
||||
15, 13, 11, 9, 7, 5, 3, 1);
|
||||
|
||||
// load scale and zero point
|
||||
auto load_scale_and_zeros = [&](int i, int _kb) {
|
||||
// load 2x bfloat16 vector
|
||||
__m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * ldc * 2 + 32 * i));
|
||||
if (_kb + PREFETCH_SIZE_KB < KB) {
|
||||
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 32 * i, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
// convert to 2x f32 vector
|
||||
__m512 a, b;
|
||||
vec::cvtbf16_fp32(t, a, b);
|
||||
|
||||
// transpose scale_and_zero from {16, 2} to {2, 16}
|
||||
// inputs:
|
||||
// a: {s0, z0, s1, z1, ..., s7, z7}
|
||||
// b: {s8, z8, s9, z9, ..., s15, z15}
|
||||
// output:
|
||||
// scale: {s0, s1, s2, ..., s15}
|
||||
// zero: {z0, z1, z2, ..., z15}
|
||||
scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
|
||||
zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
|
||||
};
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
auto compute = [&, COLS](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
float aa = static_cast<float>(A[row * lda + k]);
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
va = _mm512_set1_ps(aa);
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (COLS == 4) {
|
||||
// when BLOCK_N = 64, handle each row at a time
|
||||
// to reduce de-quantize overhead.
|
||||
if constexpr (col == 0) {
|
||||
__m256i b4 = _mm256_load_si256((__m256i*)(B + k * ldb));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
|
||||
vb[0] = _mm512_permutexvar_ps(b32, lut);
|
||||
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
|
||||
vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
|
||||
vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
|
||||
|
||||
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
|
||||
vb[1] = _mm512_permutexvar_ps(b32, lut);
|
||||
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
|
||||
vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
|
||||
vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
|
||||
}
|
||||
} else {
|
||||
__m128i b8 = conver_int4_to_int8(B + k * ldb + col * 8);
|
||||
__m512i b32 = _mm512_cvtepu8_epi32(b8);
|
||||
vb[col] = _mm512_permutexvar_ps(b32, lut);
|
||||
vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int idx = row * COLS + col;
|
||||
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
||||
};
|
||||
|
||||
for (int k = 0, kb = 0; k < K; ++k) {
|
||||
if (is_block_start(k, BLOCK_K)) {
|
||||
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
|
||||
}
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
//store to C
|
||||
auto storec = [&, COLS](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (COLS == 4) {
|
||||
// when BLOCK_N = 64, handle each row at a time
|
||||
// to reduce `cvtfp32_bf16` overhead.
|
||||
if constexpr (col == 0) {
|
||||
__m512i c01 = vec::cvtfp32_bf16(vc[row * 4 + 0], vc[row * 4 + 1]);
|
||||
__m512i c23 = vec::cvtfp32_bf16(vc[row * 4 + 2], vc[row * 4 + 3]);
|
||||
_mm512_storeu_si512((__m512i*)(C + row * ldc + 0 * 32), c01);
|
||||
_mm512_storeu_si512((__m512i*)(C + row * ldc + 1 * 32), c23);
|
||||
}
|
||||
} else {
|
||||
__m256i ci = vec::cvtfp32_bf16(vc[i]);
|
||||
_mm256_storeu_si256((__m256i*)(C + row * ldc + col * 16), ci);
|
||||
}
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
|
||||
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT ScaleAndZeros,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K,
|
||||
int BLOCK_K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 8;
|
||||
|
||||
const int PREFETCH_SIZE_K = 16 * 4;
|
||||
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
|
||||
|
||||
// number of blocks on K
|
||||
const int KB = K / BLOCK_K;
|
||||
|
||||
__m256 va;
|
||||
__m256 vb[COLS];
|
||||
__m256 vc[ROWS * COLS];
|
||||
__m256 scale[COLS];
|
||||
__m256 zero[COLS];
|
||||
|
||||
static const __m256i idx1 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
|
||||
|
||||
// offset to shift from range [0, 15] to [-8, 7]
|
||||
const __m256 offset = _mm256_set1_ps(-8.0f);
|
||||
|
||||
// load scale and zero point
|
||||
auto load_scale_and_zeros = [&](int i, int _kb) {
|
||||
// load 2x bfloat16 vector
|
||||
__m256i t = _mm256_loadu_si256((__m256i*)(ScaleAndZeros + _kb * ldc * 2 + 16 * i));
|
||||
if (_kb + PREFETCH_SIZE_KB < KB) {
|
||||
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 16 * i, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
// convert to 2x f32 vector
|
||||
__m256 a, b;
|
||||
vec::cvtbf16_fp32(t, a, b);
|
||||
|
||||
// transpose scale_and_zero from {8, 2} to {2, 8}
|
||||
// inputs:
|
||||
// a: {s0, z0, s1, z1, s2, z2, s3, z3}
|
||||
// b: {s4, z4, s5, z5, s6, z6, s7, z7}
|
||||
// output:
|
||||
// scale: {s0, s1, s2, s3, s4, s5, s6, s7}
|
||||
// zero: {z0, z1, z2, z3, z4, z5, z6, z7}
|
||||
a = _mm256_permutevar8x32_ps(a, idx1);
|
||||
b = _mm256_permutevar8x32_ps(b, idx1);
|
||||
scale[i] = _mm256_permute2f128_ps(a, b, 0b0100000);
|
||||
zero[i] = _mm256_permute2f128_ps(a, b, 0b0110001);
|
||||
|
||||
// zero = -8 * scale + zero
|
||||
zero[i] = _mm256_fmadd_ps(scale[i], offset, zero[i]);
|
||||
};
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm256_setzero_ps();
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
auto compute = [&, COLS](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
float aa = static_cast<float>(A[row * lda + k]);
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
va = _mm256_set1_ps(aa);
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (COLS == 4) {
|
||||
// when BLOCK_N = 32, handle each row at a time
|
||||
if constexpr (col == 0) {
|
||||
__m256i mask = _mm256_set1_epi32(0xF);
|
||||
__m128i b4 = _mm_load_si128((__m128i*)(B + k * ldb));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b32 = _mm256_cvtepu8_epi32(b4);
|
||||
vb[0] = _mm256_cvtepi32_ps(_mm256_and_si256(b32, mask));
|
||||
vb[0] = _mm256_fmadd_ps(vb[0], scale[0], zero[0]);
|
||||
vb[2] = _mm256_cvtepi32_ps(_mm256_srli_epi32(b32, 4));
|
||||
vb[2] = _mm256_fmadd_ps(vb[2], scale[2], zero[2]);
|
||||
|
||||
b32 = _mm256_cvtepu8_epi32(_mm_shuffle_epi32(b4, _MM_SHUFFLE(3, 2, 3, 2)));
|
||||
vb[1] = _mm256_cvtepi32_ps(_mm256_and_si256(b32, mask));
|
||||
vb[1] = _mm256_fmadd_ps(vb[1], scale[1], zero[1]);
|
||||
vb[3] = _mm256_cvtepi32_ps(_mm256_srli_epi32(b32, 4));
|
||||
vb[3] = _mm256_fmadd_ps(vb[3], scale[3], zero[3]);
|
||||
}
|
||||
} else {
|
||||
if constexpr (col % 2 == 0) {
|
||||
// de-quantize per 64 bits (16x int4)
|
||||
__m128i b8 = conver_int4_to_int8(B + k * ldb + col * 4);
|
||||
__m128i b8_val0 = _mm_set1_epi64x(_mm_extract_epi64(b8, 0));
|
||||
__m128i b8_val1 = _mm_set1_epi64x(_mm_extract_epi64(b8, 1));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb + col * 4, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
vb[col] = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(b8_val0));
|
||||
vb[col] = _mm256_fmadd_ps(vb[col], scale[col], zero[col]);
|
||||
vb[col + 1] = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(b8_val1));
|
||||
vb[col + 1] = _mm256_fmadd_ps(vb[col + 1], scale[col + 1], zero[col + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int idx = row * COLS + col;
|
||||
vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
|
||||
};
|
||||
for (int k = 0, kb = 0; k < K; ++k) {
|
||||
if (is_block_start(k, BLOCK_K)) {
|
||||
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
|
||||
}
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
// store to C
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m256i ci = vec::cvtfp32_bf16(vc[row * COLS + col], vc[row * COLS + col + 1]);
|
||||
_mm256_storeu_si256((__m256i*)(C + row * ldc + col * 8), ci);
|
||||
}
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
inline float convert_int4_to_float(uint8_t a, bool is_even) {
|
||||
static constexpr float lut[16] = {
|
||||
-8.0f, -7.0f, -6.0f, -5.0f,
|
||||
-4.0f, -3.0f, -2.0f, -1.0f,
|
||||
0.0f, 1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f
|
||||
};
|
||||
|
||||
int index = is_even ? (a & 0x0F) : (a >> 4);
|
||||
return lut[index];
|
||||
}
|
||||
|
||||
// non-vectorized version
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT ScaleAndZeros,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K,
|
||||
int BLOCK_K) {
|
||||
|
||||
for (const auto m : c10::irange(BLOCK_M)) {
|
||||
for (const auto n : c10::irange(BLOCK_N)) {
|
||||
float c_val = 0;
|
||||
for (const auto k : c10::irange(K)) {
|
||||
int kb = k / BLOCK_K;
|
||||
const auto scale = static_cast<float>(ScaleAndZeros[kb * ldc * 2 + n * 2]);
|
||||
const auto zero = static_cast<float>(ScaleAndZeros[kb * ldc * 2 + n * 2 + 1]);
|
||||
const auto a_val = static_cast<float>(A[m * lda + k]);
|
||||
uint8_t b_pack = B[k * ldb + n / 2];
|
||||
// range [-8, 7]: B_val = (bf16(B_int4_val) * scale) + zero
|
||||
float b_val = convert_int4_to_float(b_pack, n % 2 == 0);
|
||||
b_val = b_val * scale + zero;
|
||||
|
||||
c_val += a_val * b_val;
|
||||
}
|
||||
C[m * ldc + n] = c_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel<MB_SIZE, NB_SIZE>( \
|
||||
A_ptr, B_ptr, S_ptr, C_ptr, \
|
||||
K, NB_SIZE / 2, N, K, BLOCK_K);
|
||||
|
||||
#define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE) \
|
||||
switch (nb_size) { \
|
||||
case 16: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 32); \
|
||||
break; \
|
||||
case 48: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 48); \
|
||||
break; \
|
||||
case 64: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 64); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
// NB: int4 weight pack (with BLOCK_N 64)
|
||||
// weight (int32): {N/64, 64, K}
|
||||
// packed (uint8): {N/64, K, 32}
|
||||
//
|
||||
// 1. avx512 packed format:
|
||||
// When N is 64, to do 256-bit unpacking at a time, we pack Lane0 with Lane2,
|
||||
// Lane1 with Lane3 since we can only do shift on a 128-bit basis.
|
||||
//
|
||||
// weight:
|
||||
// [Lane0] N0...15: {a00, a01, a02, ...}
|
||||
// [Lane1] N16...31: {a10, a11, a12, ...}
|
||||
// [Lane2] N32...47: {a20, a21, a22, ...}
|
||||
// [Lane3] N48...63: {a30, a31, a32, ...}
|
||||
//
|
||||
// packed:
|
||||
// [Lane02] N0...31: {a20|a00, a21|a01, a22|a02, ...}
|
||||
// [Lane13] N32...63: {a30|a10, a31|a11, a32|a12, ...}
|
||||
//
|
||||
// Note: when N is 16, 32 or 48, pack with 64-bit format.
|
||||
//
|
||||
// 2. avx2 packed format:
|
||||
// When N is 32, to do 128-bit unpacking at a time.
|
||||
//
|
||||
// weight:
|
||||
// [Lane0] N0...15: { a0, a1, a2, ...}
|
||||
// [Lane1] N16...32: {a16, a17, a18, ...}
|
||||
//
|
||||
// packed:
|
||||
// [Lane01] N0...32: {a16|a0, a17|a1, a18|a2, ...}
|
||||
//
|
||||
// Note: When N is 16, pack with 64-bit format
|
||||
//
|
||||
// 3 non-vectorized packed format:
|
||||
// Do 64-bit unpacking at a time.
|
||||
//
|
||||
// weight: {a0, a1, a2, a3, ..., a14, a15}
|
||||
// packed: {a1|a0, a3, a2, ..., a15|a14}
|
||||
//
|
||||
void weight_to_int4pack_kernel(
|
||||
const Tensor& weight_packed,
|
||||
const Tensor& weight,
|
||||
int N, int K) {
|
||||
|
||||
auto weight_packed_data = reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<int32_t>();
|
||||
|
||||
// 64 for avx512 and 64 for avx2/non-vectorized
|
||||
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
// parallel on NB blocks
|
||||
at::parallel_for(0, NB, 0, [&](int begin, int end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
int nb_size = std::min(BLOCK_N, N - i * BLOCK_N);
|
||||
|
||||
const int32_t* src = weight_data + i * BLOCK_N * K;
|
||||
uint8_t* dst = weight_packed_data + i * K * BLOCK_N / 2;
|
||||
for (const auto k : c10::irange(K)) {
|
||||
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
if (nb_size == BLOCK_N) {
|
||||
for (const auto d : c10::irange(16)) {
|
||||
int32_t val0 = src[(d + 0) * K + k];
|
||||
int32_t val1 = src[(d + 16) * K + k];
|
||||
int32_t val2 = src[(d + 32) * K + k];
|
||||
int32_t val3 = src[(d + 48) * K + k];
|
||||
|
||||
uint8_t packed02 = (((uint8_t)(val2) << 4)) | ((uint8_t)(val0));
|
||||
uint8_t packed13 = (((uint8_t)(val3) << 4)) | ((uint8_t)(val1));
|
||||
|
||||
dst[k * 32 + d] = packed02;
|
||||
dst[k * 32 + 16 + d] = packed13;
|
||||
}
|
||||
} else {
|
||||
// for nb_size 16, 32, 48
|
||||
for (int n = 0; n < nb_size; n += 2) {
|
||||
int32_t val0 = src[n * K + k];
|
||||
int32_t val1 = src[n * K + K + k];
|
||||
|
||||
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
dst[k * nb_size / 2 + n / 2] = packed;
|
||||
}
|
||||
}
|
||||
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
if (nb_size == BLOCK_N) {
|
||||
// for nb_size 32
|
||||
for (const auto d : c10::irange(16)) {
|
||||
int32_t val0 = src[(d + 0) * K + k];
|
||||
int32_t val1 = src[(d + 16) * K + k];
|
||||
|
||||
uint8_t packed01 = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
dst[k * 16 + d] = packed01;
|
||||
}
|
||||
} else {
|
||||
// for nb_size 16
|
||||
for (int n = 0; n < nb_size; n += 2) {
|
||||
int32_t val0 = src[n * K + k];
|
||||
int32_t val1 = src[n * K + K + k];
|
||||
|
||||
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
dst[k * nb_size / 2 + n / 2] = packed;
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int n = 0; n < nb_size; n += 2) {
|
||||
int32_t val0 = src[n * K + k];
|
||||
int32_t val1 = src[n * K + K + k];
|
||||
|
||||
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
dst[k * nb_size / 2 + n / 2] = packed;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void int4pack_mm_kernel(
|
||||
const Tensor& C,
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
int qGroupSize,
|
||||
const Tensor& qScaleAndZeros,
|
||||
int N, int K) {
|
||||
|
||||
const auto* A_data = A.data_ptr<BFloat16>();
|
||||
const auto* B_data = reinterpret_cast<uint8_t*>(B.data_ptr());
|
||||
auto* C_data = C.data_ptr<BFloat16>();
|
||||
const auto* S_data = qScaleAndZeros.data_ptr<BFloat16>();
|
||||
|
||||
int M = A.size(0);
|
||||
|
||||
constexpr int BLOCK_M = 4;
|
||||
// 64 for avx512 and 32 for avx2/non-vectorized
|
||||
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
|
||||
// 32, 64, 128, 256
|
||||
const int BLOCK_K = qGroupSize;
|
||||
|
||||
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
|
||||
int mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
(void)i;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
const auto* A_ptr = A_data + mb_start * K;
|
||||
const auto* B_ptr = B_data + nb_start * K / 2;
|
||||
const auto* S_ptr = S_data + nb_start * 2;
|
||||
auto* C_ptr = C_data + mb_start * N + nb_start;
|
||||
|
||||
switch (mb_size) {
|
||||
case 1:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(2);
|
||||
break;
|
||||
case 3:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(3);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(4);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel);
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel);
|
||||
|
||||
} // at::native
|
14
aten/src/ATen/native/cpu/int_mm_kernel.h
Normal file
14
aten/src/ATen/native/cpu/int_mm_kernel.h
Normal file
@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
|
||||
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
|
||||
|
||||
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
|
||||
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
|
||||
|
||||
} // namespace at::native
|
@ -4092,10 +4092,12 @@
|
||||
|
||||
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
|
||||
dispatch:
|
||||
CPU: _convert_weight_to_int4pack_cpu
|
||||
CUDA: _convert_weight_to_int4pack_cuda
|
||||
|
||||
- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
|
||||
dispatch:
|
||||
CPU: _weight_int4pack_mm_cpu
|
||||
CUDA: _weight_int4pack_mm_cuda
|
||||
|
||||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
|
@ -1154,6 +1154,7 @@ aten_native_source_codegen_list = [
|
||||
"aten/src/ATen/native/cpu/airy_ai.cpp",
|
||||
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/int4mm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp",
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <type_traits>
|
||||
|
||||
// Utility to guarantee complete unrolling of a loop where the bounds are known
|
||||
// at compile time. Various pragmas achieve similar effects, but are not as
|
||||
@ -11,18 +12,18 @@ namespace c10 {
|
||||
|
||||
template <int n>
|
||||
struct ForcedUnroll {
|
||||
template <typename Func>
|
||||
C10_ALWAYS_INLINE void operator()(const Func& f) const {
|
||||
ForcedUnroll<n - 1>{}(f);
|
||||
f(n - 1);
|
||||
template <typename Func, typename... Args>
|
||||
C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
ForcedUnroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ForcedUnroll<1> {
|
||||
template <typename Func>
|
||||
C10_ALWAYS_INLINE void operator()(const Func& f) const {
|
||||
f(0);
|
||||
template <typename Func, typename... Args>
|
||||
C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -7,6 +7,7 @@ import unittest
|
||||
import itertools
|
||||
import warnings
|
||||
import math
|
||||
import sys
|
||||
from math import inf, nan, isnan
|
||||
import random
|
||||
from random import randrange
|
||||
@ -5907,12 +5908,14 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(not SM80OrLater, "need sm_80")
|
||||
@onlyCUDA
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
def test__int4_mm(self, device, m, k, n):
|
||||
if self.device_type == 'cuda' and not SM80OrLater:
|
||||
self.skipTest("requires SM80 or later")
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
self.skipTest("_int4_mm not compiled for ROCM")
|
||||
|
||||
@ -5947,15 +5950,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(not SM80OrLater, "need sm_80")
|
||||
@onlyCUDA
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
def test_compile_int4_mm(self, device, m, k, n):
|
||||
if self.device_type == 'cuda' and not SM80OrLater:
|
||||
self.skipTest("requires SM80 or later")
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
self.skipTest("_int4_mm not compiled for ROCM")
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
self.skipTest("Dynamo is not supported on Python 3.12+")
|
||||
|
||||
q_group = 32
|
||||
inner_k_tiles = 2
|
||||
|
||||
|
Reference in New Issue
Block a user