mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel][Hardware][AMD] Bf16 mfma opt for ROCm skinny GEMMs (#17071)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
|
||||
const int warp = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
const int num_warps = blockDim.x / WARP_SIZE;
|
||||
const int qwarpid = threadid / num_warps;
|
||||
const int qthreadid = threadid % num_warps;
|
||||
const int qwarpid = threadid / 16;
|
||||
const int qthreadid = threadid % 16;
|
||||
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
|
||||
scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
|
||||
float acc[NUM_A_ROWS_PER_BLOCK];
|
||||
@ -142,15 +142,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
|
||||
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
|
||||
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
|
||||
}
|
||||
colB_elem4x = bf4[threadid * 4 + 0];
|
||||
colB_elem4y = bf4[threadid * 4 + 1];
|
||||
colB_elem4z = bf4[threadid * 4 + 2];
|
||||
colB_elem4w = bf4[threadid * 4 + 3];
|
||||
}
|
||||
|
||||
colB_elem4x = bf4[threadid * 4 + 0];
|
||||
colB_elem4y = bf4[threadid * 4 + 1];
|
||||
colB_elem4z = bf4[threadid * 4 + 2];
|
||||
colB_elem4w = bf4[threadid * 4 + 3];
|
||||
|
||||
scalar2_t Af2;
|
||||
[[maybe_unused]] scalar2_t Bf2;
|
||||
float2 S;
|
||||
|
||||
auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
|
||||
@ -193,12 +191,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
|
||||
|
||||
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
|
||||
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
|
||||
for (int mask = num_warps / 2; mask >= 1; mask /= 2) {
|
||||
#pragma unroll
|
||||
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
|
||||
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
|
||||
}
|
||||
float oval2 = __shfl_xor(acc[qwarpid], num_warps);
|
||||
float oval2 = __shfl_xor(acc[qwarpid], 16);
|
||||
|
||||
if (lane % (num_warps * 2) == 0) {
|
||||
if (lane % 32 == 0) {
|
||||
oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
|
||||
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
|
||||
}
|
||||
@ -222,9 +221,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
|
||||
// operations.
|
||||
const int NUM_THREADS =
|
||||
K * 2 / 16 % WARP_SIZE == 0
|
||||
? K * 2 / 16
|
||||
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);
|
||||
max(rows_per_block * 16,
|
||||
K * 2 / 16 % WARP_SIZE == 0
|
||||
? K * 2 / 16
|
||||
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE));
|
||||
|
||||
int NUM_BLOCKS = M / rows_per_block;
|
||||
|
||||
@ -275,13 +275,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
#endif
|
||||
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
|
||||
using half4 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
|
||||
union bigType {
|
||||
scalar_t h[A_CHUNK];
|
||||
float f[A_CHUNK / 2];
|
||||
float2 f2[A_CHUNK / 4];
|
||||
double d[A_CHUNK / 4];
|
||||
half4 h4[A_CHUNK / 4];
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
@ -318,6 +327,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||
|
||||
float sum[N][YTILE];
|
||||
scalar8 sum4[N][YTILE];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Each wave works on a single column of weight matrix.
|
||||
@ -343,7 +353,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// are being worked on by each wave.
|
||||
//----------------------------------------------------
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = 0;
|
||||
for (int n = 0; n < N; n++)
|
||||
if constexpr (!use_mfma)
|
||||
sum[n][i] = 0;
|
||||
else
|
||||
sum4[n][i] = {0, 0, 0, 0};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
@ -374,24 +388,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (k_ >= K) break;
|
||||
|
||||
const scalar_t* B_ = &B[(m + 0) * K + k_];
|
||||
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
|
||||
//----------------------------------------------------
|
||||
// The following code with YTILE > 1 has to be deleted
|
||||
//----------------------------------------------------
|
||||
if constexpr (YTILE >= 2)
|
||||
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
|
||||
if constexpr (YTILE >= 3)
|
||||
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
|
||||
if constexpr (YTILE >= 4)
|
||||
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
|
||||
if constexpr (YTILE >= 5)
|
||||
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
|
||||
if constexpr (YTILE >= 6)
|
||||
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
|
||||
if constexpr (YTILE >= 7)
|
||||
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
|
||||
if constexpr (YTILE >= 8)
|
||||
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
|
||||
for (int y = 0; y < YTILE; y++)
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
|
||||
}
|
||||
|
||||
// Fetch activation matrix from either just LDS or from both LDS / memory
|
||||
@ -419,32 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b])
|
||||
//----------------------------------------------------
|
||||
// The following code with YTILE > 1
|
||||
//----------------------------------------------------
|
||||
if constexpr (YTILE >= 2) {
|
||||
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 3) {
|
||||
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 4) {
|
||||
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 5) {
|
||||
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 6) {
|
||||
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 7) {
|
||||
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 8) {
|
||||
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
|
||||
}
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if constexpr (!use_mfma)
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
||||
}
|
||||
else
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
||||
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
||||
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -453,37 +436,84 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
//----------------------------------------------------
|
||||
// Final reduction step using shuffle
|
||||
//----------------------------------------------------
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 63) {
|
||||
if constexpr (!use_mfma) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int n = 0; n < N; n++) {
|
||||
#pragma unroll
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
// float accm1 = 0;
|
||||
// for (int i=0; i<64; i++)
|
||||
// accm1 += __shfl(sum4[n][y][i%4], i);
|
||||
float accm = sum4[n][y][0];
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
|
||||
sum4[n][y][0] = accm;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
@ -505,13 +535,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
#endif
|
||||
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
|
||||
using half4 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
|
||||
union bigType {
|
||||
scalar_t h[A_CHUNK];
|
||||
float f[A_CHUNK / 2];
|
||||
float2 f2[A_CHUNK / 4];
|
||||
double d[A_CHUNK / 4];
|
||||
half4 h4[A_CHUNK / 4];
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
@ -573,6 +612,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.y >= _WvPrGrp) return;
|
||||
|
||||
float sum[N][YTILE];
|
||||
scalar8 sum4[N][YTILE];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Each wave works on a single column of weight matrix.
|
||||
@ -598,7 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// are being worked on by each wave.
|
||||
//----------------------------------------------------
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = 0;
|
||||
for (int n = 0; n < N; n++)
|
||||
if constexpr (!use_mfma)
|
||||
sum[n][i] = 0;
|
||||
else
|
||||
sum4[n][i] = {0, 0, 0, 0};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
@ -628,24 +672,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (k_ >= K) break;
|
||||
|
||||
const scalar_t* B_ = &B[(m + 0) * K + k_];
|
||||
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
|
||||
//----------------------------------------------------
|
||||
// The following code with YTILE > 1 has to be deleted
|
||||
//----------------------------------------------------
|
||||
if constexpr (YTILE >= 2)
|
||||
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
|
||||
if constexpr (YTILE >= 3)
|
||||
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
|
||||
if constexpr (YTILE >= 4)
|
||||
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
|
||||
if constexpr (YTILE >= 5)
|
||||
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
|
||||
if constexpr (YTILE >= 6)
|
||||
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
|
||||
if constexpr (YTILE >= 7)
|
||||
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
|
||||
if constexpr (YTILE >= 8)
|
||||
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
|
||||
for (int b = 0; b < YTILE; b++)
|
||||
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
|
||||
}
|
||||
|
||||
// Fetch activation matrix from either just LDS or from both LDS / memory
|
||||
@ -676,32 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// Do the matrix multiplication of activation and weight matrix
|
||||
// - Remember the accumulation is happening for K-split of 64!
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
|
||||
//----------------------------------------------------
|
||||
// The following code with YTILE > 1
|
||||
//----------------------------------------------------
|
||||
if constexpr (YTILE >= 2) {
|
||||
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 3) {
|
||||
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 4) {
|
||||
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 5) {
|
||||
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 6) {
|
||||
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 7) {
|
||||
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 8) {
|
||||
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
|
||||
}
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if constexpr (!use_mfma)
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
||||
}
|
||||
else
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
||||
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
||||
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -710,34 +723,82 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
//----------------------------------------------------
|
||||
// Final reduction step using shuffle
|
||||
//----------------------------------------------------
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 63) {
|
||||
if constexpr (!use_mfma) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int n = 0; n < N; n++) {
|
||||
#pragma unroll
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
// float accm1 = 0;
|
||||
// for (int i=0; i<64; i++)
|
||||
// accm1 += __shfl(sum4[n][y][i%4], i);
|
||||
|
||||
float accm = sum4[n][y][0];
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
|
||||
sum4[n][y][0] = accm;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -774,14 +835,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
#endif
|
||||
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
|
||||
|
||||
using half4 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
|
||||
union bigType {
|
||||
scalar_t h[A_CHUNK];
|
||||
float f[A_CHUNK / 2];
|
||||
float2 f2[A_CHUNK / 4];
|
||||
double d[A_CHUNK / 4];
|
||||
half4 h4[A_CHUNK / 4];
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
@ -857,6 +926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
kFit = min(kFit, K);
|
||||
|
||||
float sum[N][YTILE];
|
||||
scalar8 sum4[N][YTILE];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Each wave works on a single column of weight matrix.
|
||||
@ -888,7 +958,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// are being worked on by each wave.
|
||||
//----------------------------------------------------
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = 0;
|
||||
for (int n = 0; n < N; n++)
|
||||
if constexpr (!use_mfma)
|
||||
sum[n][i] = 0;
|
||||
else
|
||||
sum4[n][i] = {0, 0, 0, 0};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
@ -937,24 +1011,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (k_ >= K) break;
|
||||
|
||||
const scalar_t* B_ = &B[(m + 0) * K + k_];
|
||||
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
|
||||
//----------------------------------------------------
|
||||
// The following code with YTILE > 1 has to be deleted
|
||||
//----------------------------------------------------
|
||||
if constexpr (YTILE >= 2)
|
||||
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
|
||||
if constexpr (YTILE >= 3)
|
||||
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
|
||||
if constexpr (YTILE >= 4)
|
||||
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
|
||||
if constexpr (YTILE >= 5)
|
||||
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
|
||||
if constexpr (YTILE >= 6)
|
||||
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
|
||||
if constexpr (YTILE >= 7)
|
||||
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
|
||||
if constexpr (YTILE >= 8)
|
||||
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
|
||||
for (int b = 0; b < YTILE; b++)
|
||||
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
|
||||
}
|
||||
|
||||
// Fetch activation matrix from either just LDS or from both LDS / memory
|
||||
@ -989,32 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// Do the matrix multiplication of activation and weight matrix
|
||||
// - Remember the accumulation is happening for K-split of 64!
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
|
||||
//----------------------------------------------------
|
||||
// The following code with YTILE > 1
|
||||
//----------------------------------------------------
|
||||
if constexpr (YTILE >= 2) {
|
||||
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 3) {
|
||||
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 4) {
|
||||
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 5) {
|
||||
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 6) {
|
||||
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 7) {
|
||||
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
|
||||
}
|
||||
if constexpr (YTILE >= 8) {
|
||||
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
|
||||
}
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if constexpr (!use_mfma)
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
||||
}
|
||||
else
|
||||
#pragma unroll
|
||||
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
||||
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
||||
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1031,34 +1074,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
//----------------------------------------------------
|
||||
// Final reduction step using shuffle
|
||||
//----------------------------------------------------
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 63) {
|
||||
if constexpr (!use_mfma) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(sum[n][y])
|
||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int n = 0; n < N; n++) {
|
||||
#pragma unroll
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm = sum4[n][y][0];
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
||||
: "=v"(accm)
|
||||
: "0"(accm), "v"(accm), "v"(accm));
|
||||
|
||||
sum4[n][y][0] = accm;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
|
||||
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0
|
||||
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
|
||||
N = [1, 2, 3, 4]
|
||||
SEEDS = [0]
|
||||
|
||||
|
@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor,
|
||||
m = weight.shape[0]
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
if m > 8 and 0 < n < 4:
|
||||
if m > 8 and 0 < n <= 4:
|
||||
out = ops.wvSplitK(weight, x_view, cu_count)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
|
@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
return device_id
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi250_mi300() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
|
Reference in New Issue
Block a user