[Kernel][Hardware][AMD] Bf16 mfma opt for ROCm skinny GEMMs (#17071)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
Hashem Hashemi
2025-05-07 22:34:49 -07:00
committed by GitHub
parent 6930a41116
commit 5a499e70d5
4 changed files with 321 additions and 233 deletions

View File

@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / num_warps;
const int qthreadid = threadid % num_warps;
const int qwarpid = threadid / 16;
const int qthreadid = threadid % 16;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float acc[NUM_A_ROWS_PER_BLOCK];
@ -142,15 +142,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
scalar2_t Af2;
[[maybe_unused]] scalar2_t Bf2;
float2 S;
auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
@ -193,12 +191,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
for (int mask = num_warps / 2; mask >= 1; mask /= 2) {
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], num_warps);
float oval2 = __shfl_xor(acc[qwarpid], 16);
if (lane % (num_warps * 2) == 0) {
if (lane % 32 == 0) {
oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
@ -222,9 +221,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);
max(rows_per_block * 16,
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE));
int NUM_BLOCKS = M / rows_per_block;
@ -275,13 +275,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
@ -318,6 +327,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@ -343,7 +353,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
@ -374,24 +388,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@ -419,32 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b])
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
}
}
}
@ -453,37 +436,84 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
@ -505,13 +535,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
@ -573,6 +612,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@ -598,7 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
@ -628,24 +672,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@ -676,32 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
}
}
}
@ -710,34 +723,82 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
@ -774,14 +835,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
@ -857,6 +926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
kFit = min(kFit, K);
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@ -888,7 +958,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
@ -937,24 +1011,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@ -989,32 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
}
}
}
@ -1031,34 +1074,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}

View File

@ -8,7 +8,7 @@ from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
N = [1, 2, 3, 4]
SEEDS = [0]

View File

@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor,
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n < 4:
if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:

View File

@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id
@cache
def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])