[V0 Deprecation] Remove unused classes in attention (#25541)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon
2025-09-24 13:24:40 -07:00
committed by GitHub
parent 8c853050e7
commit e6750d0b18
6 changed files with 11 additions and 716 deletions

View File

@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
@ -13,7 +11,5 @@ __all__ = [
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"AttentionState",
"get_attn_backend",
]

View File

@ -2,10 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
Type, TypeVar)
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
import torch
@ -49,18 +46,13 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
@ -77,149 +69,18 @@ class AttentionBackend(ABC):
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@dataclass
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass
@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
pass
T = TypeVar("T", bound=AttentionMetadata)
class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: Any):
...
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield
@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...
@abstractmethod
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@abstractmethod
def get_graph_input_buffers(
self,
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@abstractmethod
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@abstractmethod
def begin_forward(self, model_input) -> None:
"""Prepare state for forward pass."""
...
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self, input_builder) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionLayer(Protocol):
_q_scale: torch.Tensor

View File

@ -1,559 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Optional
import numpy as np
import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__)
PAD_SLOT_ID = -1
# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
return (isinstance(block_tables, dict)
and all(value is None for value in block_tables.values()))
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
start_idx = max(0, query_len - sliding_window)
return start_idx
def _compute_slot_mapping_python(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
for i in range(range_start, range_end):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
def _compute_slot_mapping_numpy(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
block_table_array = np.array(block_table)
idx = np.arange(range_start, range_end)
block_offset = idx % block_size
idx //= block_size
seq_slot_mapping_array = block_table_array[idx]
seq_slot_mapping_array *= block_size
seq_slot_mapping_array += block_offset
slot_mapping.extend(seq_slot_mapping_array)
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
block_tables: Dict[int, List[int]]):
"""
Compute slot mapping.
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
padding_mask_len = max(0, start_idx - context_len)
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start
block_table = block_tables[seq_id]
# numpy implementation will be faster than python if we have
# many elements, otherwise it will be slower.
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
range_end, block_size)
else:
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
range_end, block_size)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class CommonAttentionState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_model_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
return attn_metadata
def get_graph_input_buffers(
self,
attn_metadata,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
self._add_additional_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
def prepare_graph_input_buffers(
self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False) -> None:
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
def begin_forward(self, model_input) -> None:
return
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
attn_metadata):
"""
Updates the attention metadata parameters for CUDA graph capture in an
encoder-decoder model.
This method modifies attention-related tensors and metadata required
for CUDA graph capture in encoder-decoder models. Specifically, it
updates the cross-attention and encoder sequence tensors in the
AttentionMetadata object.
"""
# During decode phase the cross_slot_mapping will be empty. Hence set
# an empty tensor for CUDA Graph capture.
attn_metadata.cross_slot_mapping = torch.tensor(
[], dtype=torch.int).cuda()
attn_metadata.cross_block_tables = torch.full(
(batch_size, self.runner.get_max_block_per_batch()),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
attn_metadata.num_encoder_tokens = 0
def _add_additional_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]):
"""
Saves additional input buffers specific to the encoder-decoder model
from the attention metadata.
This method extracts and stores encoder-decoder related input buffers
from the `attn_metadata` into the `input_buffers` dictionary. The
buffers include encoder sequence lengths, cross-slot mappings, and
cross-block tables, which are essential for the encoder-decoder model
during CUDA graph replay.
"""
input_buffers["encoder_seq_lens_tensor"] = (
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
input_buffers["cross_slot_mapping"] = (
attn_metadata.decode_metadata.cross_slot_mapping)
input_buffers["cross_block_tables"] = (
attn_metadata.decode_metadata.cross_block_tables)
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
input_buffers: Dict[str,
Any]):
"""
Populates input buffers with data from the encoder-decoder model's
attention metadata.
This method fills the input buffers with encoder-decoder specific
tensors. It copies data from the `attn_metadata` and keyword arguments
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
The copied data includes attention-related metadata as well as input
IDs and positional information for the encoder.
"""
input_buffers["encoder_seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
non_blocking=True)
input_buffers["cross_slot_mapping"].copy_(
attn_metadata.decode_metadata.cross_slot_mapping,
non_blocking=True)
input_buffers["cross_block_tables"].copy_(
attn_metadata.decode_metadata.cross_block_tables,
non_blocking=True)
def is_all_encoder_attn_metadata_set(attn_metadata):
'''
All attention metadata required for encoder attention is set.
'''
return ((attn_metadata.encoder_seq_lens is not None)
and (attn_metadata.encoder_seq_lens_tensor is not None)
and (attn_metadata.max_encoder_seq_len is not None))
def is_all_cross_attn_metadata_set(attn_metadata):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (attn_metadata.is_all_encoder_attn_metadata_set
and (attn_metadata.cross_slot_mapping is not None)
and (attn_metadata.cross_block_tables is not None))
def get_seq_len_block_table_args(
attn_metadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_num_prefill_decode_query_kv_tokens(
attn_metadata,
attn_type: str,
) -> Tuple[int, int, int]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (AttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
num_prefill_kv_tokens = 0
if attn_type == AttentionType.ENCODER:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = 0
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
else: # attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)
@dataclass
class MLADims:

View File

@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@ -65,10 +64,6 @@ class TorchSDPABackend(AttentionBackend):
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
@ -835,16 +830,6 @@ class _PagedAttention:
blocksparse_head_sliding_step,
)
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention):

View File

@ -8,7 +8,6 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2
@ -97,10 +96,6 @@ class PallasAttentionBackend(AttentionBackend):
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -9,7 +9,6 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
@ -25,7 +24,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@ -184,8 +184,9 @@ class EagleProposer:
builder = (self._get_attention_metadata_builder()
if self.attn_metadata_builder is None else
self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
@ -319,7 +320,7 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = builder.build_for_drafting(
attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names: