Compare commits

...

2 Commits

Author SHA1 Message Date
a2599dca0f fix missing removal
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
2025-10-17 11:35:42 -07:00
3fd66b1e73 [Misc] Remove unused virtual engine flag
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
2025-10-16 23:04:05 -07:00
16 changed files with 39 additions and 208 deletions

View File

@ -24,7 +24,6 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
from vllm.utils import ( from vllm.utils import (
FlexibleArgumentParser, FlexibleArgumentParser,
MemorySnapshot, MemorySnapshot,
bind_kv_cache,
common_broadcastable_dtype, common_broadcastable_dtype,
current_stream, current_stream,
get_open_port, get_open_port,
@ -343,87 +342,6 @@ def test_memory_profiling():
lib.cudaFree(handle2) lib.cudaFree(handle2)
def test_bind_kv_cache():
from vllm.attention import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
]
shared_kv_cache_layers = {
"layers.2.self_attn": "layers.1.self_attn",
"layers.3.self_attn": "layers.0.self_attn",
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
# example from Jamba PP=2
ctx = {
"model.layers.20.attn": Attention(32, 128, 0.1),
"model.layers.28.attn": Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1,)),
torch.zeros((1,)),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0]
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_pp():
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
}
kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]]
bind_kv_cache(ctx, kv_cache)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0]
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
@pytest.mark.parametrize( @pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"), ("src_dtype", "tgt_dtype", "expected_result"),
[ [

View File

@ -382,7 +382,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
@ -450,7 +449,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
@ -506,7 +504,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
@ -666,7 +663,6 @@ def test_kv_connector_stats(dist_init):
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
@ -1241,7 +1237,6 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
@ -1344,7 +1339,6 @@ def test_handshake_failure_returns_finished(dist_init):
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
@ -1393,7 +1387,6 @@ def test_transfer_setup_failure_returns_finished(dist_init):
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)

View File

@ -179,7 +179,7 @@ class RequestRunner:
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext( self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0 no_compile_layers={}, attn_metadata={}
) )
def new_request(self, token_ids: list[int]): def new_request(self, token_ids: list[int]):

View File

@ -272,14 +272,9 @@ class Attention(nn.Module, AttentionLayerBase):
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced # use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache # by bind_kv_cache this variable will not be accessed if use_direct_call
# this variable will not be accessed if use_direct_call is True # is True
self.kv_cache = [ self.kv_cache = torch.tensor([])
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
try: try:
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
@ -361,9 +356,9 @@ class Attention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name] attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward( self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output self, query, key, value, self.kv_cache, attn_metadata, output=output
) )
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
@ -376,9 +371,9 @@ class Attention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name] attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward( return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata self, query, key, value, self.kv_cache, attn_metadata
) )
else: else:
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
@ -644,12 +639,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
self.kv_cache = [ self.kv_cache = torch.tensor([])
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
# Align with Attention's scale attributes for MLA backends. # Align with Attention's scale attributes for MLA backends.
@ -688,7 +678,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name] attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path # Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr( if self.calculate_kv_scales and getattr(
@ -703,14 +692,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
q, q,
kv_c_normed, kv_c_normed,
k_pe, k_pe,
self_kv_cache, self.kv_cache,
attn_metadata, attn_metadata,
output=output, output=output,
) )
return output return output
else: else:
return self.impl.forward( return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata
) )
else: else:
if self.attn_backend.accept_output_buffer: if self.attn_backend.accept_output_buffer:
@ -785,7 +774,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector( def maybe_save_kv_layer_to_connector(
layer_name: str, layer_name: str,
kv_cache_layer: list[torch.Tensor], kv_cache_layer: torch.Tensor,
): ):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return return
@ -851,10 +840,9 @@ def unified_attention(
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, query, key, value, self.kv_cache, attn_metadata)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
return output return output
@ -889,20 +877,19 @@ def unified_attention_with_output(
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward( self.impl.forward(
self, self,
query, query,
key, key,
value, value,
kv_cache, self.kv_cache,
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
output_block_scale=output_block_scale, output_block_scale=output_block_scale,
) )
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(
@ -938,10 +925,9 @@ def unified_mla_attention(
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name] self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata)
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
return output return output
@ -978,20 +964,19 @@ def unified_mla_attention_with_output(
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name] self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward( self.impl.forward(
self, self,
q, q,
kv_c_normed, kv_c_normed,
k_pe, k_pe,
kv_cache, self.kv_cache,
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
output_block_scale=output_block_scale, output_block_scale=output_block_scale,
) )
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
def unified_mla_attention_with_output_fake( def unified_mla_attention_with_output_fake(

View File

@ -200,12 +200,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Only process layers that have kv_cache # Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention # attribute (attention layers) Skip non-attention
# layers like FusedMoE # layers like FusedMoE
kv_cache = getattr(layer, "kv_cache", None) layer = getattr(layer, "kv_cache", None)
if kv_cache is None: if layer is None:
continue continue
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address request.request_id + "#" + layer_name, remote_address
) )

View File

@ -174,12 +174,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Only process layers that have kv_cache # Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention # attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc. # layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None) kv_cache_layer = getattr(layer, "kv_cache", None)
if kv_cache_attr is None: if kv_cache_layer is None:
continue continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
filename = self._generate_filename_debug( filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes layer_name, request.token_ids, request.mm_hashes
) )

View File

@ -37,7 +37,7 @@ class BatchDescriptor(NamedTuple):
num_tokens: int num_tokens: int
uniform_decode: bool = False uniform_decode: bool = False
""" """
False can also be used for an uniform decode batch to dispatch to the False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches. cudagraph supporting non-uniform batches.
""" """
@ -179,8 +179,8 @@ class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context # copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any] no_compile_layers: dict[str, Any]
""" """
Type AttentionMetadata for v0, Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata attention layer to its attention metadata
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch. for each microbatch.
@ -191,8 +191,6 @@ class ForwardContext:
dict[str, "AttentionMetadata"], dict[str, "AttentionMetadata"],
list[dict[str, "AttentionMetadata"]], list[dict[str, "AttentionMetadata"]],
] ]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass
dp_metadata: DPMetadata | None = None dp_metadata: DPMetadata | None = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
@ -223,7 +221,6 @@ def get_forward_context() -> ForwardContext:
def create_forward_context( def create_forward_context(
attn_metadata: Any, attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: DPMetadata | None = None, dp_metadata: DPMetadata | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
@ -231,7 +228,6 @@ def create_forward_context(
): ):
return ForwardContext( return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context, no_compile_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
@ -259,7 +255,6 @@ def override_forward_context(forward_context: ForwardContext | None):
def set_forward_context( def set_forward_context(
attn_metadata: Any, attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int | None = None, num_tokens: int | None = None,
num_tokens_across_dp: torch.Tensor | None = None, num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
@ -305,7 +300,6 @@ def set_forward_context(
forward_context = create_forward_context( forward_context = create_forward_context(
attn_metadata, attn_metadata,
vllm_config, vllm_config,
virtual_engine,
dp_metadata, dp_metadata,
cudagraph_runtime_mode, cudagraph_runtime_mode,
batch_descriptor, batch_descriptor,

View File

@ -328,7 +328,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0] kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0) num_prefills = getattr(attn_metadata, "num_prefills", 0)

View File

@ -248,9 +248,8 @@ class MambaMixer(MambaBase, CustomOp):
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
query_start_loc = mamba1_metadata.query_start_loc query_start_loc = mamba1_metadata.query_start_loc
state_indices_tensor = mamba1_metadata.state_indices_tensor state_indices_tensor = mamba1_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self.kv_cache[0].transpose(-1, -2)
conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self.kv_cache[1]
ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes num_padded_decodes = mamba1_metadata.num_padded_decodes

View File

@ -511,10 +511,9 @@ class MambaMixer2(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self.kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self.kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states prep_initial_states = attn_metadata.prep_initial_states

View File

@ -118,8 +118,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self.kv_cache[0].transpose(-1, -2)
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p

View File

@ -471,7 +471,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
): ):
super().__init__() super().__init__()
self.kv_cache = [torch.tensor([])] self.kv_cache = torch.tensor([])
self.head_dim = head_dim self.head_dim = head_dim
self.prefix = prefix self.prefix = prefix
self.cache_config = cache_config self.cache_config = cache_config

View File

@ -258,10 +258,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self.kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self.kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states prep_initial_states = attn_metadata.prep_initial_states

View File

@ -458,9 +458,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self.kv_cache[0].transpose(-1, -2)
conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self.kv_cache[1]
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens

View File

@ -2013,55 +2013,6 @@ def get_mp_context():
return multiprocessing.get_context(mp_method) return multiprocessing.get_context(mp_method)
def bind_kv_cache(
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: dict[str, str] | None = None,
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
# 5. Some models have attention layers that share kv cache with previous
# layers, this is specified through shared_kv_cache_layers
if shared_kv_cache_layers is None:
shared_kv_cache_layers = {}
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name
for layer_name in ctx
if (
hasattr(ctx[layer_name], "attn_type")
and ctx[layer_name].attn_type
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
)
and ctx[layer_name].kv_sharing_target_layer_name is None
]
layer_index_sorted = sorted(
set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
)
for layer_name in layer_need_kv_cache:
kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
if shared_kv_cache_layers is not None:
for layer_name, target_layer_name in shared_kv_cache_layers.items():
assert extract_layer_index(target_layer_name) < extract_layer_index(
layer_name
), "v0 doesn't support interleaving kv sharing"
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
def run_method( def run_method(
obj: Any, obj: Any,
method: str | bytes | Callable, method: str | bytes | Callable,

View File

@ -318,8 +318,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context # Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items(): for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = kv_cache
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp( def is_residual_scattered_for_sp(