From 2b22290ce01b033cc692e7dce159d74a43f6f2c5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 20 Mar 2025 15:24:16 -0700 Subject: [PATCH] [V1] Add flag to disable cascade attention (#15243) Signed-off-by: Woosuk Kwon --- vllm/config.py | 2 ++ vllm/engine/arg_utils.py | 12 ++++++++++++ vllm/v1/worker/gpu_model_runner.py | 14 +++++++++----- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 74d7d9b17c..1f7147f7cf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -246,6 +246,7 @@ class ModelConfig: max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, disable_sliding_window: bool = False, + disable_cascade_attn: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, list[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, @@ -322,6 +323,7 @@ class ModelConfig: self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window + self.disable_cascade_attn = disable_cascade_attn self.skip_tokenizer_init = skip_tokenizer_init self.enable_sleep_mode = enable_sleep_mode diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 43bf2fe8f0..5015f1d684 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,6 +120,7 @@ class EngineArgs: block_size: Optional[int] = None enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False + disable_cascade_attn: bool = False use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB @@ -1096,6 +1097,16 @@ class EngineArgs: "using. This is used to parse the reasoning content into OpenAI " "API format. Required for ``--enable-reasoning``.") + parser.add_argument( + "--disable-cascade-attn", + action="store_true", + default=False, + help="Disable cascade attention for V1. While cascade attention " + "does not change the mathematical correctness, disabling it " + "could be useful for preventing potential numerical issues. " + "Note that even if this is set to False, cascade attention will be " + "only used when the heuristic tells that it's beneficial.") + return parser @classmethod @@ -1141,6 +1152,7 @@ class EngineArgs: max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, disable_sliding_window=self.disable_sliding_window, + disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7faf666dc6..c82bcec25d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if needed. - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + attn_metadata = self.attn_metadata_builder.build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens,