[v1][attention] Support Hybrid Allocator + FlashInfer (#21412)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -198,7 +198,8 @@ class MockAttentionLayer:
|
|||||||
|
|
||||||
|
|
||||||
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||||
vllm_config, device: torch.device,
|
layer_names: list[str], vllm_config,
|
||||||
|
device: torch.device,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
query: torch.Tensor, key: torch.Tensor,
|
query: torch.Tensor, key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
|||||||
if backend == _Backend.FLASHINFER_VLLM_V1:
|
if backend == _Backend.FLASHINFER_VLLM_V1:
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
|
|
||||||
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
|
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||||
|
|
||||||
def mock_get_per_layer_parameters(vllm_config, impl_cls):
|
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
|
||||||
# Return mock parameters for a single layer
|
# Return mock parameters for a single layer
|
||||||
head_size = vllm_config.model_config.get_head_size()
|
head_size = vllm_config.model_config.get_head_size()
|
||||||
return {
|
return {
|
||||||
"mock_layer":
|
layer_name:
|
||||||
PerLayerParameters(
|
PerLayerParameters(
|
||||||
window_left=-1, # No sliding window
|
window_left=-1, # No sliding window
|
||||||
logits_soft_cap=0.0, # No soft cap
|
logits_soft_cap=0.0, # No soft cap
|
||||||
sm_scale=1.0 / (head_size**0.5) # Standard scale
|
sm_scale=1.0 / (head_size**0.5) # Standard scale
|
||||||
)
|
)
|
||||||
|
for layer_name in layer_names
|
||||||
}
|
}
|
||||||
|
|
||||||
with unittest.mock.patch(
|
with unittest.mock.patch(
|
||||||
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
|
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
|
||||||
mock_get_per_layer_parameters):
|
mock_get_per_layer_parameters):
|
||||||
builder = builder_cls(kv_cache_spec, vllm_config, device)
|
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
|
||||||
|
device)
|
||||||
attn_metadata = builder.build(
|
attn_metadata = builder.build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Build metadata
|
# Build metadata
|
||||||
builder = builder_cls(kv_cache_spec, vllm_config, device)
|
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
attn_metadata = builder.build(
|
attn_metadata = builder.build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
set_kv_cache_layout("HND")
|
set_kv_cache_layout("HND")
|
||||||
|
|
||||||
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
||||||
vllm_config, device,
|
["placeholder"], vllm_config,
|
||||||
common_attn_metadata,
|
device, common_attn_metadata,
|
||||||
query_vllm, key_vllm,
|
query_vllm, key_vllm,
|
||||||
value_vllm,
|
value_vllm,
|
||||||
kv_cache_for_backend)
|
kv_cache_for_backend)
|
||||||
|
@ -305,6 +305,7 @@ def test_propose(num_speculative_tokens):
|
|||||||
_Backend.FLASH_ATTN_VLLM_V1)
|
_Backend.FLASH_ATTN_VLLM_V1)
|
||||||
attn_metadata_builder = attn_metadata_builder_cls(
|
attn_metadata_builder = attn_metadata_builder_cls(
|
||||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||||
|
layer_names=proposer.attn_layer_names,
|
||||||
vllm_config=proposer.vllm_config,
|
vllm_config=proposer.vllm_config,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
@ -745,7 +745,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
layer_4 = "model.layers.4.mixer"
|
layer_4 = "model.layers.4.mixer"
|
||||||
layer_5 = "model.layers.5.mixer"
|
layer_5 = "model.layers.5.mixer"
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
hf_config = vllm_config.model_config.hf_config
|
hf_config = vllm_config.model_config.hf_config
|
||||||
fwd_context = {}
|
fwd_context = {}
|
||||||
for key in [layer_0, layer_1]:
|
for key in [layer_0, layer_1]:
|
||||||
|
@ -740,8 +740,8 @@ class ModelConfig:
|
|||||||
isinstance(sliding_window, list))
|
isinstance(sliding_window, list))
|
||||||
|
|
||||||
if not self.disable_sliding_window and has_interleaved_attention:
|
if not self.disable_sliding_window and has_interleaved_attention:
|
||||||
if (backend :=
|
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
|
||||||
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
|
) in ("XFORMERS", "FLASHINFER"):
|
||||||
sliding_window_len_min = get_min_sliding_window(
|
sliding_window_len_min = get_min_sliding_window(
|
||||||
self.hf_text_config.sliding_window)
|
self.hf_text_config.sliding_window)
|
||||||
|
|
||||||
@ -5065,13 +5065,29 @@ def assert_hashable(text):
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def get_layers_from_vllm_config(vllm_config: VllmConfig,
|
def get_layers_from_vllm_config(
|
||||||
layer_type: type[T]) -> dict[str, T]:
|
vllm_config: VllmConfig,
|
||||||
|
layer_type: type[T],
|
||||||
|
layer_names: Optional[list[str]] = None) -> dict[str, T]:
|
||||||
|
"""
|
||||||
|
Get layers from the vLLM config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: The vLLM config.
|
||||||
|
layer_type: The type of the layer to get.
|
||||||
|
layer_names: The names of the layers to get. If None, return all layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if layer_names is None:
|
||||||
|
layer_names = list(
|
||||||
|
vllm_config.compilation_config.static_forward_context.keys())
|
||||||
|
|
||||||
|
forward_context = vllm_config.compilation_config.static_forward_context
|
||||||
|
|
||||||
return {
|
return {
|
||||||
layer_name: layer
|
layer_name: forward_context[layer_name]
|
||||||
for layer_name, layer in
|
for layer_name in layer_names
|
||||||
vllm_config.compilation_config.static_forward_context.items()
|
if isinstance(forward_context[layer_name], layer_type)
|
||||||
if isinstance(layer, layer_type)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,8 +315,8 @@ class TorchSDPAMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device) -> None:
|
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
|
@ -148,8 +148,8 @@ class FlashAttentionMetadataBuilder(
|
|||||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||||
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
|
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
|
@ -21,10 +21,9 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||||
get_kv_cache_layout, get_per_layer_parameters,
|
get_per_layer_parameters, infer_global_hyperparameters,
|
||||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||||
split_decodes_and_prefills)
|
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -219,8 +218,8 @@ class FlashInferMetadata:
|
|||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.device = device
|
self.device = device
|
||||||
self._workspace_buffer = None
|
self._workspace_buffer = None
|
||||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||||
@ -228,7 +227,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||||
|
|
||||||
# Global hyperparameters shared by all attention layers
|
# Global hyperparameters shared by all attention layers
|
||||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
self.global_hyperparameters = infer_global_hyperparameters(
|
||||||
|
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
@ -283,10 +283,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
def _plan(self, num_prefills: int, num_decodes: int,
|
def _plan(self, num_prefills: int, num_decodes: int,
|
||||||
attn_metadata: FlashInferMetadata):
|
attn_metadata: FlashInferMetadata):
|
||||||
if self.global_hyperparameters is None:
|
|
||||||
self.global_hyperparameters = infer_global_hyperparameters(
|
|
||||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
|
||||||
|
|
||||||
if attn_metadata.use_cascade:
|
if attn_metadata.use_cascade:
|
||||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||||
attn_metadata.cascade_wrapper.plan(
|
attn_metadata.cascade_wrapper.plan(
|
||||||
|
@ -258,8 +258,8 @@ class FlexAttentionMetadata:
|
|||||||
class FlexAttentionMetadataBuilder(
|
class FlexAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
|
@ -87,8 +87,8 @@ class Mamba2AttentionMetadata:
|
|||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||||
|
@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
kv_cache_spec: AttentionSpec,
|
kv_cache_spec: AttentionSpec,
|
||||||
|
layer_names: list[str],
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
metadata_cls: Optional[type[M]] = None):
|
metadata_cls: Optional[type[M]] = None):
|
||||||
@ -471,7 +472,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper] = []
|
BatchPrefillWithRaggedKVCacheWrapper] = []
|
||||||
|
|
||||||
self._global_hyperparameters = infer_global_hyperparameters(
|
self._global_hyperparameters = infer_global_hyperparameters(
|
||||||
get_per_layer_parameters(vllm_config, MLACommonImpl))
|
get_per_layer_parameters(vllm_config, layer_names,
|
||||||
|
MLACommonImpl))
|
||||||
|
|
||||||
if self._use_cudnn_prefill:
|
if self._use_cudnn_prefill:
|
||||||
self.cudnn_workspace = torch.empty(
|
self.cudnn_workspace = torch.empty(
|
||||||
|
@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
|||||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||||
|
FlashMLAMetadata)
|
||||||
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
|
@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
|||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||||
|
AiterMLAMetadata)
|
||||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||||
"only supports block size 1."
|
"only supports block size 1."
|
||||||
|
|
||||||
|
@ -231,8 +231,8 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
|
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
|
||||||
full_cudagraph_supported: ClassVar[bool] = True
|
full_cudagraph_supported: ClassVar[bool] = True
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
|
@ -59,8 +59,8 @@ class TritonAttentionMetadataBuilder(
|
|||||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
full_cudagraph_supported: ClassVar[bool] = True
|
full_cudagraph_supported: ClassVar[bool] = True
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
@ -70,8 +70,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
full_cudagraph_supported: ClassVar[bool] = False
|
full_cudagraph_supported: ClassVar[bool] = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -164,14 +164,14 @@ class PerLayerParameters:
|
|||||||
|
|
||||||
|
|
||||||
def get_per_layer_parameters(
|
def get_per_layer_parameters(
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig, layer_names: list[str],
|
||||||
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
|
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
|
||||||
"""
|
"""
|
||||||
Scan all attention layers and determine some hyperparameters
|
Scan layers in `layer_names` and determine some hyperparameters
|
||||||
to use during `plan`.
|
to use during `plan`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
|
||||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||||
|
|
||||||
for key, layer in layers.items():
|
for key, layer in layers.items():
|
||||||
@ -208,6 +208,10 @@ def infer_global_hyperparameters(
|
|||||||
param_sets = list(per_layer_params.values())
|
param_sets = list(per_layer_params.values())
|
||||||
global_params = param_sets[0]
|
global_params = param_sets[0]
|
||||||
for params in param_sets:
|
for params in param_sets:
|
||||||
|
if params.window_left != global_params.window_left:
|
||||||
|
raise ValueError(
|
||||||
|
"Window left is not the same for all layers. One potential fix "
|
||||||
|
"is to set disable_sliding_window=True")
|
||||||
assert params == global_params, (
|
assert params == global_params, (
|
||||||
"FlashInfer backend currently only supports models in which all "
|
"FlashInfer backend currently only supports models in which all "
|
||||||
"layers share the same values for the following hyperparameters: "
|
"layers share the same values for the following hyperparameters: "
|
||||||
|
@ -2521,7 +2521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
elapsed_time, cuda_graph_size / (1 << 30))
|
elapsed_time, cuda_graph_size / (1 << 30))
|
||||||
|
|
||||||
def _initialize_single_attn_backend(
|
def _initialize_single_attn_backend(
|
||||||
self, kv_cache_spec: KVCacheSpec
|
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
|
||||||
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
attn_backend_i = get_attn_backend(
|
attn_backend_i = get_attn_backend(
|
||||||
@ -2551,6 +2551,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
|
layer_names,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
@ -2574,8 +2575,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config.kv_cache_groups):
|
kv_cache_config.kv_cache_groups):
|
||||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||||
|
|
||||||
attn_backend_i, attn_metadata_builder_i = \
|
attn_backend_i, attn_metadata_builder_i = (
|
||||||
self._initialize_single_attn_backend(kv_cache_spec)
|
self._initialize_single_attn_backend(
|
||||||
|
kv_cache_spec, kv_cache_group_spec.layer_names))
|
||||||
self.attn_backends.append(attn_backend_i)
|
self.attn_backends.append(attn_backend_i)
|
||||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||||
|
|
||||||
@ -2606,8 +2608,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
assert len(attn_specs) == len(attn_layers), \
|
assert len(attn_specs) == len(attn_layers), \
|
||||||
"All or none of the layers are expected to be encoder-only"
|
"All or none of the layers are expected to be encoder-only"
|
||||||
|
|
||||||
attn_backend, attn_metadata_builder = \
|
attn_backend, attn_metadata_builder = (
|
||||||
self._initialize_single_attn_backend(attn_specs[0])
|
self._initialize_single_attn_backend(attn_specs[0],
|
||||||
|
attn_layers.keys()))
|
||||||
self.attn_backends.append(attn_backend)
|
self.attn_backends.append(attn_backend)
|
||||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||||
self.is_encoder_only_model = True
|
self.is_encoder_only_model = True
|
||||||
|
Reference in New Issue
Block a user