Force paged attention v2 for long contexts (#1510)

This commit is contained in:
Antoni Baum
2023-11-01 16:24:32 -07:00
committed by GitHub
parent 1fe0990023
commit 9738b84a08
2 changed files with 4 additions and 29 deletions

View File

@ -156,7 +156,9 @@ class PagedAttention(nn.Module):
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
attention_ops.paged_attention_v1(

View File

@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
from vllm.utils import get_gpu_memory
class Worker:
@ -141,13 +141,6 @@ class Worker:
self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window
if self.sliding_window is None:
max_seq_len = self.scheduler_config.max_model_len
else:
max_seq_len = min(self.scheduler_config.max_model_len,
self.sliding_window)
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [pad] * (max_len - len(x))
def _check_if_can_support_max_seq_len(max_seq_len: int,
block_size: int) -> None:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_memory_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
padded_max_seq_len = (
(max_seq_len + block_size - 1) / block_size) * block_size
# padded_max_seq_len + extra buffer
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
if padded_max_seq_len * float32_bytes > max_shared_mem:
raise RuntimeError(
f"vLLM cannot currently support max_model_len={max_seq_len} "
f"with block_size={block_size} on GPU with compute "
f"capability {torch.cuda.get_device_capability()} "
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: