mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Use PIECEWISE cudagraphs on Blackwell if max_model_len > 131072 (#27114)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -350,26 +350,48 @@ class VllmConfig:
|
||||
self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
)
|
||||
|
||||
# pooling models and encoder-decoder models
|
||||
# do not support full cudagraphs
|
||||
if self.model_config is not None and (
|
||||
self.model_config.pooler_config is not None
|
||||
or self.model_config.is_encoder_decoder
|
||||
):
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
# decode context parallel do not support full cudagraphs now.
|
||||
if self.parallel_config.decode_context_parallel_size > 1:
|
||||
logger.warning(
|
||||
"Decode context parallel (DCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. Set "
|
||||
"cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# if cudagraph_mode has full cudagraphs, we need to check support
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
if self.parallel_config.decode_context_parallel_size > 1:
|
||||
logger.warning_once(
|
||||
"Decode context parallel (DCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
elif self.model_config is not None:
|
||||
if self.model_config.pooler_config is not None:
|
||||
logger.warning_once(
|
||||
"Pooling models do not support full cudagraphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
logger.warning_once(
|
||||
"Encoder-decoder models do not support full cudagraphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
elif (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and self.model_config.max_model_len > 131072
|
||||
and not self.model_config.use_mla
|
||||
):
|
||||
# Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
|
||||
logger.warning_once(
|
||||
"NVIDIA Blackwell TRTLLM attention cannot support "
|
||||
"max_model_len >= 131072 (found "
|
||||
f"{self.model_config.max_model_len}), causing dynamic "
|
||||
"dispatching that breaks full cudagraphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and self.model_config.enforce_eager:
|
||||
logger.info("Cudagraph is disabled under eager mode")
|
||||
|
Reference in New Issue
Block a user