[v1][attention] Support Hybrid Allocator + FlashInfer (#21412)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -198,7 +198,8 @@ class MockAttentionLayer:
|
||||
|
||||
|
||||
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
vllm_config, device: torch.device,
|
||||
layer_names: list[str], vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
if backend == _Backend.FLASHINFER_VLLM_V1:
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
|
||||
def mock_get_per_layer_parameters(vllm_config, impl_cls):
|
||||
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
|
||||
# Return mock parameters for a single layer
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
"mock_layer":
|
||||
layer_name:
|
||||
PerLayerParameters(
|
||||
window_left=-1, # No sliding window
|
||||
logits_soft_cap=0.0, # No soft cap
|
||||
sm_scale=1.0 / (head_size**0.5) # Standard scale
|
||||
)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with unittest.mock.patch(
|
||||
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
|
||||
mock_get_per_layer_parameters):
|
||||
builder = builder_cls(kv_cache_spec, vllm_config, device)
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, vllm_config, device)
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
||||
vllm_config, device,
|
||||
common_attn_metadata,
|
||||
["placeholder"], vllm_config,
|
||||
device, common_attn_metadata,
|
||||
query_vllm, key_vllm,
|
||||
value_vllm,
|
||||
kv_cache_for_backend)
|
||||
|
@ -305,6 +305,7 @@ def test_propose(num_speculative_tokens):
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
@ -745,7 +745,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
layer_4 = "model.layers.4.mixer"
|
||||
layer_5 = "model.layers.5.mixer"
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
fwd_context = {}
|
||||
for key in [layer_0, layer_1]:
|
||||
|
@ -740,8 +740,8 @@ class ModelConfig:
|
||||
isinstance(sliding_window, list))
|
||||
|
||||
if not self.disable_sliding_window and has_interleaved_attention:
|
||||
if (backend :=
|
||||
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
|
||||
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
|
||||
) in ("XFORMERS", "FLASHINFER"):
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
self.hf_text_config.sliding_window)
|
||||
|
||||
@ -5065,13 +5065,29 @@ def assert_hashable(text):
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_layers_from_vllm_config(vllm_config: VllmConfig,
|
||||
layer_type: type[T]) -> dict[str, T]:
|
||||
def get_layers_from_vllm_config(
|
||||
vllm_config: VllmConfig,
|
||||
layer_type: type[T],
|
||||
layer_names: Optional[list[str]] = None) -> dict[str, T]:
|
||||
"""
|
||||
Get layers from the vLLM config.
|
||||
|
||||
Args:
|
||||
vllm_config: The vLLM config.
|
||||
layer_type: The type of the layer to get.
|
||||
layer_names: The names of the layers to get. If None, return all layers.
|
||||
"""
|
||||
|
||||
if layer_names is None:
|
||||
layer_names = list(
|
||||
vllm_config.compilation_config.static_forward_context.keys())
|
||||
|
||||
forward_context = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
return {
|
||||
layer_name: layer
|
||||
for layer_name, layer in
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
if isinstance(layer, layer_type)
|
||||
layer_name: forward_context[layer_name]
|
||||
for layer_name in layer_names
|
||||
if isinstance(forward_context[layer_name], layer_type)
|
||||
}
|
||||
|
||||
|
||||
|
@ -315,8 +315,8 @@ class TorchSDPAMetadata(AttentionMetadata):
|
||||
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device) -> None:
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
|
@ -148,8 +148,8 @@ class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
|
@ -21,10 +21,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
||||
get_kv_cache_layout, get_per_layer_parameters,
|
||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills)
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -219,8 +218,8 @@ class FlashInferMetadata:
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
@ -228,7 +227,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
@ -283,10 +283,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def _plan(self, num_prefills: int, num_decodes: int,
|
||||
attn_metadata: FlashInferMetadata):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
||||
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
|
@ -258,8 +258,8 @@ class FlexAttentionMetadata:
|
||||
class FlexAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
@ -87,8 +87,8 @@ class Mamba2AttentionMetadata:
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
|
@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[type[M]] = None):
|
||||
@ -471,7 +472,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
BatchPrefillWithRaggedKVCacheWrapper] = []
|
||||
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, MLACommonImpl))
|
||||
get_per_layer_parameters(vllm_config, layer_names,
|
||||
MLACommonImpl))
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
self.cudnn_workspace = torch.empty(
|
||||
|
@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashMLAMetadata)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
|
@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
AiterMLAMetadata)
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
|
@ -231,8 +231,8 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
|
@ -59,8 +59,8 @@ class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
@ -70,8 +70,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
@abstractmethod
|
||||
@ -164,14 +164,14 @@ class PerLayerParameters:
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig,
|
||||
vllm_config: VllmConfig, layer_names: list[str],
|
||||
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan all attention layers and determine some hyperparameters
|
||||
Scan layers in `layer_names` and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
@ -208,6 +208,10 @@ def infer_global_hyperparameters(
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. One potential fix "
|
||||
"is to set disable_sliding_window=True")
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
|
@ -2521,7 +2521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def _initialize_single_attn_backend(
|
||||
self, kv_cache_spec: KVCacheSpec
|
||||
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
|
||||
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backend_i = get_attn_backend(
|
||||
@ -2551,6 +2551,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
@ -2574,8 +2575,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
|
||||
attn_backend_i, attn_metadata_builder_i = \
|
||||
self._initialize_single_attn_backend(kv_cache_spec)
|
||||
attn_backend_i, attn_metadata_builder_i = (
|
||||
self._initialize_single_attn_backend(
|
||||
kv_cache_spec, kv_cache_group_spec.layer_names))
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
@ -2606,8 +2608,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert len(attn_specs) == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
|
||||
attn_backend, attn_metadata_builder = \
|
||||
self._initialize_single_attn_backend(attn_specs[0])
|
||||
attn_backend, attn_metadata_builder = (
|
||||
self._initialize_single_attn_backend(attn_specs[0],
|
||||
attn_layers.keys()))
|
||||
self.attn_backends.append(attn_backend)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||
self.is_encoder_only_model = True
|
||||
|
Reference in New Issue
Block a user