mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Force paged attention v2 for long contexts (#1510)
This commit is contained in:
@ -156,7 +156,9 @@ class PagedAttention(nn.Module):
|
|||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
# TODO(woosuk): Tune this heuristic.
|
# TODO(woosuk): Tune this heuristic.
|
||||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||||
|
use_v1 = input_metadata.max_context_len <= 8192 and (
|
||||||
|
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
# Run PagedAttention V1.
|
# Run PagedAttention V1.
|
||||||
attention_ops.paged_attention_v1(
|
attention_ops.paged_attention_v1(
|
||||||
|
@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
from vllm.utils import get_gpu_memory
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@ -141,13 +141,6 @@ class Worker:
|
|||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.sliding_window = cache_config.sliding_window
|
self.sliding_window = cache_config.sliding_window
|
||||||
|
|
||||||
if self.sliding_window is None:
|
|
||||||
max_seq_len = self.scheduler_config.max_model_len
|
|
||||||
else:
|
|
||||||
max_seq_len = min(self.scheduler_config.max_model_len,
|
|
||||||
self.sliding_window)
|
|
||||||
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
|
|
||||||
|
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.cache_events = self.cache_engine.events
|
self.cache_events = self.cache_engine.events
|
||||||
@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
|||||||
return x + [pad] * (max_len - len(x))
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
|
||||||
block_size: int) -> None:
|
|
||||||
# Follows the logic in
|
|
||||||
# attention_kernels.cu::single_query_cached_kv_attention_launcher
|
|
||||||
max_shared_mem = get_max_shared_memory_bytes()
|
|
||||||
float32_bytes = torch.finfo(torch.float).bits // 8
|
|
||||||
padded_max_seq_len = (
|
|
||||||
(max_seq_len + block_size - 1) / block_size) * block_size
|
|
||||||
# padded_max_seq_len + extra buffer
|
|
||||||
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
|
|
||||||
if padded_max_seq_len * float32_bytes > max_shared_mem:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"vLLM cannot currently support max_model_len={max_seq_len} "
|
|
||||||
f"with block_size={block_size} on GPU with compute "
|
|
||||||
f"capability {torch.cuda.get_device_capability()} "
|
|
||||||
f"(required shared memory {required_shared_mem} > "
|
|
||||||
f"available shared memory {max_shared_mem}). "
|
|
||||||
"This will be fixed in a future release.")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
# Check if the GPU supports the dtype.
|
# Check if the GPU supports the dtype.
|
||||||
if torch_dtype == torch.bfloat16:
|
if torch_dtype == torch.bfloat16:
|
||||||
|
Reference in New Issue
Block a user