mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
2 Commits
d31f7844f8
...
zhuohan/re
Author | SHA1 | Date | |
---|---|---|---|
a2599dca0f | |||
3fd66b1e73 |
@ -24,7 +24,6 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
|||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
FlexibleArgumentParser,
|
FlexibleArgumentParser,
|
||||||
MemorySnapshot,
|
MemorySnapshot,
|
||||||
bind_kv_cache,
|
|
||||||
common_broadcastable_dtype,
|
common_broadcastable_dtype,
|
||||||
current_stream,
|
current_stream,
|
||||||
get_open_port,
|
get_open_port,
|
||||||
@ -343,87 +342,6 @@ def test_memory_profiling():
|
|||||||
lib.cudaFree(handle2)
|
lib.cudaFree(handle2)
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
]
|
|
||||||
bind_kv_cache(ctx, [kv_cache])
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2]
|
|
||||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3]
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_kv_sharing():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
]
|
|
||||||
shared_kv_cache_layers = {
|
|
||||||
"layers.2.self_attn": "layers.1.self_attn",
|
|
||||||
"layers.3.self_attn": "layers.0.self_attn",
|
|
||||||
}
|
|
||||||
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_non_attention():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
# example from Jamba PP=2
|
|
||||||
ctx = {
|
|
||||||
"model.layers.20.attn": Attention(32, 128, 0.1),
|
|
||||||
"model.layers.28.attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
]
|
|
||||||
bind_kv_cache(ctx, [kv_cache])
|
|
||||||
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_pp():
|
|
||||||
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
|
|
||||||
# this test runs with 1 GPU, but we simulate 2 GPUs
|
|
||||||
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
|
||||||
with set_current_vllm_config(cfg):
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]]
|
|
||||||
bind_kv_cache(ctx, kv_cache)
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0]
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("src_dtype", "tgt_dtype", "expected_result"),
|
("src_dtype", "tgt_dtype", "expected_result"),
|
||||||
[
|
[
|
||||||
|
@ -382,7 +382,6 @@ class TestNixlHandshake:
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
_before_load = time.perf_counter()
|
_before_load = time.perf_counter()
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
@ -450,7 +449,6 @@ class TestNixlHandshake:
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
_before_load = time.perf_counter()
|
_before_load = time.perf_counter()
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
@ -506,7 +504,6 @@ class TestNixlHandshake:
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
_before_load = time.perf_counter()
|
_before_load = time.perf_counter()
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
@ -666,7 +663,6 @@ def test_kv_connector_stats(dist_init):
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
@ -1241,7 +1237,6 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
@ -1344,7 +1339,6 @@ def test_handshake_failure_returns_finished(dist_init):
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
@ -1393,7 +1387,6 @@ def test_transfer_setup_failure_returns_finished(dist_init):
|
|||||||
dummy_ctx = ForwardContext(
|
dummy_ctx = ForwardContext(
|
||||||
no_compile_layers={},
|
no_compile_layers={},
|
||||||
attn_metadata={},
|
attn_metadata={},
|
||||||
virtual_engine=0,
|
|
||||||
)
|
)
|
||||||
connector.start_load_kv(dummy_ctx)
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ class RequestRunner:
|
|||||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||||
|
|
||||||
self._dummy_ctx: ForwardContext = ForwardContext(
|
self._dummy_ctx: ForwardContext = ForwardContext(
|
||||||
no_compile_layers={}, attn_metadata={}, virtual_engine=0
|
no_compile_layers={}, attn_metadata={}
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_request(self, token_ids: list[int]):
|
def new_request(self, token_ids: list[int]):
|
||||||
|
@ -272,14 +272,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
# use a placeholder kv cache tensor during init, which will be replaced
|
# use a placeholder kv cache tensor during init, which will be replaced
|
||||||
# by bind_kv_cache
|
# by bind_kv_cache this variable will not be accessed if use_direct_call
|
||||||
# this variable will not be accessed if use_direct_call is True
|
# is True
|
||||||
self.kv_cache = [
|
self.kv_cache = torch.tensor([])
|
||||||
torch.tensor([])
|
|
||||||
for _ in range(
|
|
||||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
@ -361,9 +356,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[self.layer_name]
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
self.impl.forward(
|
self.impl.forward(
|
||||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
self, query, key, value, self.kv_cache, attn_metadata, output=output
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
torch.ops.vllm.unified_attention_with_output(
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
@ -376,9 +371,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[self.layer_name]
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
return self.impl.forward(
|
return self.impl.forward(
|
||||||
self, query, key, value, self_kv_cache, attn_metadata
|
self, query, key, value, self.kv_cache, attn_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.unified_attention(
|
return torch.ops.vllm.unified_attention(
|
||||||
@ -644,12 +639,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
self.kv_cache = [
|
self.kv_cache = torch.tensor([])
|
||||||
torch.tensor([])
|
|
||||||
for _ in range(
|
|
||||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Align with Attention's scale attributes for MLA backends.
|
# Align with Attention's scale attributes for MLA backends.
|
||||||
|
|
||||||
@ -688,7 +678,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[self.layer_name]
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
|
|
||||||
# Mirror Attention.forward scale calculation path
|
# Mirror Attention.forward scale calculation path
|
||||||
if self.calculate_kv_scales and getattr(
|
if self.calculate_kv_scales and getattr(
|
||||||
@ -703,14 +692,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
q,
|
q,
|
||||||
kv_c_normed,
|
kv_c_normed,
|
||||||
k_pe,
|
k_pe,
|
||||||
self_kv_cache,
|
self.kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
output=output,
|
output=output,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
return self.impl.forward(
|
return self.impl.forward(
|
||||||
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.attn_backend.accept_output_buffer:
|
if self.attn_backend.accept_output_buffer:
|
||||||
@ -785,7 +774,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
|||||||
|
|
||||||
def maybe_save_kv_layer_to_connector(
|
def maybe_save_kv_layer_to_connector(
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_cache_layer: list[torch.Tensor],
|
kv_cache_layer: torch.Tensor,
|
||||||
):
|
):
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
return
|
return
|
||||||
@ -851,10 +840,9 @@ def unified_attention(
|
|||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[layer_name]
|
attn_metadata = attn_metadata[layer_name]
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
output = self.impl.forward(self, query, key, value, self.kv_cache, attn_metadata)
|
||||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -889,20 +877,19 @@ def unified_attention_with_output(
|
|||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[layer_name]
|
attn_metadata = attn_metadata[layer_name]
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
self.impl.forward(
|
self.impl.forward(
|
||||||
self,
|
self,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
kv_cache,
|
self.kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
output=output,
|
output=output,
|
||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
output_block_scale=output_block_scale,
|
output_block_scale=output_block_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_with_output_fake(
|
def unified_attention_with_output_fake(
|
||||||
@ -938,10 +925,9 @@ def unified_mla_attention(
|
|||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[layer_name]
|
attn_metadata = attn_metadata[layer_name]
|
||||||
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
output = self.impl.forward(self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata)
|
||||||
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -978,20 +964,19 @@ def unified_mla_attention_with_output(
|
|||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[layer_name]
|
attn_metadata = attn_metadata[layer_name]
|
||||||
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
self.impl.forward(
|
self.impl.forward(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
kv_c_normed,
|
kv_c_normed,
|
||||||
k_pe,
|
k_pe,
|
||||||
kv_cache,
|
self.kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
output=output,
|
output=output,
|
||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
output_block_scale=output_block_scale,
|
output_block_scale=output_block_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
maybe_save_kv_layer_to_connector(layer_name, self.kv_cache)
|
||||||
|
|
||||||
|
|
||||||
def unified_mla_attention_with_output_fake(
|
def unified_mla_attention_with_output_fake(
|
||||||
|
@ -200,12 +200,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
# Only process layers that have kv_cache
|
# Only process layers that have kv_cache
|
||||||
# attribute (attention layers) Skip non-attention
|
# attribute (attention layers) Skip non-attention
|
||||||
# layers like FusedMoE
|
# layers like FusedMoE
|
||||||
kv_cache = getattr(layer, "kv_cache", None)
|
layer = getattr(layer, "kv_cache", None)
|
||||||
if kv_cache is None:
|
if layer is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
layer = kv_cache[forward_context.virtual_engine]
|
|
||||||
|
|
||||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||||
request.request_id + "#" + layer_name, remote_address
|
request.request_id + "#" + layer_name, remote_address
|
||||||
)
|
)
|
||||||
|
@ -174,12 +174,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
# Only process layers that have kv_cache
|
# Only process layers that have kv_cache
|
||||||
# attribute (attention layers) Skip non-attention
|
# attribute (attention layers) Skip non-attention
|
||||||
# layers like FusedMoE/MLP etc.
|
# layers like FusedMoE/MLP etc.
|
||||||
kv_cache_attr = getattr(layer, "kv_cache", None)
|
kv_cache_layer = getattr(layer, "kv_cache", None)
|
||||||
if kv_cache_attr is None:
|
if kv_cache_layer is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
|
|
||||||
|
|
||||||
filename = self._generate_filename_debug(
|
filename = self._generate_filename_debug(
|
||||||
layer_name, request.token_ids, request.mm_hashes
|
layer_name, request.token_ids, request.mm_hashes
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ class BatchDescriptor(NamedTuple):
|
|||||||
num_tokens: int
|
num_tokens: int
|
||||||
uniform_decode: bool = False
|
uniform_decode: bool = False
|
||||||
"""
|
"""
|
||||||
False can also be used for an uniform decode batch to dispatch to the
|
False can also be used for an uniform decode batch to dispatch to the
|
||||||
cudagraph supporting non-uniform batches.
|
cudagraph supporting non-uniform batches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -179,8 +179,8 @@ class ForwardContext:
|
|||||||
# copy from vllm_config.compilation_config.static_forward_context
|
# copy from vllm_config.compilation_config.static_forward_context
|
||||||
no_compile_layers: dict[str, Any]
|
no_compile_layers: dict[str, Any]
|
||||||
"""
|
"""
|
||||||
Type AttentionMetadata for v0,
|
Type AttentionMetadata for v0,
|
||||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||||
attention layer to its attention metadata
|
attention layer to its attention metadata
|
||||||
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
|
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
|
||||||
for each microbatch.
|
for each microbatch.
|
||||||
@ -191,8 +191,6 @@ class ForwardContext:
|
|||||||
dict[str, "AttentionMetadata"],
|
dict[str, "AttentionMetadata"],
|
||||||
list[dict[str, "AttentionMetadata"]],
|
list[dict[str, "AttentionMetadata"]],
|
||||||
]
|
]
|
||||||
# TODO: remove after making all virtual_engines share the same kv cache
|
|
||||||
virtual_engine: int # set dynamically for each forward pass
|
|
||||||
# set dynamically for each forward pass
|
# set dynamically for each forward pass
|
||||||
dp_metadata: DPMetadata | None = None
|
dp_metadata: DPMetadata | None = None
|
||||||
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
||||||
@ -223,7 +221,6 @@ def get_forward_context() -> ForwardContext:
|
|||||||
def create_forward_context(
|
def create_forward_context(
|
||||||
attn_metadata: Any,
|
attn_metadata: Any,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
virtual_engine: int = 0,
|
|
||||||
dp_metadata: DPMetadata | None = None,
|
dp_metadata: DPMetadata | None = None,
|
||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: BatchDescriptor | None = None,
|
batch_descriptor: BatchDescriptor | None = None,
|
||||||
@ -231,7 +228,6 @@ def create_forward_context(
|
|||||||
):
|
):
|
||||||
return ForwardContext(
|
return ForwardContext(
|
||||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||||
virtual_engine=virtual_engine,
|
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
dp_metadata=dp_metadata,
|
dp_metadata=dp_metadata,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
@ -259,7 +255,6 @@ def override_forward_context(forward_context: ForwardContext | None):
|
|||||||
def set_forward_context(
|
def set_forward_context(
|
||||||
attn_metadata: Any,
|
attn_metadata: Any,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
virtual_engine: int = 0,
|
|
||||||
num_tokens: int | None = None,
|
num_tokens: int | None = None,
|
||||||
num_tokens_across_dp: torch.Tensor | None = None,
|
num_tokens_across_dp: torch.Tensor | None = None,
|
||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
@ -305,7 +300,6 @@ def set_forward_context(
|
|||||||
forward_context = create_forward_context(
|
forward_context = create_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
vllm_config,
|
vllm_config,
|
||||||
virtual_engine,
|
|
||||||
dp_metadata,
|
dp_metadata,
|
||||||
cudagraph_runtime_mode,
|
cudagraph_runtime_mode,
|
||||||
batch_descriptor,
|
batch_descriptor,
|
||||||
|
@ -328,7 +328,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
kv_cache = self.kv_cache[0]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
|
|
||||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||||
|
@ -248,9 +248,8 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||||
query_start_loc = mamba1_metadata.query_start_loc
|
query_start_loc = mamba1_metadata.query_start_loc
|
||||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
conv_state = self.kv_cache[0].transpose(-1, -2)
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
ssm_state = self.kv_cache[1]
|
||||||
ssm_state = self_kv_cache[1]
|
|
||||||
has_initial_states = mamba1_metadata.has_initial_states
|
has_initial_states = mamba1_metadata.has_initial_states
|
||||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||||
|
|
||||||
|
@ -511,10 +511,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self.kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self.kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
prep_initial_states = attn_metadata.prep_initial_states
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
|
@ -118,8 +118,7 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
conv_state = self.kv_cache[0].transpose(-1, -2)
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
|
|
||||||
|
@ -471,7 +471,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
|||||||
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
|
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.kv_cache = [torch.tensor([])]
|
self.kv_cache = torch.tensor([])
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
@ -258,10 +258,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self.kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self.kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
prep_initial_states = attn_metadata.prep_initial_states
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
|
@ -458,9 +458,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
conv_state = self.kv_cache[0].transpose(-1, -2)
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
ssm_state = self.kv_cache[1]
|
||||||
ssm_state = self_kv_cache[1]
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||||
|
|
||||||
|
@ -2013,55 +2013,6 @@ def get_mp_context():
|
|||||||
return multiprocessing.get_context(mp_method)
|
return multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|
||||||
def bind_kv_cache(
|
|
||||||
ctx: dict[str, Any],
|
|
||||||
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
|
||||||
shared_kv_cache_layers: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
# Bind the kv_cache tensor to Attention modules, similar to
|
|
||||||
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
|
||||||
# Special things handled here:
|
|
||||||
# 1. Some models have non-attention layers, e.g., Jamba
|
|
||||||
# 2. Pipeline parallelism, each rank only has a subset of layers
|
|
||||||
# 3. Encoder attention has no kv cache
|
|
||||||
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
|
|
||||||
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
|
|
||||||
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
|
|
||||||
# tensor
|
|
||||||
# 5. Some models have attention layers that share kv cache with previous
|
|
||||||
# layers, this is specified through shared_kv_cache_layers
|
|
||||||
if shared_kv_cache_layers is None:
|
|
||||||
shared_kv_cache_layers = {}
|
|
||||||
from vllm.attention import AttentionType
|
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
|
||||||
|
|
||||||
layer_need_kv_cache = [
|
|
||||||
layer_name
|
|
||||||
for layer_name in ctx
|
|
||||||
if (
|
|
||||||
hasattr(ctx[layer_name], "attn_type")
|
|
||||||
and ctx[layer_name].attn_type
|
|
||||||
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
|
|
||||||
)
|
|
||||||
and ctx[layer_name].kv_sharing_target_layer_name is None
|
|
||||||
]
|
|
||||||
layer_index_sorted = sorted(
|
|
||||||
set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
|
|
||||||
)
|
|
||||||
for layer_name in layer_need_kv_cache:
|
|
||||||
kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
|
|
||||||
forward_ctx = ctx[layer_name]
|
|
||||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
|
||||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
|
||||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
|
||||||
if shared_kv_cache_layers is not None:
|
|
||||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
|
||||||
assert extract_layer_index(target_layer_name) < extract_layer_index(
|
|
||||||
layer_name
|
|
||||||
), "v0 doesn't support interleaving kv sharing"
|
|
||||||
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
def run_method(
|
def run_method(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
method: str | bytes | Callable,
|
method: str | bytes | Callable,
|
||||||
|
@ -318,8 +318,7 @@ def bind_kv_cache(
|
|||||||
|
|
||||||
# Bind kv_caches to forward context
|
# Bind kv_caches to forward context
|
||||||
for layer_name, kv_cache in kv_caches.items():
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
# NOTE: Use list because of v0 PP virtual engine.
|
forward_context[layer_name].kv_cache = kv_cache
|
||||||
forward_context[layer_name].kv_cache = [kv_cache]
|
|
||||||
|
|
||||||
|
|
||||||
def is_residual_scattered_for_sp(
|
def is_residual_scattered_for_sp(
|
||||||
|
Reference in New Issue
Block a user