[v1][attention] Support Hybrid Allocator + FlashInfer (#21412)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-07-29 18:45:29 -07:00
committed by GitHub
parent 0e36abf993
commit 555e7225bc
16 changed files with 85 additions and 57 deletions

View File

@ -198,7 +198,8 @@ class MockAttentionLayer:
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
vllm_config, device: torch.device, layer_names: list[str], vllm_config,
device: torch.device,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor, key: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
if backend == _Backend.FLASHINFER_VLLM_V1: if backend == _Backend.FLASHINFER_VLLM_V1:
import unittest.mock import unittest.mock
from vllm.v1.attention.backends.flashinfer import PerLayerParameters from vllm.v1.attention.backends.utils import PerLayerParameters
def mock_get_per_layer_parameters(vllm_config, impl_cls): def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
# Return mock parameters for a single layer # Return mock parameters for a single layer
head_size = vllm_config.model_config.get_head_size() head_size = vllm_config.model_config.get_head_size()
return { return {
"mock_layer": layer_name:
PerLayerParameters( PerLayerParameters(
window_left=-1, # No sliding window window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5) # Standard scale sm_scale=1.0 / (head_size**0.5) # Standard scale
) )
for layer_name in layer_names
} }
with unittest.mock.patch( with unittest.mock.patch(
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
mock_get_per_layer_parameters): mock_get_per_layer_parameters):
builder = builder_cls(kv_cache_spec, vllm_config, device) builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
device)
attn_metadata = builder.build( attn_metadata = builder.build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
) )
else: else:
# Build metadata # Build metadata
builder = builder_cls(kv_cache_spec, vllm_config, device) builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build( attn_metadata = builder.build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
set_kv_cache_layout("HND") set_kv_cache_layout("HND")
backend_output = run_attention_backend(backend_name, kv_cache_spec, backend_output = run_attention_backend(backend_name, kv_cache_spec,
vllm_config, device, ["placeholder"], vllm_config,
common_attn_metadata, device, common_attn_metadata,
query_vllm, key_vllm, query_vllm, key_vllm,
value_vllm, value_vllm,
kv_cache_for_backend) kv_cache_for_backend)

View File

@ -305,6 +305,7 @@ def test_propose(num_speculative_tokens):
_Backend.FLASH_ATTN_VLLM_V1) _Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config, vllm_config=proposer.vllm_config,
device=device, device=device,
) )

View File

@ -745,7 +745,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer" layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer" layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
hf_config = vllm_config.model_config.hf_config hf_config = vllm_config.model_config.hf_config
fwd_context = {} fwd_context = {}
for key in [layer_0, layer_1]: for key in [layer_0, layer_1]:

View File

@ -740,8 +740,8 @@ class ModelConfig:
isinstance(sliding_window, list)) isinstance(sliding_window, list))
if not self.disable_sliding_window and has_interleaved_attention: if not self.disable_sliding_window and has_interleaved_attention:
if (backend := if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): ) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window( sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window) self.hf_text_config.sliding_window)
@ -5065,13 +5065,29 @@ def assert_hashable(text):
T = TypeVar("T") T = TypeVar("T")
def get_layers_from_vllm_config(vllm_config: VllmConfig, def get_layers_from_vllm_config(
layer_type: type[T]) -> dict[str, T]: vllm_config: VllmConfig,
layer_type: type[T],
layer_names: Optional[list[str]] = None) -> dict[str, T]:
"""
Get layers from the vLLM config.
Args:
vllm_config: The vLLM config.
layer_type: The type of the layer to get.
layer_names: The names of the layers to get. If None, return all layers.
"""
if layer_names is None:
layer_names = list(
vllm_config.compilation_config.static_forward_context.keys())
forward_context = vllm_config.compilation_config.static_forward_context
return { return {
layer_name: layer layer_name: forward_context[layer_name]
for layer_name, layer in for layer_name in layer_names
vllm_config.compilation_config.static_forward_context.items() if isinstance(forward_context[layer_name], layer_type)
if isinstance(layer, layer_type)
} }

View File

@ -315,8 +315,8 @@ class TorchSDPAMetadata(AttentionMetadata):
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device) -> None: vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config

View File

@ -148,8 +148,8 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config

View File

@ -21,10 +21,9 @@ from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
get_kv_cache_layout, get_per_layer_parameters, get_per_layer_parameters, infer_global_hyperparameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING: if TYPE_CHECKING:
@ -219,8 +218,8 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device self.device = device
self._workspace_buffer = None self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append self._prefill_wrapper = None # Wrapper for prefill/append
@ -228,7 +227,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._cascade_wrapper = None # Wrapper for cascade attention self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
@ -283,10 +283,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _plan(self, num_prefills: int, num_decodes: int, def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata: FlashInferMetadata): attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
if attn_metadata.use_cascade: if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan( attn_metadata.cascade_wrapper.plan(

View File

@ -258,8 +258,8 @@ class FlexAttentionMetadata:
class FlexAttentionMetadataBuilder( class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]): AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config

View File

@ -87,8 +87,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]): AttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()

View File

@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def __init__(self, def __init__(self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
metadata_cls: Optional[type[M]] = None): metadata_cls: Optional[type[M]] = None):
@ -471,7 +472,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
BatchPrefillWithRaggedKVCacheWrapper] = [] BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters( self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, MLACommonImpl)) get_per_layer_parameters(vllm_config, layer_names,
MLACommonImpl))
if self._use_cudnn_prefill: if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty( self.cudnn_workspace = torch.empty(

View File

@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata) super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads( self.num_q_heads = vllm_config.model_config.get_num_attention_heads(

View File

@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only full_cudagraph_supported: ClassVar[bool] = True # decode only
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata) super().__init__(kv_cache_spec, layer_names, vllm_config, device,
AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1." "only supports block size 1."

View File

@ -231,8 +231,8 @@ class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]): AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True full_cudagraph_supported: ClassVar[bool] = True
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config

View File

@ -59,8 +59,8 @@ class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]): AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True full_cudagraph_supported: ClassVar[bool] = True
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device self.device = device
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec

View File

@ -70,8 +70,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
full_cudagraph_supported: ClassVar[bool] = False full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod @abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
@abstractmethod @abstractmethod
@ -164,14 +164,14 @@ class PerLayerParameters:
def get_per_layer_parameters( def get_per_layer_parameters(
vllm_config: VllmConfig, vllm_config: VllmConfig, layer_names: list[str],
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
""" """
Scan all attention layers and determine some hyperparameters Scan layers in `layer_names` and determine some hyperparameters
to use during `plan`. to use during `plan`.
""" """
layers = get_layers_from_vllm_config(vllm_config, Attention) layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {} per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():
@ -208,6 +208,10 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values()) param_sets = list(per_layer_params.values())
global_params = param_sets[0] global_params = param_sets[0]
for params in param_sets: for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. One potential fix "
"is to set disable_sliding_window=True")
assert params == global_params, ( assert params == global_params, (
"FlashInfer backend currently only supports models in which all " "FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: " "layers share the same values for the following hyperparameters: "

View File

@ -2521,7 +2521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elapsed_time, cuda_graph_size / (1 << 30)) elapsed_time, cuda_graph_size / (1 << 30))
def _initialize_single_attn_backend( def _initialize_single_attn_backend(
self, kv_cache_spec: KVCacheSpec self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
) -> tuple[AttentionBackend, AttentionMetadataBuilder]: ) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend( attn_backend_i = get_attn_backend(
@ -2551,6 +2551,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builder_i = attn_backend_i.get_builder_cls()( attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec, kv_cache_spec,
layer_names,
self.vllm_config, self.vllm_config,
self.device, self.device,
) )
@ -2574,8 +2575,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config.kv_cache_groups): kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backend_i, attn_metadata_builder_i = \ attn_backend_i, attn_metadata_builder_i = (
self._initialize_single_attn_backend(kv_cache_spec) self._initialize_single_attn_backend(
kv_cache_spec, kv_cache_group_spec.layer_names))
self.attn_backends.append(attn_backend_i) self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i) self.attn_metadata_builders.append(attn_metadata_builder_i)
@ -2606,8 +2608,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(attn_specs) == len(attn_layers), \ assert len(attn_specs) == len(attn_layers), \
"All or none of the layers are expected to be encoder-only" "All or none of the layers are expected to be encoder-only"
attn_backend, attn_metadata_builder = \ attn_backend, attn_metadata_builder = (
self._initialize_single_attn_backend(attn_specs[0]) self._initialize_single_attn_backend(attn_specs[0],
attn_layers.keys()))
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
self.attn_metadata_builders.append(attn_metadata_builder) self.attn_metadata_builders.append(attn_metadata_builder)
self.is_encoder_only_model = True self.is_encoder_only_model = True