Compare commits

...

2 Commits

Author SHA1 Message Date
a2599dca0f fix missing removal
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
2025-10-17 11:35:42 -07:00
3fd66b1e73 [Misc] Remove unused virtual engine flag
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
2025-10-16 23:04:05 -07:00
16 changed files with 39 additions and 208 deletions

View File

@ -24,7 +24,6 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
from vllm.utils import (
FlexibleArgumentParser,
MemorySnapshot,
bind_kv_cache,
common_broadcastable_dtype,
current_stream,
get_open_port,
@ -343,87 +342,6 @@ def test_memory_profiling():
lib.cudaFree(handle2)
def test_bind_kv_cache():
from vllm.attention import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
torch.zeros((1,)),
]
shared_kv_cache_layers = {
"layers.2.self_attn": "layers.1.self_attn",
"layers.3.self_attn": "layers.0.self_attn",
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
# example from Jamba PP=2
ctx = {
"model.layers.20.attn": Attention(32, 128, 0.1),
"model.layers.28.attn": Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1,)),
torch.zeros((1,)),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0]
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_pp():
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
}
kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]]
bind_kv_cache(ctx, kv_cache)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0]
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[

View File

@ -382,7 +382,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
@ -450,7 +449,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
@ -506,7 +504,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
@ -666,7 +663,6 @@ def test_kv_connector_stats(dist_init):
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
@ -1241,7 +1237,6 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
@ -1344,7 +1339,6 @@ def test_handshake_failure_returns_finished(dist_init):
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
@ -1393,7 +1387,6 @@ def test_transfer_setup_failure_returns_finished(dist_init):
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)

View File

@ -179,7 +179,7 @@ class RequestRunner:
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0
no_compile_layers={}, attn_metadata={}
)
def new_request(self, token_ids: list[int]):

View File

@ -272,14 +272,9 @@ class Attention(nn.Module, AttentionLayerBase):
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
# by bind_kv_cache this variable will not be accessed if use_direct_call
# is True
self.kv_cache = torch.tensor([])
try:
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
@ -361,9 +356,9 @@ class Attention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
self, query, key, value, self.kv_cache, attn_metadata, output=output
)
else:
torch.ops.vllm.unified_attention_with_output(
@ -376,9 +371,9 @@ class Attention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
self, query, key, value, self.kv_cache, attn_metadata
)
else:
return torch.ops.vllm.unified_attention(
@ -644,12 +639,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.kv_cache = torch.tensor([])
# Align with Attention's scale attributes for MLA backends.
@ -688,7 +678,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
@ -703,14 +692,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
q,
kv_c_normed,
k_pe,
self_kv_cache,
self.kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
@ -785,7 +774,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: list[torch.Tensor],
kv_cache_layer: torch.Tensor,
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@ -851,10 +840,9 @@ def unified_attention(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
output = self.impl.forward(self, query, key, value, self.kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
return output
@ -889,20 +877,19 @@ def unified_attention_with_output(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
query,
key,
value,
kv_cache,
self.kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
def unified_attention_with_output_fake(
@ -938,10 +925,9 @@ def unified_mla_attention(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
output = self.impl.forward(self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
return output
@ -978,20 +964,19 @@ def unified_mla_attention_with_output(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
self.kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
def unified_mla_attention_with_output_fake(

View File

@ -200,12 +200,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, "kv_cache", None)
if kv_cache is None:
layer = getattr(layer, "kv_cache", None)
if layer is None:
continue
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address
)

View File

@ -174,12 +174,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None)
if kv_cache_attr is None:
kv_cache_layer = getattr(layer, "kv_cache", None)
if kv_cache_layer is None:
continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)

View File

@ -37,7 +37,7 @@ class BatchDescriptor(NamedTuple):
num_tokens: int
uniform_decode: bool = False
"""
False can also be used for an uniform decode batch to dispatch to the
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
@ -179,8 +179,8 @@ class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch.
@ -191,8 +191,6 @@ class ForwardContext:
dict[str, "AttentionMetadata"],
list[dict[str, "AttentionMetadata"]],
]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: DPMetadata | None = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
@ -223,7 +221,6 @@ def get_forward_context() -> ForwardContext:
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: DPMetadata | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
@ -231,7 +228,6 @@ def create_forward_context(
):
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
@ -259,7 +255,6 @@ def override_forward_context(forward_context: ForwardContext | None):
def set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int | None = None,
num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
@ -305,7 +300,6 @@ def set_forward_context(
forward_context = create_forward_context(
attn_metadata,
vllm_config,
virtual_engine,
dp_metadata,
cudagraph_runtime_mode,
batch_descriptor,

View File

@ -328,7 +328,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)

View File

@ -248,9 +248,8 @@ class MambaMixer(MambaBase, CustomOp):
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
query_start_loc = mamba1_metadata.query_start_loc
state_indices_tensor = mamba1_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
conv_state = self.kv_cache[0].transpose(-1, -2)
ssm_state = self.kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes

View File

@ -511,10 +511,9 @@ class MambaMixer2(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
conv_state = self.kv_cache[0].transpose(-1, -2)
ssm_state = self.kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states

View File

@ -118,8 +118,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
conv_state = self.kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p

View File

@ -471,7 +471,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.kv_cache = torch.tensor([])
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config

View File

@ -258,10 +258,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
conv_state = self.kv_cache[0].transpose(-1, -2)
ssm_state = self.kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states

View File

@ -458,9 +458,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
conv_state = self.kv_cache[0].transpose(-1, -2)
ssm_state = self.kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens

View File

@ -2013,55 +2013,6 @@ def get_mp_context():
return multiprocessing.get_context(mp_method)
def bind_kv_cache(
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: dict[str, str] | None = None,
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
# 5. Some models have attention layers that share kv cache with previous
# layers, this is specified through shared_kv_cache_layers
if shared_kv_cache_layers is None:
shared_kv_cache_layers = {}
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name
for layer_name in ctx
if (
hasattr(ctx[layer_name], "attn_type")
and ctx[layer_name].attn_type
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
)
and ctx[layer_name].kv_sharing_target_layer_name is None
]
layer_index_sorted = sorted(
set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
)
for layer_name in layer_need_kv_cache:
kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
if shared_kv_cache_layers is not None:
for layer_name, target_layer_name in shared_kv_cache_layers.items():
assert extract_layer_index(target_layer_name) < extract_layer_index(
layer_name
), "v0 doesn't support interleaving kv sharing"
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
def run_method(
obj: Any,
method: str | bytes | Callable,

View File

@ -318,8 +318,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
forward_context[layer_name].kv_cache = kv_cache
def is_residual_scattered_for_sp(