[ROCM][AMD][TRITON] Halving warps number for fw_prefill to reduce spilling (#12713)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev
2025-02-04 19:58:22 -08:00
committed by GitHub
parent b3a0d01e45
commit 64862d106e

View File

@ -11,7 +11,7 @@ from vllm.platforms import current_platform
# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 8
NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)