[Bugfix] [B200] cutlass_mla - ensure kv_split == 1 for batch size > 1 (#25509)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Alexander Matveev
2025-09-23 19:57:55 -04:00
committed by yewentao256
parent d562c2ea09
commit faae7a7eab

View File

@ -135,10 +135,10 @@ public:
max_splits = min(16, max_splits);
// TODO: This avoids a hang when the batch size larger than 1 and
// there is more than 4 kv_splits.
// there is more than 1 kv_splits.
// Discuss with NVIDIA how this can be fixed.
if (B > 1) {
max_splits = min(2, max_splits);
max_splits = min(1, max_splits);
}
// printf(" max_splits = %d\n", max_splits);