mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[misc] add forward context for attention (#9029)
This commit is contained in:
@ -3,9 +3,9 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa: F401
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv(
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output = flash_attn_with_kvcache(
|
||||
q=query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
).squeeze(1)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
|
@ -13,152 +13,15 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
# yapf: disable
|
||||
from vllm.vllm_flash_attn import (
|
||||
flash_attn_varlen_func as _flash_attn_varlen_func)
|
||||
from vllm.vllm_flash_attn import (
|
||||
flash_attn_with_kvcache as _flash_attn_with_kvcache)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Optional[List[int]] = None,
|
||||
softcap: float = 0.0,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# custom op does not support tuple input
|
||||
real_window_size: Tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
return _flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=real_window_size,
|
||||
softcap=softcap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=block_table,
|
||||
)
|
||||
|
||||
|
||||
@flash_attn_varlen_func.register_fake # type: ignore
|
||||
def _(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Optional[List[int]] = None,
|
||||
softcap: float = 0.0,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
|
||||
def flash_attn_with_kvcache(
|
||||
decode_query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
softcap: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
return _flash_attn_with_kvcache(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
block_table=block_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=softcap,
|
||||
)
|
||||
|
||||
|
||||
@flash_attn_with_kvcache.register_fake # type: ignore
|
||||
def _(
|
||||
decode_query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
softcap: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(decode_query)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::reshape_and_cache_flash",
|
||||
mutates_args=["kv_cache"])
|
||||
def reshape_and_cache_flash(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
"""Inductor cannot deal with inplace operations on views.
|
||||
See https://github.com/pytorch/pytorch/issues/131192
|
||||
and https://github.com/pytorch/pytorch/issues/130174
|
||||
This is a workaround to hide the view operation from the inductor.
|
||||
"""
|
||||
return torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
|
||||
|
||||
@reshape_and_cache_flash.register_fake # type: ignore
|
||||
def _(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
pass
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@ -721,118 +584,182 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
return output
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
torch.ops.vllm.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
@torch.library.custom_op("vllm::unified_flash_attention",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
_, num_head, head_dim = decode_query.shape
|
||||
decode_query = decode_query.reshape(-1,
|
||||
decode_meta.decode_query_len,
|
||||
num_head, head_dim)
|
||||
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=self.scale,
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
prefill_output = flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
_, num_head, head_dim = decode_query.shape
|
||||
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len,
|
||||
num_head, head_dim)
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
# Chunked prefill does not work with speculative decoding.
|
||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||
|
||||
# Chunked prefill does not work with speculative decoding.
|
||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_attention.register_fake
|
||||
def _(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
@ -7,7 +7,7 @@ try:
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
@ -799,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
|
22
vllm/forward_context.py
Normal file
22
vllm/forward_context.py
Normal file
@ -0,0 +1,22 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_forward_context: Any = None
|
||||
|
||||
|
||||
def get_forward_context() -> Any:
|
||||
"""Get the current forward context."""
|
||||
return _forward_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(context: Any):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_forward_context = prev_context
|
@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
|
||||
try:
|
||||
@ -291,16 +292,17 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
# Run model
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**kwargs,
|
||||
)
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalInputs
|
||||
@ -119,7 +120,8 @@ class EmbeddingModelRunner(
|
||||
device=self.device),
|
||||
}
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
|
@ -14,6 +14,7 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -198,17 +199,18 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
} if self.has_seqlen_agnostic else {}
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
@ -24,6 +24,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
@ -1499,7 +1500,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self._update_inputs_to_capture_for_enc_dec_model(
|
||||
capture_inputs)
|
||||
|
||||
graph_runner.capture(**capture_inputs)
|
||||
with set_forward_context(attn_metadata):
|
||||
graph_runner.capture(**capture_inputs)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
graph_runner)
|
||||
@ -1641,15 +1643,16 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
|
Reference in New Issue
Block a user