mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCm] Add skinny gemm bias support for dtypes fp16,bf16,fp8 (#24988)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
This commit is contained in:
@ -5,11 +5,14 @@
|
||||
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t rows_per_block);
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias,
|
||||
const int64_t CuCount);
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
void paged_attention(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
|
@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][i] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
if (BIAS)
|
||||
sum4[n][i][0] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
if (commitColumn[i]) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][i] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
if (commitColumn[i]) {
|
||||
if (BIAS)
|
||||
sum4[n][i][0] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
if (commitColumn[i]) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][i] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
if (commitColumn[i]) {
|
||||
if (BIAS)
|
||||
sum4[n][i][0] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) {
|
||||
return rtn;
|
||||
}
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias,
|
||||
const int64_t CuCount) {
|
||||
auto M_in = in_a.size(0);
|
||||
auto K_in = in_a.size(1);
|
||||
auto N_in = in_b.size(0);
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
: 1;
|
||||
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
|
||||
in_bias->sizes().size() == 2)
|
||||
? in_bias->size(0)
|
||||
: 1;
|
||||
|
||||
TORCH_CHECK(in_a.dtype() == in_b.dtype());
|
||||
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
|
||||
@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
|
||||
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
using fptype = typename scalar<scalar_t>::type;
|
||||
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
|
||||
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
|
||||
const fptype* biasf4 =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
||||
: nullptr;
|
||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
const int CuCount) {
|
||||
@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 0) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS,
|
||||
scalar_t* C, const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE;
|
||||
@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
scalar_t* C, const float* __restrict__ s_A,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b,
|
||||
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
auto K_in = in_a.size(1);
|
||||
auto N_in = in_b.size(0);
|
||||
auto Kp_in = in_a.stride(0);
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
: 1;
|
||||
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
|
||||
in_bias->sizes().size() == 2)
|
||||
? in_bias->size(0)
|
||||
: 1;
|
||||
|
||||
TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
|
||||
TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
|
||||
TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
|
||||
@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
|
||||
auto a_ptr = in_a.data_ptr<fp8_t>();
|
||||
auto b_ptr = in_b.data_ptr<fp8_t>();
|
||||
auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0)
|
||||
? reinterpret_cast<fptype*>(in_bias->data_ptr())
|
||||
: nullptr;
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
|
||||
|
@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
|
||||
// Custom gemm op for skinny matrix-matrix multiplication
|
||||
rocm_ops.def(
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
|
||||
|
||||
// wvSplitK for fp8
|
||||
rocm_ops.def(
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
|
||||
"Tensor scale_a, "
|
||||
" Tensor scale_b, int CuCount) -> ()");
|
||||
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
|
||||
|
||||
|
Reference in New Issue
Block a user