mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Attention] Tune CUTLASS MLA num_splits (#26846)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@ -125,32 +125,37 @@ public:
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
// printf("set_split_kv start");
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
// std::cout << H << " " << K << " " << D << " " << B << "\n";
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
// printf(" sm_count = %d\n", sm_count);
|
||||
int max_splits = ceil_div(K, 128);
|
||||
max_splits = min(16, max_splits);
|
||||
float seq_length_k = static_cast<float>(K) / 1024.0f;
|
||||
int max_splits = 1;
|
||||
|
||||
// TODO: This avoids a hang when the batch size larger than 1 and
|
||||
// there is more than 1 kv_splits.
|
||||
// Discuss with NVIDIA how this can be fixed.
|
||||
if (B > 1) {
|
||||
max_splits = min(1, max_splits);
|
||||
if (B <= 4 && seq_length_k >= 16) {
|
||||
max_splits = 16;
|
||||
}
|
||||
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
else if (B <= 8 && seq_length_k >= 4) {
|
||||
max_splits = 8;
|
||||
}
|
||||
else if ((B <= 16 && seq_length_k >= 8) ||
|
||||
(B == 48 && seq_length_k >= 32)) {
|
||||
max_splits = 4;
|
||||
}
|
||||
else if ((B <= 32 && seq_length_k >= 16) ||
|
||||
(B == 96 && seq_length_k >= 16)) {
|
||||
max_splits = 2;
|
||||
}
|
||||
else {
|
||||
max_splits = 1;
|
||||
}
|
||||
|
||||
// Wave-aware scheduling: ensure integer number of waves in K dimension
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
// printf(" sms_per_batch = %d\n", sms_per_batch);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
// printf(" args.split_kv = %d\n", args.split_kv);
|
||||
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
|
Reference in New Issue
Block a user