mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
7 Commits
Author | SHA1 | Date | |
---|---|---|---|
39c4a4cdb5 | |||
1ccf100c6a | |||
248c5b632d | |||
950f349492 | |||
61bb55f3d5 | |||
0bddb6b9a5 | |||
c715fb19e5 |
@ -89,4 +89,4 @@ repos:
|
|||||||
name: Suggestion
|
name: Suggestion
|
||||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||||
language: system
|
language: system
|
||||||
verbose: true
|
verbose: true
|
@ -8,10 +8,10 @@ prompts = [
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m")
|
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
@ -19,4 +19,4 @@ outputs = llm.generate(prompts, sampling_params)
|
|||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
@ -20,7 +20,7 @@ TASK = "gsm8k"
|
|||||||
FILTER = "exact_match,strict-match"
|
FILTER = "exact_match,strict-match"
|
||||||
RTOL = 0.03
|
RTOL = 0.03
|
||||||
EXPECTED_VALUE = 0.58
|
EXPECTED_VALUE = 0.58
|
||||||
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
|
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||||
MORE_ARGS_LIST = [
|
MORE_ARGS_LIST = [
|
||||||
[], # Default
|
[], # Default
|
||||||
["--enable-chunked-prefill"], # Chunked
|
["--enable-chunked-prefill"], # Chunked
|
||||||
@ -66,14 +66,21 @@ def run_test(more_args):
|
|||||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||||
reason="V1 currently only supported on CUDA")
|
and not current_platform.is_tpu(),
|
||||||
|
reason="V1 currently only supported on CUDA and TPU")
|
||||||
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
run_test([])
|
more_args = []
|
||||||
|
|
||||||
|
# Limit compilation time for V1
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
more_args = ["--max-num-seqs", "64"]
|
||||||
|
|
||||||
|
run_test(more_args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||||
|
@ -34,4 +34,4 @@ run_mypy vllm/plugins
|
|||||||
run_mypy vllm/prompt_adapter
|
run_mypy vllm/prompt_adapter
|
||||||
run_mypy vllm/spec_decode
|
run_mypy vllm/spec_decode
|
||||||
run_mypy vllm/worker
|
run_mypy vllm/worker
|
||||||
run_mypy vllm/v1
|
run_mypy vllm/v1
|
@ -135,7 +135,7 @@ class CudaPlatformBase(Platform):
|
|||||||
else:
|
else:
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
parallel_config.worker_cls = \
|
parallel_config.worker_cls = \
|
||||||
"vllm.v1.worker.gpu_worker.Worker"
|
"vllm.v1.worker.gpu_worker.GPUWorker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ class _Backend(enum.Enum):
|
|||||||
FLASHINFER = enum.auto()
|
FLASHINFER = enum.auto()
|
||||||
HPU_ATTN = enum.auto()
|
HPU_ATTN = enum.auto()
|
||||||
PALLAS = enum.auto()
|
PALLAS = enum.auto()
|
||||||
|
PALLAS_VLLM_V1 = enum.auto()
|
||||||
IPEX = enum.auto()
|
IPEX = enum.auto()
|
||||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||||
NO_ATTENTION = enum.auto()
|
NO_ATTENTION = enum.auto()
|
||||||
|
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum, _Backend
|
from .interface import Platform, PlatformEnum, _Backend
|
||||||
@ -30,10 +31,16 @@ class TpuPlatform(Platform):
|
|||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool) -> str:
|
block_size: int, use_v1: bool) -> str:
|
||||||
if selected_backend != _Backend.PALLAS:
|
if (selected_backend != _Backend.PALLAS
|
||||||
|
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||||
logger.info("Using Pallas backend.")
|
|
||||||
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
if use_v1:
|
||||||
|
logger.info("Using Pallas V1 backend.")
|
||||||
|
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||||
|
else:
|
||||||
|
logger.info("Using Pallas backend.")
|
||||||
|
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
@ -45,7 +52,7 @@ class TpuPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||||
return True
|
return not envs.VLLM_USE_V1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
@ -60,11 +67,11 @@ class TpuPlatform(Platform):
|
|||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if compilation_config.level == CompilationLevel.NO_COMPILATION:
|
|
||||||
# TPU does not support NO_COMPILATION
|
# TPU only supports DYNAMO_ONCE compilation level
|
||||||
|
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
|
||||||
|
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
|
||||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||||
assert compilation_config.level < CompilationLevel.PIECEWISE,\
|
|
||||||
"TPU does not support Inductor."
|
|
||||||
|
|
||||||
if compilation_config.backend == "":
|
if compilation_config.backend == "":
|
||||||
compilation_config.backend = "openxla"
|
compilation_config.backend = "openxla"
|
||||||
@ -72,10 +79,6 @@ class TpuPlatform(Platform):
|
|||||||
assert vllm_config.speculative_config is None, \
|
assert vllm_config.speculative_config is None, \
|
||||||
"TPU does not support speculative decoding"
|
"TPU does not support speculative decoding"
|
||||||
|
|
||||||
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
|
|
||||||
"Chunked prefill is not yet supported for TPU backend")
|
|
||||||
assert not vllm_config.speculative_config, (
|
|
||||||
"Speculative decoding is not yet supported for TPU backend")
|
|
||||||
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
|
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The TPU backend currently does not support %s. "
|
"The TPU backend currently does not support %s. "
|
||||||
@ -85,8 +88,27 @@ class TpuPlatform(Platform):
|
|||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if scheduler_config.is_multi_step:
|
if envs.VLLM_USE_V1:
|
||||||
parallel_config.worker_cls = \
|
parallel_config.worker_cls = \
|
||||||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
"vllm.v1.worker.tpu_worker.TPUWorker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
|
if scheduler_config.is_multi_step:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.tpu_worker.TPUWorker"
|
||||||
|
|
||||||
|
# Adjust scheduler config for V1
|
||||||
|
# TODO: Add support for these
|
||||||
|
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
|
||||||
|
logger.warning("[V1][TPU] Disable prefix caching")
|
||||||
|
vllm_config.cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
|
assert not vllm_config.speculative_config, (
|
||||||
|
"Speculative decoding is not yet supported for TPU backend")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_pin_memory_available(cls):
|
||||||
|
logger.warning("Pin memory is not supported on TPU.")
|
||||||
|
return False
|
||||||
|
351
vllm/v1/attention/backends/pallas.py
Normal file
351
vllm/v1/attention/backends/pallas.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
|
||||||
|
|
||||||
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "PALLAS_VLLM_V1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||||
|
return PallasAttentionBackendImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||||
|
return PallasMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||||
|
return CommonAttentionState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
src_indices, dst_indices = src_to_dists
|
||||||
|
for k_cache, v_cache in kv_caches:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||||
|
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||||
|
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PallasMetadata(AttentionMetadata):
|
||||||
|
|
||||||
|
# Currently, input sequences can only contain all prefills
|
||||||
|
# or all decoding.
|
||||||
|
block_tables: Optional[torch.Tensor] = None
|
||||||
|
context_lens: Optional[torch.Tensor] = None
|
||||||
|
effective_query_lens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert self.num_decode_tokens == 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert self.num_prefills == 0
|
||||||
|
assert self.num_prefill_tokens == 0
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.context_lens is not None
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class PallasAttentionBackendImpl(AttentionImpl):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
if head_size % 128 != 0:
|
||||||
|
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
raise NotImplementedError("Alibi slopes is not supported.")
|
||||||
|
if sliding_window is not None:
|
||||||
|
raise NotImplementedError("Sliding window is not supported.")
|
||||||
|
if kv_cache_dtype != "auto":
|
||||||
|
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||||
|
if blocksparse_params is not None:
|
||||||
|
raise NotImplementedError("Blocksparse is not supported.")
|
||||||
|
if logits_soft_cap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Attention logits soft-capping is not supported.")
|
||||||
|
|
||||||
|
if torch_xla.tpu.version() < 4:
|
||||||
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||||
|
|
||||||
|
self.megacore_mode = None
|
||||||
|
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||||
|
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||||
|
or tpu_env.get("TYPE", None)
|
||||||
|
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||||
|
assert tpu_type is not None
|
||||||
|
tpu_type = tpu_type.lower()
|
||||||
|
|
||||||
|
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||||
|
if self.num_kv_heads % 2 == 0:
|
||||||
|
self.megacore_mode = "kv_head"
|
||||||
|
else:
|
||||||
|
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||||
|
# megacore mode will be None.
|
||||||
|
self.megacore_mode = "batch"
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"PallasAttentionBackendImpl")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attn_metadata: PallasMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with Pallas attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
|
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||||
|
with shape [0] for profiling run.
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if attn_metadata is None:
|
||||||
|
if output is None:
|
||||||
|
output = torch.ones_like(query)
|
||||||
|
return output
|
||||||
|
|
||||||
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||||
|
self.head_size)
|
||||||
|
|
||||||
|
if kv_cache[0].numel() > 0:
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
key_cache, value_cache = kv_cache
|
||||||
|
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
|
query = query * self.scale
|
||||||
|
if attn_metadata.num_prefills > 0:
|
||||||
|
if attn_metadata.block_tables is None:
|
||||||
|
# Prefill without paged KV cache.
|
||||||
|
assert seq_len % 16 == 0, (
|
||||||
|
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||||
|
f"multiple of 16 but got {seq_len}")
|
||||||
|
|
||||||
|
# Handle GQA/MQA.
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
key = key.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=-2)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_heads,
|
||||||
|
self.head_size)
|
||||||
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=-2)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_heads,
|
||||||
|
self.head_size)
|
||||||
|
# FlashAttention kernel requires the input shape to be
|
||||||
|
# [batch_size, num_heads, seq_len, d_model]
|
||||||
|
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||||
|
# Permute the input to match the required format.
|
||||||
|
output = torch.ops.xla.flash_attention(
|
||||||
|
query.permute(0, 2, 1, 3),
|
||||||
|
key.permute(0, 2, 1, 3),
|
||||||
|
value.permute(0, 2, 1, 3),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
output = output.permute(0, 2, 1, 3)
|
||||||
|
else:
|
||||||
|
# Prefill with paged KV cache.
|
||||||
|
# TODO(woosuk): Tune the below knobs.
|
||||||
|
num_kv_pages_per_compute_block = 16
|
||||||
|
num_queries_per_compute_block = 16
|
||||||
|
assert seq_len % num_queries_per_compute_block == 0
|
||||||
|
output = torch.ops.xla.multi_queries_paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.effective_query_lens,
|
||||||
|
num_kv_pages_per_compute_block,
|
||||||
|
num_queries_per_compute_block,
|
||||||
|
use_kernel=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Decoding run.
|
||||||
|
assert kv_cache[0].numel() > 0
|
||||||
|
query = query.squeeze(dim=1)
|
||||||
|
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||||
|
|
||||||
|
assert attn_metadata.block_tables is not None
|
||||||
|
assert attn_metadata.context_lens is not None
|
||||||
|
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||||
|
# block table in SMEM. Therefore, if the block table is too large,
|
||||||
|
# the kernel compilation will fail. To avoid this, we split the
|
||||||
|
# batch dimension into smaller chunks and run the kernel multiple
|
||||||
|
# times.
|
||||||
|
MAX_SMEM_USAGE = 512 * 1024
|
||||||
|
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||||
|
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||||
|
|
||||||
|
if batch_size <= max_num_seq:
|
||||||
|
output = paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
self.megacore_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_size = max_num_seq
|
||||||
|
# Make sure the chunk size is a multiple of 2.
|
||||||
|
chunk_size = chunk_size // 2 * 2
|
||||||
|
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
chunk_start = chunk_idx * chunk_size
|
||||||
|
chunk_end = chunk_start + chunk_size
|
||||||
|
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||||
|
# compilation error. Instead, we rely on the slice operation
|
||||||
|
# to handle the out-of-bound case.
|
||||||
|
# chunk_end = min(chunk_end, batch_size)
|
||||||
|
chunk_output = paged_attention(
|
||||||
|
query[chunk_start:chunk_end],
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||||
|
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||||
|
pages_per_compute_block,
|
||||||
|
self.megacore_mode,
|
||||||
|
)
|
||||||
|
output[chunk_start:chunk_end] = chunk_output
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.reshape(batch_size, seq_len, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_kv_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||||
|
|
||||||
|
key = key.flatten(0, 2)
|
||||||
|
value = value.flatten(0, 2)
|
||||||
|
key_cache = key_cache.flatten(0, 2)
|
||||||
|
value_cache = value_cache.flatten(0, 2)
|
||||||
|
key_cache.index_copy_(0, slot_mapping, key)
|
||||||
|
value_cache.index_copy_(0, slot_mapping, value)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
pages_per_compute_block: int,
|
||||||
|
megacore_mode: Optional[str],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||||
|
megacore_mode = None
|
||||||
|
else:
|
||||||
|
megacore_mode = megacore_mode
|
||||||
|
|
||||||
|
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||||
|
# "xla::paged_attention() Expected a value of type 'str' for
|
||||||
|
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||||
|
if megacore_mode is not None:
|
||||||
|
output = torch.ops.xla.paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
megacore_mode=megacore_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = torch.ops.xla.paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
)
|
||||||
|
return output
|
@ -72,7 +72,7 @@ class InputBatch:
|
|||||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
# Block table.
|
# Block table.
|
||||||
self.block_table = BlockTable(
|
self.block_table = BlockTable(
|
||||||
|
@ -5,32 +5,23 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.utils import DeviceMemoryProfiler, cdiv
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|
||||||
LayerBlockType, cdiv, is_pin_memory_available)
|
|
||||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||||
FlashAttentionMetadata)
|
FlashAttentionMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
||||||
KVCacheSpec)
|
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
@ -38,87 +29,17 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GPUModelRunner:
|
class GPUModelRunner(ModelRunnerBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.vllm_config = vllm_config
|
super().__init__(vllm_config, device)
|
||||||
self.model_config = vllm_config.model_config
|
|
||||||
self.cache_config = vllm_config.cache_config
|
|
||||||
self.lora_config = vllm_config.lora_config
|
|
||||||
self.load_config = vllm_config.load_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
||||||
self.observability_config = vllm_config.observability_config
|
|
||||||
|
|
||||||
model_config = self.model_config
|
# KV caches for forward pass
|
||||||
cache_config = self.cache_config
|
|
||||||
scheduler_config = self.scheduler_config
|
|
||||||
parallel_config = self.parallel_config
|
|
||||||
self.device = device
|
|
||||||
self.pin_memory = is_pin_memory_available()
|
|
||||||
self.dtype = self.model_config.dtype
|
|
||||||
if cache_config.cache_dtype == "auto":
|
|
||||||
self.kv_cache_dtype = self.dtype
|
|
||||||
else:
|
|
||||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
||||||
cache_config.cache_dtype]
|
|
||||||
|
|
||||||
self.is_multimodal_model = model_config.is_multimodal_model
|
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
|
||||||
self.block_size = cache_config.block_size
|
|
||||||
self.max_model_len = model_config.max_model_len
|
|
||||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
|
||||||
|
|
||||||
# Model-related.
|
|
||||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
|
||||||
parallel_config, LayerBlockType.attention)
|
|
||||||
self.num_query_heads = model_config.get_num_attention_heads(
|
|
||||||
parallel_config)
|
|
||||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
|
||||||
self.head_size = model_config.get_head_size()
|
|
||||||
self.hidden_size = model_config.get_hidden_size()
|
|
||||||
|
|
||||||
# Multi-modal data support
|
|
||||||
self.input_registry = INPUT_REGISTRY
|
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
|
||||||
|
|
||||||
# NOTE: Initialized input mapper is only used for processing dummy
|
|
||||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
|
||||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
|
||||||
self.mm_input_mapper_profiling.use_cache = False
|
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
|
||||||
model_config=model_config,
|
|
||||||
scheduler_config=scheduler_config,
|
|
||||||
)
|
|
||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
||||||
self.encoder_cache_size = encoder_cache_size
|
|
||||||
|
|
||||||
# Lazy initialization
|
|
||||||
# self.model: nn.Module # Set after load_model
|
|
||||||
self.kv_caches: List[torch.Tensor] = []
|
self.kv_caches: List[torch.Tensor] = []
|
||||||
# req_id -> (input_id -> encoder_output)
|
|
||||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
|
||||||
|
|
||||||
# Request states.
|
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
|
||||||
# Persistent batch.
|
|
||||||
self.input_batch = InputBatch(
|
|
||||||
max_num_reqs=self.max_num_reqs,
|
|
||||||
max_model_len=self.max_model_len,
|
|
||||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
|
||||||
device=self.device,
|
|
||||||
pin_memory=self.pin_memory,
|
|
||||||
vocab_size=model_config.get_vocab_size(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
@ -202,132 +123,6 @@ class GPUModelRunner:
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
||||||
# Remove stopped requests from the cached states.
|
|
||||||
# Keep the states of the pre-empted requests.
|
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
|
||||||
self.requests.pop(req_id, None)
|
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Free the cached encoder outputs.
|
|
||||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
|
||||||
encoder_outputs = self.encoder_cache.get(req_id)
|
|
||||||
if encoder_outputs is not None:
|
|
||||||
encoder_outputs.pop(input_id, None)
|
|
||||||
if not encoder_outputs:
|
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Remove the requests from the persistent batch.
|
|
||||||
stopped_req_ids = set().union(
|
|
||||||
scheduler_output.preempted_req_ids,
|
|
||||||
scheduler_output.finished_req_ids,
|
|
||||||
)
|
|
||||||
removed_req_indices: List[int] = []
|
|
||||||
for req_id in stopped_req_ids:
|
|
||||||
req_index = self.input_batch.remove_request(req_id)
|
|
||||||
if req_index is not None:
|
|
||||||
removed_req_indices.append(req_index)
|
|
||||||
|
|
||||||
# Update the states of the running requests.
|
|
||||||
for req_data in scheduler_output.scheduled_running_reqs:
|
|
||||||
req_id = req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
req_index = self.input_batch.req_id_to_index[req_id]
|
|
||||||
|
|
||||||
# Update the num_computed_tokens.
|
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
||||||
req_data.num_computed_tokens)
|
|
||||||
|
|
||||||
# Update the block table.
|
|
||||||
num_new_blocks = len(req_data.new_block_ids)
|
|
||||||
if num_new_blocks == 0:
|
|
||||||
continue
|
|
||||||
start_index = len(req_state.block_ids)
|
|
||||||
req_state.block_ids.extend(req_data.new_block_ids)
|
|
||||||
self.input_batch.block_table.append_row(req_index, start_index,
|
|
||||||
req_data.new_block_ids)
|
|
||||||
|
|
||||||
req_ids_to_add: List[str] = []
|
|
||||||
# Add new requests to the cached states.
|
|
||||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
||||||
req_id = new_req_data.req_id
|
|
||||||
sampling_params = new_req_data.sampling_params
|
|
||||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
||||||
generator = torch.Generator(device=self.device)
|
|
||||||
generator.manual_seed(sampling_params.seed)
|
|
||||||
else:
|
|
||||||
generator = None
|
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
|
||||||
req_id=req_id,
|
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
||||||
prompt=new_req_data.prompt,
|
|
||||||
mm_inputs=new_req_data.mm_inputs,
|
|
||||||
mm_positions=new_req_data.mm_positions,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
generator=generator,
|
|
||||||
block_ids=new_req_data.block_ids,
|
|
||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
||||||
output_token_ids=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
||||||
if self.model_config.uses_mrope:
|
|
||||||
image_grid_thw = []
|
|
||||||
video_grid_thw = []
|
|
||||||
for mm_input in self.requests[req_id].mm_inputs:
|
|
||||||
if mm_input.get("image_grid_thw") is not None:
|
|
||||||
image_grid_thw.extend(
|
|
||||||
mm_input["image_grid_thw"].tolist())
|
|
||||||
if mm_input.get("video_grid_thw") is not None:
|
|
||||||
video_grid_thw.extend(
|
|
||||||
mm_input["video_grid_thw"].tolist())
|
|
||||||
|
|
||||||
hf_config = self.model_config.hf_config
|
|
||||||
|
|
||||||
self.requests[req_id].mrope_positions, \
|
|
||||||
self.requests[req_id].mrope_position_delta = \
|
|
||||||
MRotaryEmbedding.get_input_positions_tensor(
|
|
||||||
self.requests[req_id].prompt_token_ids,
|
|
||||||
image_grid_thw=image_grid_thw,
|
|
||||||
video_grid_thw=video_grid_thw,
|
|
||||||
image_token_id=hf_config.image_token_id,
|
|
||||||
video_token_id=hf_config.video_token_id,
|
|
||||||
vision_start_token_id=hf_config.vision_start_token_id,
|
|
||||||
vision_end_token_id=hf_config.vision_end_token_id,
|
|
||||||
spatial_merge_size=hf_config.vision_config.
|
|
||||||
spatial_merge_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
req_ids_to_add.append(req_id)
|
|
||||||
|
|
||||||
# Update the cached states of the resumed requests.
|
|
||||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
|
||||||
req_id = res_req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
|
|
||||||
req_state.block_ids = res_req_data.block_ids
|
|
||||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
|
||||||
req_ids_to_add.append(req_id)
|
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
|
||||||
# The smaller empty indices are filled first.
|
|
||||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
||||||
for req_id in req_ids_to_add:
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
if removed_req_indices:
|
|
||||||
# Fill the empty index.
|
|
||||||
req_index = removed_req_indices.pop()
|
|
||||||
else:
|
|
||||||
# Append to the end.
|
|
||||||
req_index = None
|
|
||||||
self.input_batch.add_request(req_state, req_index)
|
|
||||||
|
|
||||||
# Condense the batched states if there are empty indices.
|
|
||||||
if removed_req_indices:
|
|
||||||
self.input_batch.condense(removed_req_indices)
|
|
||||||
|
|
||||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
@ -611,6 +406,8 @@ class GPUModelRunner:
|
|||||||
return sampling_metadata
|
return sampling_metadata
|
||||||
|
|
||||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
|
assert self.model is not None
|
||||||
|
|
||||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
if not scheduled_encoder_inputs:
|
if not scheduled_encoder_inputs:
|
||||||
return
|
return
|
||||||
@ -698,15 +495,14 @@ class GPUModelRunner:
|
|||||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
self._update_states(scheduler_output)
|
assert self.model is not None
|
||||||
|
|
||||||
|
self.update_states(scheduler_output)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
@ -833,14 +629,15 @@ class GPUModelRunner:
|
|||||||
self.model_memory_usage / float(2**30))
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(
|
def dummy_run(
|
||||||
self,
|
self,
|
||||||
|
kv_caches,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
seq_len: Optional[int] = None,
|
||||||
|
exec_mode: Optional[ExecutionMode] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
model = self.model
|
assert self.model is not None
|
||||||
if kv_caches is None:
|
|
||||||
kv_caches = self.kv_caches
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
@ -851,7 +648,7 @@ class GPUModelRunner:
|
|||||||
positions = self.mrope_positions[:, :num_tokens] \
|
positions = self.mrope_positions[:, :num_tokens] \
|
||||||
if self.model_config.uses_mrope \
|
if self.model_config.uses_mrope \
|
||||||
else self.positions[:num_tokens]
|
else self.positions[:num_tokens]
|
||||||
hidden_states = model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
@ -861,6 +658,7 @@ class GPUModelRunner:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
|
assert self.model is not None
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
# it by reference, rather by specializing on the value `None`.
|
# it by reference, rather by specializing on the value `None`.
|
||||||
# the `dtype` argument does not matter, and we use `float32` as
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
@ -966,7 +764,7 @@ class GPUModelRunner:
|
|||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
|
hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
logits = logits[:self.max_num_tokens]
|
logits = logits[:self.max_num_tokens]
|
||||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||||
@ -992,8 +790,8 @@ class GPUModelRunner:
|
|||||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
self._dummy_run(num_tokens)
|
self.dummy_run(None, num_tokens)
|
||||||
self._dummy_run(num_tokens)
|
self.dummy_run(None, num_tokens)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
@ -1036,38 +834,3 @@ class GPUModelRunner:
|
|||||||
kv_caches,
|
kv_caches,
|
||||||
self.vllm_config.compilation_config.static_forward_context,
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
|
||||||
"""
|
|
||||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
||||||
Attention module in the static forward context.
|
|
||||||
Returns:
|
|
||||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
||||||
format. Layers that do not need KV cache are not included.
|
|
||||||
"""
|
|
||||||
|
|
||||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
|
||||||
kv_cache_spec: KVCacheSpec = {}
|
|
||||||
for layer_name, attn_module in forward_ctx.items():
|
|
||||||
# TODO: Support other attention modules, e.g., sliding window,
|
|
||||||
# cross-attention, MLA.
|
|
||||||
assert isinstance(attn_module, Attention)
|
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=attn_module.dtype,
|
|
||||||
)
|
|
||||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
||||||
AttentionType.ENCODER_ONLY):
|
|
||||||
# encoder-only attention does not need KV cache.
|
|
||||||
continue
|
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown attention type: {attn_module.attn_type}")
|
|
||||||
|
|
||||||
return kv_cache_spec
|
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
@ -15,20 +13,17 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils import GiB_bytes
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
from vllm.v1.worker.worker_base import WorkerBase, check_if_gpu_supports_dtype
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
|
||||||
|
|
||||||
|
class GPUWorker(WorkerBase):
|
||||||
class Worker:
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -38,46 +33,8 @@ class Worker:
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
):
|
):
|
||||||
|
super().__init__(vllm_config, local_rank, rank,
|
||||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
distributed_init_method)
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.model_config = vllm_config.model_config
|
|
||||||
self.cache_config = vllm_config.cache_config
|
|
||||||
self.lora_config = vllm_config.lora_config
|
|
||||||
self.load_config = vllm_config.load_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
|
||||||
self.device_config = vllm_config.device_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
||||||
self.observability_config = vllm_config.observability_config
|
|
||||||
|
|
||||||
self.parallel_config.rank = rank
|
|
||||||
self.local_rank = local_rank
|
|
||||||
self.rank = rank
|
|
||||||
self.distributed_init_method = distributed_init_method
|
|
||||||
|
|
||||||
if self.model_config.trust_remote_code:
|
|
||||||
# note: lazy import to avoid importing torch before initializing
|
|
||||||
from vllm.utils import init_cached_hf_modules
|
|
||||||
init_cached_hf_modules()
|
|
||||||
|
|
||||||
# Torch profiler. Enabled and configured through env vars:
|
|
||||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
||||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
||||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
||||||
torch_profiler_trace_dir)
|
|
||||||
self.profiler = torch.profiler.profile(
|
|
||||||
activities=[
|
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
with_stack=True,
|
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
||||||
torch_profiler_trace_dir, use_gzip=True))
|
|
||||||
else:
|
|
||||||
self.profiler = None
|
|
||||||
|
|
||||||
def sleep(self, level: int = 1) -> None:
|
def sleep(self, level: int = 1) -> None:
|
||||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||||
@ -97,31 +54,39 @@ class Worker:
|
|||||||
allocator.wake_up()
|
allocator.wake_up()
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
if self.device_config.device.type == "cuda":
|
assert self.device_config.device.type == "cuda"
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
|
||||||
# the synchronization point. This causes the memory usage to grow
|
|
||||||
# as the number of all_reduce calls increases. This env var disables
|
|
||||||
# this behavior.
|
|
||||||
# Related issue:
|
|
||||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
||||||
|
|
||||||
# This env var set by Ray causes exceptions with graph building.
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
# the synchronization point. This causes the memory usage to grow
|
||||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
# as the number of all_reduce calls increases. This env var disables
|
||||||
torch.cuda.set_device(self.device)
|
# this behavior.
|
||||||
|
# Related issue:
|
||||||
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
# the synchronization point. This causes the memory usage to grow
|
||||||
|
# as the number of all_reduce calls increases. This env var disables
|
||||||
|
# this behavior.
|
||||||
|
# Related issue:
|
||||||
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# This env var set by Ray causes exceptions with graph building.
|
||||||
|
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||||
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||||
|
torch.cuda.set_device(self.device)
|
||||||
|
|
||||||
|
check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Not support device type: {self.device_config.device}")
|
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
init_cuda_worker_distributed_environment(self.parallel_config,
|
||||||
self.distributed_init_method,
|
self.rank,
|
||||||
self.local_rank)
|
self.distributed_init_method,
|
||||||
|
self.local_rank)
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
@ -139,6 +104,7 @@ class Worker:
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
context = nullcontext()
|
context = nullcontext()
|
||||||
with context:
|
with context:
|
||||||
|
assert self.model_runner is not None
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -160,6 +126,7 @@ class Worker:
|
|||||||
_, total_gpu_memory = torch.cuda.mem_get_info()
|
_, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# of the model.
|
# of the model.
|
||||||
|
assert self.model_runner is not None
|
||||||
self.model_runner.profile_run()
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||||
@ -191,9 +158,6 @@ class Worker:
|
|||||||
|
|
||||||
return int(available_kv_cache_memory)
|
return int(available_kv_cache_memory)
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
|
||||||
return self.model_runner.get_kv_cache_spec()
|
|
||||||
|
|
||||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||||
if self.vllm_config.model_config.enable_sleep_mode:
|
if self.vllm_config.model_config.enable_sleep_mode:
|
||||||
@ -203,9 +167,12 @@ class Worker:
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
context = nullcontext()
|
context = nullcontext()
|
||||||
with context:
|
with context:
|
||||||
|
assert self.model_runner is not None
|
||||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
def compile_or_warm_up_model(self) -> None:
|
def compile_or_warm_up_model(self) -> None:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
|
||||||
# warm up sizes that are not in cudagraph capture sizes,
|
# warm up sizes that are not in cudagraph capture sizes,
|
||||||
# but users still want to compile for better performance,
|
# but users still want to compile for better performance,
|
||||||
# e.g. for the max-num-batched token size in chunked prefill.
|
# e.g. for the max-num-batched token size in chunked prefill.
|
||||||
@ -217,44 +184,32 @@ class Worker:
|
|||||||
]
|
]
|
||||||
for size in sorted(warmup_sizes, reverse=True):
|
for size in sorted(warmup_sizes, reverse=True):
|
||||||
logger.info("Compile and warming up model for size %d", size)
|
logger.info("Compile and warming up model for size %d", size)
|
||||||
self.model_runner._dummy_run(size)
|
self.model_runner.dummy_run(None, size)
|
||||||
|
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
|
||||||
return self.model_runner.get_model()
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> Optional[ModelRunnerOutput]:
|
) -> Optional[ModelRunnerOutput]:
|
||||||
|
assert self.model_runner is not None
|
||||||
output = self.model_runner.execute_model(scheduler_output)
|
output = self.model_runner.execute_model(scheduler_output)
|
||||||
return output if self.rank == 0 else None
|
return output if self.rank == 0 else None
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
|
||||||
if self.profiler is None:
|
|
||||||
raise RuntimeError("Profiler is not enabled.")
|
|
||||||
if is_start:
|
|
||||||
self.profiler.start()
|
|
||||||
else:
|
|
||||||
self.profiler.stop()
|
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def init_cuda_worker_distributed_environment(
|
||||||
# worker will always be healthy as long as it's running.
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: Optional[str] = None,
|
distributed_init_method: Optional[str] = None,
|
||||||
local_rank: int = -1,
|
local_rank: int = -1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
|
|
||||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||||
|
|
||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
@ -264,21 +219,22 @@ def init_worker_distributed_environment(
|
|||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
# TODO: Remove
|
||||||
# Check if the GPU supports the dtype.
|
# def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
# # Check if the GPU supports the dtype.
|
||||||
if not current_platform.has_device_capability(80):
|
# if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||||
capability = current_platform.get_device_capability()
|
# if not current_platform.has_device_capability(80):
|
||||||
gpu_name = current_platform.get_device_name()
|
# capability = current_platform.get_device_capability()
|
||||||
|
# gpu_name = current_platform.get_device_name()
|
||||||
|
|
||||||
if capability is None:
|
# if capability is None:
|
||||||
compute_str = "does not have a compute capability"
|
# compute_str = "does not have a compute capability"
|
||||||
else:
|
# else:
|
||||||
version_str = capability.as_version_str()
|
# version_str = capability.as_version_str()
|
||||||
compute_str = f"has compute capability {version_str}"
|
# compute_str = f"has compute capability {version_str}"
|
||||||
|
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"Bfloat16 is only supported on GPUs with compute capability "
|
# "Bfloat16 is only supported on GPUs with compute capability "
|
||||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
# f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||||
"You can use float16 instead by explicitly setting the"
|
# "You can use float16 instead by explicitly setting the"
|
||||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
# "`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
307
vllm/v1/worker/model_runner_base.py
Normal file
307
vllm/v1/worker/model_runner_base.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
import enum
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.sampling_params import SamplingType
|
||||||
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||||
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheSpec)
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionMode(enum.Enum):
|
||||||
|
PREFILL = enum.auto()
|
||||||
|
DECODE = enum.auto()
|
||||||
|
PREFIX_PREFILL = enum.auto()
|
||||||
|
|
||||||
|
def is_prefill(self) -> bool:
|
||||||
|
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunnerBase:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.cache_config = vllm_config.cache_config
|
||||||
|
self.lora_config = vllm_config.lora_config
|
||||||
|
self.load_config = vllm_config.load_config
|
||||||
|
self.parallel_config = vllm_config.parallel_config
|
||||||
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||||
|
self.observability_config = vllm_config.observability_config
|
||||||
|
self.device_config = vllm_config.device_config
|
||||||
|
|
||||||
|
model_config = self.model_config
|
||||||
|
cache_config = self.cache_config
|
||||||
|
scheduler_config = self.scheduler_config
|
||||||
|
parallel_config = self.parallel_config
|
||||||
|
self.device = device
|
||||||
|
self.pin_memory = is_pin_memory_available()
|
||||||
|
self.dtype = self.model_config.dtype
|
||||||
|
|
||||||
|
self.is_multimodal_model = model_config.is_multimodal_model
|
||||||
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
|
self.block_size = cache_config.block_size
|
||||||
|
self.max_model_len = model_config.max_model_len
|
||||||
|
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||||
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Model-related.
|
||||||
|
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||||
|
parallel_config, LayerBlockType.attention)
|
||||||
|
self.num_query_heads = model_config.get_num_attention_heads(
|
||||||
|
parallel_config)
|
||||||
|
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
self.head_size = model_config.get_head_size()
|
||||||
|
self.hidden_size = model_config.get_hidden_size()
|
||||||
|
|
||||||
|
self.model: Optional[nn.Module] = None
|
||||||
|
|
||||||
|
# Persistent batch.
|
||||||
|
self.input_batch = InputBatch(
|
||||||
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request states.
|
||||||
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
|
|
||||||
|
# Multi-modal data support
|
||||||
|
self.input_registry = INPUT_REGISTRY
|
||||||
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
# NOTE: Initialized input mapper is only used for processing dummy
|
||||||
|
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||||
|
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||||
|
self.mm_input_mapper_profiling.use_cache = False
|
||||||
|
|
||||||
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||||
|
model_config=self.model_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
)
|
||||||
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
|
||||||
|
# req_id -> (input_id -> encoder_output)
|
||||||
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
|
# Remove stopped requests from the cached states.
|
||||||
|
# Keep the states of the pre-empted requests.
|
||||||
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
|
self.requests.pop(req_id, None)
|
||||||
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
|
# Free the cached encoder outputs.
|
||||||
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||||
|
encoder_outputs = self.encoder_cache.get(req_id)
|
||||||
|
if encoder_outputs is not None:
|
||||||
|
encoder_outputs.pop(input_id, None)
|
||||||
|
if not encoder_outputs:
|
||||||
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
|
# Remove the requests from the persistent batch.
|
||||||
|
stopped_req_ids = set().union(
|
||||||
|
scheduler_output.preempted_req_ids,
|
||||||
|
scheduler_output.finished_req_ids,
|
||||||
|
)
|
||||||
|
removed_req_indices: List[int] = []
|
||||||
|
for req_id in stopped_req_ids:
|
||||||
|
req_index = self.input_batch.remove_request(req_id)
|
||||||
|
if req_index is not None:
|
||||||
|
removed_req_indices.append(req_index)
|
||||||
|
|
||||||
|
# Update the states of the running requests.
|
||||||
|
for req_data in scheduler_output.scheduled_running_reqs:
|
||||||
|
req_id = req_data.req_id
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
|
||||||
|
# Update the num_computed_tokens.
|
||||||
|
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||||
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
|
req_data.num_computed_tokens)
|
||||||
|
|
||||||
|
# Update the block table.
|
||||||
|
num_new_blocks = len(req_data.new_block_ids)
|
||||||
|
if num_new_blocks == 0:
|
||||||
|
continue
|
||||||
|
start_index = len(req_state.block_ids)
|
||||||
|
req_state.block_ids.extend(req_data.new_block_ids)
|
||||||
|
self.input_batch.block_table.append_row(req_index, start_index,
|
||||||
|
req_data.new_block_ids)
|
||||||
|
|
||||||
|
req_ids_to_add: List[str] = []
|
||||||
|
# Add new requests to the cached states.
|
||||||
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||||
|
req_id = new_req_data.req_id
|
||||||
|
sampling_params = new_req_data.sampling_params
|
||||||
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
|
generator = torch.Generator(device=self.device)
|
||||||
|
generator.manual_seed(sampling_params.seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
self.requests[req_id] = CachedRequestState(
|
||||||
|
req_id=req_id,
|
||||||
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
|
prompt=new_req_data.prompt,
|
||||||
|
mm_inputs=new_req_data.mm_inputs,
|
||||||
|
mm_positions=new_req_data.mm_positions,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
generator=generator,
|
||||||
|
block_ids=new_req_data.block_ids,
|
||||||
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
|
output_token_ids=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
|
if self.model_config.uses_mrope:
|
||||||
|
image_grid_thw = []
|
||||||
|
video_grid_thw = []
|
||||||
|
for mm_input in self.requests[req_id].mm_inputs:
|
||||||
|
if mm_input.get("image_grid_thw") is not None:
|
||||||
|
image_grid_thw.extend(
|
||||||
|
mm_input["image_grid_thw"].tolist())
|
||||||
|
if mm_input.get("video_grid_thw") is not None:
|
||||||
|
video_grid_thw.extend(
|
||||||
|
mm_input["video_grid_thw"].tolist())
|
||||||
|
|
||||||
|
hf_config = self.model_config.hf_config
|
||||||
|
|
||||||
|
self.requests[req_id].mrope_positions, \
|
||||||
|
self.requests[req_id].mrope_position_delta = \
|
||||||
|
MRotaryEmbedding.get_input_positions_tensor(
|
||||||
|
self.requests[req_id].prompt_token_ids,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
image_token_id=hf_config.image_token_id,
|
||||||
|
video_token_id=hf_config.video_token_id,
|
||||||
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
|
vision_end_token_id=hf_config.vision_end_token_id,
|
||||||
|
spatial_merge_size=hf_config.vision_config.
|
||||||
|
spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
|
# Update the cached states of the resumed requests.
|
||||||
|
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||||
|
req_id = res_req_data.req_id
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
|
req_state.block_ids = res_req_data.block_ids
|
||||||
|
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||||
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
|
# Add the new or resumed requests to the persistent batch.
|
||||||
|
# The smaller empty indices are filled first.
|
||||||
|
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||||
|
for req_id in req_ids_to_add:
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
if removed_req_indices:
|
||||||
|
# Fill the empty index.
|
||||||
|
req_index = removed_req_indices.pop()
|
||||||
|
else:
|
||||||
|
# Append to the end.
|
||||||
|
req_index = None
|
||||||
|
self.input_batch.add_request(req_state, req_index)
|
||||||
|
|
||||||
|
# Condense the batched states if there are empty indices.
|
||||||
|
if removed_req_indices:
|
||||||
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
assert self.model is not None
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||||
|
"""
|
||||||
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||||
|
Attention module in the static forward context.
|
||||||
|
Returns:
|
||||||
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||||
|
format. Layers that do not need KV cache are not included.
|
||||||
|
"""
|
||||||
|
|
||||||
|
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||||
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
|
kv_cache_spec: KVCacheSpec = {}
|
||||||
|
for layer_name, attn_module in forward_ctx.items():
|
||||||
|
# TODO: Support other attention modules, e.g., sliding window,
|
||||||
|
# cross-attention, MLA.
|
||||||
|
assert isinstance(attn_module, Attention)
|
||||||
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=attn_module.dtype,
|
||||||
|
)
|
||||||
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
|
AttentionType.ENCODER_ONLY):
|
||||||
|
# encoder-only attention does not need KV cache.
|
||||||
|
continue
|
||||||
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown attention type: {attn_module.attn_type}")
|
||||||
|
|
||||||
|
return kv_cache_spec
|
||||||
|
|
||||||
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def dummy_run(
|
||||||
|
self,
|
||||||
|
kv_caches,
|
||||||
|
num_tokens: int,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
exec_mode: Optional[ExecutionMode] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def capture_model(self) -> None:
|
||||||
|
raise NotImplementedError()
|
729
vllm/v1/worker/tpu_model_runner.py
Normal file
729
vllm/v1/worker/tpu_model_runner.py
Normal file
@ -0,0 +1,729 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
# TPU XLA related
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||||
|
PallasMetadata)
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.utils import bind_kv_cache
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||||
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||||
|
_PAD_SLOT_ID = 1_000_000_000
|
||||||
|
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||||
|
_ENABLE_TOP_P = False
|
||||||
|
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||||
|
# This can significantly affect the performance if too large.
|
||||||
|
_MAX_NUM_SAMPLES = 128
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptInputData:
|
||||||
|
|
||||||
|
req_ids: List
|
||||||
|
prompt_lens: List
|
||||||
|
input_tokens: List
|
||||||
|
input_positions: List
|
||||||
|
attn_metadata: List
|
||||||
|
|
||||||
|
def zipped(self):
|
||||||
|
return zip(self.req_ids, self.prompt_lens, self.input_tokens,
|
||||||
|
self.input_positions, self.attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodeInputData:
|
||||||
|
req_ids: List
|
||||||
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
attn_metadata: Optional[PallasMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TPUModelRunner(ModelRunnerBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, device)
|
||||||
|
|
||||||
|
# Persistent batch.
|
||||||
|
self.input_batch = InputBatch(
|
||||||
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request states.
|
||||||
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
|
|
||||||
|
# KV caches for forward pass
|
||||||
|
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
|
|
||||||
|
# Used to initialize positions for the individual prefills
|
||||||
|
self.prefill_input_positions = torch.tensor(range(self.max_model_len),
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32).reshape(
|
||||||
|
1, -1)
|
||||||
|
|
||||||
|
def _prepare_prompt_inputs(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> PromptInputData:
|
||||||
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
assert total_num_scheduled_tokens > 0
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
assert num_reqs > 0
|
||||||
|
|
||||||
|
req_ids = []
|
||||||
|
prompt_lens = []
|
||||||
|
input_tokens_list = []
|
||||||
|
input_positions_list = []
|
||||||
|
attn_metadata_list = []
|
||||||
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||||
|
assert req_id is not None
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
|
req_id]
|
||||||
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
|
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||||
|
|
||||||
|
# Detect whether this is a prompt (can be full or chunked)
|
||||||
|
if num_computed_tokens >= num_prompt_tokens:
|
||||||
|
# This is a decode => Skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This is a prompt
|
||||||
|
req_ids.append(req_id)
|
||||||
|
|
||||||
|
# Prompt len
|
||||||
|
prompt_len = num_scheduled_tokens
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||||
|
assert padded_prompt_len <= self.max_model_len
|
||||||
|
|
||||||
|
# Seq len
|
||||||
|
seq_len = num_computed_tokens + prompt_len
|
||||||
|
|
||||||
|
# Input tokens
|
||||||
|
input_tokens = torch.zeros((1, padded_prompt_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_tokens[:, :prompt_len] = torch.from_numpy(
|
||||||
|
self.input_batch.token_ids_cpu[req_index,
|
||||||
|
num_computed_tokens:seq_len])
|
||||||
|
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||||
|
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
|
||||||
|
# input_tokens[:, prompt_len:] = 0
|
||||||
|
input_tokens_list.append(input_tokens.to(self.device))
|
||||||
|
|
||||||
|
# Input positions
|
||||||
|
input_positions = torch.zeros((1, padded_prompt_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_positions[:, :
|
||||||
|
prompt_len] = self.prefill_input_positions[:,
|
||||||
|
num_computed_tokens:
|
||||||
|
seq_len]
|
||||||
|
# input_positions[:, prompt_len:] = 0
|
||||||
|
input_positions_list.append(input_positions.to(self.device))
|
||||||
|
|
||||||
|
# Slot mapping
|
||||||
|
block_table_cpu_tensor = \
|
||||||
|
self.input_batch.block_table.get_cpu_tensor()
|
||||||
|
block_numbers = block_table_cpu_tensor[req_index,
|
||||||
|
input_positions //
|
||||||
|
self.block_size].reshape(
|
||||||
|
1, -1)
|
||||||
|
|
||||||
|
block_offsets = input_positions % self.block_size
|
||||||
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||||
|
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
|
||||||
|
slot_mapping = slot_mapping.long()
|
||||||
|
|
||||||
|
# Block table
|
||||||
|
block_table = None
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
block_table = block_table_cpu_tensor[req_index].unsqueeze(0)
|
||||||
|
block_table = block_table.to(self.device)
|
||||||
|
|
||||||
|
# Context len
|
||||||
|
context_len = 0
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
context_len = seq_len
|
||||||
|
context_lens = torch.tensor([context_len],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# Effective query len
|
||||||
|
effective_query_lens = torch.tensor([prompt_len],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# Attn metadata
|
||||||
|
attn_metadata_list.append(
|
||||||
|
PallasMetadata(
|
||||||
|
num_prefills=1,
|
||||||
|
num_prefill_tokens=0, # NOTE: This is not used.
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping.to(self.device),
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
block_tables=block_table,
|
||||||
|
context_lens=context_lens.to(self.device),
|
||||||
|
effective_query_lens=effective_query_lens.to(self.device),
|
||||||
|
))
|
||||||
|
|
||||||
|
# TODO: Remove this
|
||||||
|
# if num_computed_tokens > 0:
|
||||||
|
# print("-------------------")
|
||||||
|
# print("input_tokens.shape = {}".format(input_tokens.shape))
|
||||||
|
# print("input_positions.shape = {}".format(
|
||||||
|
# input_positions.shape))
|
||||||
|
# print("slot_mapping.shape = {}".format(slot_mapping.shape))
|
||||||
|
# print("block_table.shape = {}".format(block_table.shape))
|
||||||
|
# print("context_lens.shape = {} data = {}".format(
|
||||||
|
# context_lens.shape, context_lens))
|
||||||
|
# print("effective_query_lens.shape = {} data = {}".format(
|
||||||
|
# effective_query_lens.shape, effective_query_lens))
|
||||||
|
|
||||||
|
return PromptInputData(
|
||||||
|
req_ids=req_ids,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
input_tokens=input_tokens_list,
|
||||||
|
input_positions=input_positions_list,
|
||||||
|
attn_metadata=attn_metadata_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_decode_inputs(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> DecodeInputData:
|
||||||
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
assert total_num_scheduled_tokens > 0
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
assert num_reqs > 0
|
||||||
|
|
||||||
|
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
|
||||||
|
|
||||||
|
req_ids = []
|
||||||
|
req_indices = []
|
||||||
|
input_tokens = []
|
||||||
|
input_positions = []
|
||||||
|
slot_mapping = []
|
||||||
|
context_lens = []
|
||||||
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||||
|
assert req_id is not None
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
|
req_id]
|
||||||
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
|
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||||
|
|
||||||
|
# Detect whether this is a decode
|
||||||
|
if num_computed_tokens < num_prompt_tokens:
|
||||||
|
# This is a prompt => Skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This is a decode
|
||||||
|
req_ids.append(req_id)
|
||||||
|
req_indices.append(req_index)
|
||||||
|
|
||||||
|
# Seq len
|
||||||
|
seq_len = num_computed_tokens + num_scheduled_tokens
|
||||||
|
|
||||||
|
# Sanity check decode
|
||||||
|
assert num_scheduled_tokens == 1
|
||||||
|
assert seq_len == req_state.num_tokens
|
||||||
|
|
||||||
|
# Input token
|
||||||
|
input_tokens.append([
|
||||||
|
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
|
||||||
|
])
|
||||||
|
|
||||||
|
# Position
|
||||||
|
input_positions.append([num_computed_tokens])
|
||||||
|
|
||||||
|
# Slot mapping
|
||||||
|
block_number = block_table_cpu_tensor[req_index,
|
||||||
|
num_computed_tokens //
|
||||||
|
self.block_size]
|
||||||
|
block_offset = num_computed_tokens % self.block_size
|
||||||
|
slot_id = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append([slot_id])
|
||||||
|
|
||||||
|
# Context len
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
|
||||||
|
# Compute padding
|
||||||
|
batch_size = len(input_tokens)
|
||||||
|
padded_batch_size = _get_padded_batch_size(batch_size)
|
||||||
|
num_padding = padded_batch_size - batch_size
|
||||||
|
|
||||||
|
# Add padding
|
||||||
|
input_tokens.extend([[0]] * num_padding)
|
||||||
|
input_positions.extend([[0]] * num_padding)
|
||||||
|
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
|
||||||
|
context_lens.extend([0] * num_padding)
|
||||||
|
req_indices.extend([0] * num_padding)
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_positions_tensor = torch.tensor(input_positions,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu")
|
||||||
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
block_tables_tensor = block_table_cpu_tensor[req_indices]
|
||||||
|
|
||||||
|
# Attn metadata
|
||||||
|
attn_metadata = PallasMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=padded_batch_size,
|
||||||
|
slot_mapping=slot_mapping_tensor.to(self.device),
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
block_tables=block_tables_tensor.to(self.device),
|
||||||
|
context_lens=context_lens_tensor.to(self.device),
|
||||||
|
effective_query_lens=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DecodeInputData(
|
||||||
|
req_ids=req_ids,
|
||||||
|
input_tokens=input_tokens_tensor.to(self.device),
|
||||||
|
input_positions=input_positions_tensor.to(self.device),
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
# Update cached state
|
||||||
|
self.update_states(scheduler_output)
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
prompt_data = self._prepare_prompt_inputs(scheduler_output)
|
||||||
|
decode_data = self._prepare_decode_inputs(scheduler_output)
|
||||||
|
|
||||||
|
# Init
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
assert num_reqs > 0
|
||||||
|
sampled_token_ids_list = [0] * num_reqs
|
||||||
|
|
||||||
|
# Run decodes (a single batch)
|
||||||
|
if len(decode_data.req_ids) > 0:
|
||||||
|
# Forward
|
||||||
|
with set_forward_context(decode_data.attn_metadata,
|
||||||
|
self.vllm_config):
|
||||||
|
assert self.model is not None
|
||||||
|
selected_token_ids = self.model(decode_data.input_tokens,
|
||||||
|
decode_data.input_positions,
|
||||||
|
decode_data.attn_metadata,
|
||||||
|
self.kv_caches)
|
||||||
|
|
||||||
|
# Transfer sampled tokens from TPU to CPU
|
||||||
|
selected_token_ids_list = selected_token_ids.cpu().tolist()
|
||||||
|
|
||||||
|
# Update cached state
|
||||||
|
for i, req_id in enumerate(decode_data.req_ids):
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
|
seq_len = (req_state.num_computed_tokens +
|
||||||
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
|
|
||||||
|
token_id = selected_token_ids_list[i]
|
||||||
|
|
||||||
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
|
req_state.output_token_ids.append(token_id)
|
||||||
|
|
||||||
|
sampled_token_ids_list[req_index] = token_id
|
||||||
|
|
||||||
|
# Run each prompt
|
||||||
|
for (req_id, prompt_len, input_tokens, input_positions,
|
||||||
|
attn_metadata) in prompt_data.zipped():
|
||||||
|
assert req_id is not None
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
|
assert self.model is not None
|
||||||
|
selected_token_ids = self.model(input_tokens, input_positions,
|
||||||
|
attn_metadata, self.kv_caches)
|
||||||
|
|
||||||
|
seq_len = (req_state.num_computed_tokens +
|
||||||
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
|
if seq_len >= len(req_state.prompt_token_ids):
|
||||||
|
# Transfer sampled tokens from TPU to CPU
|
||||||
|
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
||||||
|
sampled_token_ids_list[req_index] = token_id
|
||||||
|
|
||||||
|
# Update cached state
|
||||||
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
|
req_state.output_token_ids.append(token_id)
|
||||||
|
|
||||||
|
# Get req_ids
|
||||||
|
assert all(
|
||||||
|
req_id is not None for req_id in
|
||||||
|
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||||
|
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||||
|
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
|
sampled_token_ids=sampled_token_ids_list,
|
||||||
|
logprob_token_ids_cpu=None,
|
||||||
|
logprobs_cpu=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_runner_output
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
||||||
|
# process, the ranks can be different from the ranks internally assigned
|
||||||
|
# by the xm runtime. Therefore, there is a mismatch in the rank
|
||||||
|
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
||||||
|
# This is not a problem in linear layers because all-reduce is
|
||||||
|
# rank-agnostic. However, it matters for all-gather as the ranks
|
||||||
|
# determine the order of concatenating the output tensors.
|
||||||
|
# As a workaround, we use the xm's rank assignment only when loading
|
||||||
|
# the embedding weights.
|
||||||
|
xm_tp_rank = xr.global_ordinal()
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||||
|
"get_tensor_model_parallel_rank",
|
||||||
|
return_value=xm_tp_rank):
|
||||||
|
model = get_model(vllm_config=self.vllm_config)
|
||||||
|
model = model.eval()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
model = ModelWrapperV1(model)
|
||||||
|
self.model = torch.compile(model,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False)
|
||||||
|
|
||||||
|
def dummy_run(
|
||||||
|
self,
|
||||||
|
kv_caches,
|
||||||
|
num_tokens: int,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
exec_mode: Optional[ExecutionMode] = None,
|
||||||
|
) -> None:
|
||||||
|
assert seq_len is not None
|
||||||
|
assert exec_mode is not None
|
||||||
|
|
||||||
|
exec_mode = ExecutionMode(exec_mode)
|
||||||
|
if exec_mode.is_prefill():
|
||||||
|
seq_len = (seq_len + 15) // 16 * 16
|
||||||
|
token_ids = torch.zeros((num_tokens, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
position_ids = torch.zeros((num_tokens, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.zeros((num_tokens, seq_len),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device)
|
||||||
|
if exec_mode == ExecutionMode.PREFILL:
|
||||||
|
attn_metadata = PallasMetadata(
|
||||||
|
num_prefills=num_tokens,
|
||||||
|
num_prefill_tokens=num_tokens * seq_len,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
block_tables=None,
|
||||||
|
context_lens=None,
|
||||||
|
effective_query_lens=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
context_lens = torch.ones((num_tokens, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
block_tables = torch.zeros(
|
||||||
|
(num_tokens, self.max_num_blocks_per_req),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
effective_query_lens = torch.ones_like(context_lens)
|
||||||
|
|
||||||
|
attn_metadata = PallasMetadata(
|
||||||
|
num_prefills=num_tokens,
|
||||||
|
num_prefill_tokens=num_tokens * seq_len,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
effective_query_lens=effective_query_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert seq_len == 1
|
||||||
|
token_ids = torch.zeros((num_tokens, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
position_ids = torch.zeros((num_tokens, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.zeros((num_tokens, seq_len),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device)
|
||||||
|
block_tables = torch.zeros(
|
||||||
|
(num_tokens, self.max_num_blocks_per_req),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
context_lens = torch.ones((num_tokens, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
attn_metadata = PallasMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=num_tokens * seq_len,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||||
|
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||||
|
# overhead by reusing the FX graph for different shapes.
|
||||||
|
# However, the XLA graph will still require static shapes and needs to
|
||||||
|
# be re-compiled for every different shapes. This overhead is inevitable
|
||||||
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||||
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||||
|
if exec_mode.is_prefill():
|
||||||
|
# Prefll
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||||
|
else:
|
||||||
|
# Decode
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||||
|
|
||||||
|
# TODO: Remove the attn_metadata above
|
||||||
|
with set_forward_context(None, self.vllm_config):
|
||||||
|
assert self.model is not None
|
||||||
|
self.model(token_ids, position_ids, None, kv_caches)
|
||||||
|
|
||||||
|
def capture_model(self) -> None:
|
||||||
|
"""Compile the model."""
|
||||||
|
|
||||||
|
logger.info("Compiling the model with different input shapes.")
|
||||||
|
|
||||||
|
# Capture prefill shapes
|
||||||
|
start = time.perf_counter()
|
||||||
|
for batch_size in [1]:
|
||||||
|
seq_len = 16
|
||||||
|
while True:
|
||||||
|
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
||||||
|
ExecutionMode.PREFILL)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info(" -- batch_size: %d, seq_len: %d", batch_size,
|
||||||
|
seq_len)
|
||||||
|
|
||||||
|
if seq_len >= self.model_config.max_model_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Move to next seq_len
|
||||||
|
seq_len = seq_len * 2
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info("Compilation for prefill shapes is done in %.2f [secs].",
|
||||||
|
end - start)
|
||||||
|
|
||||||
|
# Capture decode shapes.
|
||||||
|
start = time.time()
|
||||||
|
seq_len = 1
|
||||||
|
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||||
|
while True:
|
||||||
|
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
||||||
|
ExecutionMode.DECODE)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d",
|
||||||
|
batch_size, seq_len,
|
||||||
|
self.scheduler_config.max_num_seqs)
|
||||||
|
|
||||||
|
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
logger.info("Compilation for decode shapes is done in %.2f [secs].",
|
||||||
|
end - start)
|
||||||
|
|
||||||
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
|
"""
|
||||||
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
|
Args:
|
||||||
|
kv_cache_config: Configuration for the KV cache, including the KV
|
||||||
|
cache size of each layer
|
||||||
|
"""
|
||||||
|
if len(kv_cache_config.groups) > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Hybrid models with more than one KV cache type are not "
|
||||||
|
"supported yet.")
|
||||||
|
|
||||||
|
kv_caches: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||||
|
tensor_config = kv_cache_config.tensors[layer_name]
|
||||||
|
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||||
|
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
||||||
|
if isinstance(layer_spec, FullAttentionSpec):
|
||||||
|
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
||||||
|
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
||||||
|
layer_spec.head_size)
|
||||||
|
dtype = layer_spec.dtype
|
||||||
|
|
||||||
|
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||||
|
|
||||||
|
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
bind_kv_cache(
|
||||||
|
kv_caches,
|
||||||
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
|
self.kv_caches)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapperV1(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||||
|
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||||
|
attn_metadata: The Pallas attention metadata.
|
||||||
|
input_lens: The actual input lengths of shape [batch_size].
|
||||||
|
t: The sampling temperature of shape [batch_size].
|
||||||
|
p: The top-p probability of shape [batch_size].
|
||||||
|
num_samples: Number of samples to draw from each logits vector.
|
||||||
|
kv_caches: The key and value caches. They can be None during the
|
||||||
|
memory profiling at initialization.
|
||||||
|
"""
|
||||||
|
# Skip this in memory profiling at initialization.
|
||||||
|
if attn_metadata is not None:
|
||||||
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||||
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||||
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||||
|
# work, we need to flatten the first three dimensions and modify
|
||||||
|
# the slot_mapping accordingly.
|
||||||
|
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
slot_mapping = slot_mapping.flatten()
|
||||||
|
head_indicies = torch.arange(0,
|
||||||
|
num_kv_heads,
|
||||||
|
device=slot_mapping.device,
|
||||||
|
dtype=slot_mapping.dtype)
|
||||||
|
head_indicies *= block_size * num_blocks
|
||||||
|
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||||
|
-1, num_kv_heads)
|
||||||
|
slot_mapping = slot_mapping + head_indicies.view(1, -1)
|
||||||
|
slot_mapping = slot_mapping.flatten()
|
||||||
|
attn_metadata.slot_mapping = slot_mapping
|
||||||
|
|
||||||
|
assert self.model is not None
|
||||||
|
hidden_states = self.model(
|
||||||
|
token_ids,
|
||||||
|
position_ids,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.flatten(0, 1)
|
||||||
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
|
||||||
|
# Greedy sampling.
|
||||||
|
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
|
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
|
||||||
|
return argmax_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_prefill_len(x: int) -> int:
|
||||||
|
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||||
|
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||||
|
# multiple of 16. This is also good for performance.
|
||||||
|
if x <= 16:
|
||||||
|
return 16
|
||||||
|
return 1 << (x - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_batch_size(batch_size: int) -> int:
|
||||||
|
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
||||||
|
# To meet this requirement in the simplest way, we set the minimal batch
|
||||||
|
# size to 8.
|
||||||
|
if batch_size <= 8:
|
||||||
|
return 8
|
||||||
|
else:
|
||||||
|
return ((batch_size + 15) // 16) * 16
|
141
vllm/v1/worker/tpu_worker.py
Normal file
141
vllm/v1/worker/tpu_worker.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
"""A TPU worker class."""
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||||
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUWorker(WorkerBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, local_rank, rank,
|
||||||
|
distributed_init_method)
|
||||||
|
|
||||||
|
def init_device(self):
|
||||||
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
torch.set_default_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
init_tpu_worker_distributed_environment(self.parallel_config,
|
||||||
|
self.rank,
|
||||||
|
self.distributed_init_method,
|
||||||
|
self.local_rank)
|
||||||
|
|
||||||
|
# Device initialization should happen after initializing
|
||||||
|
# the distributed runtime.
|
||||||
|
self.device = xm.xla_device()
|
||||||
|
self.device_config.device = self.device
|
||||||
|
|
||||||
|
# Set random seed.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
xm.set_rng_state(self.model_config.seed, self.device)
|
||||||
|
|
||||||
|
# Increase the cache size limit, which is the maximum number of
|
||||||
|
# dynamo graphs that can be compiled.
|
||||||
|
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||||
|
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||||
|
torch._dynamo.config.cache_size_limit = 128
|
||||||
|
# Use persistent cache to avoid XLA recompilation.
|
||||||
|
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||||
|
# can have slightly different XLA graphs.
|
||||||
|
world_size = self.parallel_config.world_size
|
||||||
|
rank = xr.global_ordinal()
|
||||||
|
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||||
|
f"tp{world_size}_rank{rank}")
|
||||||
|
xr.initialize_cache(per_rank_path, readonly=False)
|
||||||
|
|
||||||
|
# Init ModelRunner here, so that we have access to self.device.
|
||||||
|
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
|
||||||
|
|
||||||
|
def determine_available_memory(self) -> int:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||||
|
device=self.device),
|
||||||
|
torch.tensor([], dtype=torch.float32,
|
||||||
|
device=self.device))
|
||||||
|
for _ in range(num_layers)]
|
||||||
|
|
||||||
|
self.model_runner.dummy_run(
|
||||||
|
kv_caches,
|
||||||
|
num_tokens=1,
|
||||||
|
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||||
|
exec_mode=ExecutionMode.PREFILL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Synchronize before measuring the memory usage.
|
||||||
|
xm.wait_device_ops()
|
||||||
|
|
||||||
|
# Get the maximum amount of memory used by the model weights and
|
||||||
|
# intermediate activations.
|
||||||
|
m = xm.get_memory_info(self.device)
|
||||||
|
total_memory_size = m["bytes_limit"]
|
||||||
|
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
|
||||||
|
|
||||||
|
# Calculate the TPU KV cache size based on profiling.
|
||||||
|
usable_memory_size = int(total_memory_size *
|
||||||
|
self.cache_config.gpu_memory_utilization)
|
||||||
|
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||||
|
|
||||||
|
return int(tpu_kv_cache_bytes)
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> Optional[ModelRunnerOutput]:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
output = self.model_runner.execute_model(scheduler_output)
|
||||||
|
return output if self.rank == 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def init_tpu_worker_distributed_environment(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: Optional[str] = None,
|
||||||
|
local_rank: int = -1,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
|
||||||
|
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||||
|
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||||
|
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||||
|
# own context.
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
local_rank=local_rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
backend="gloo",
|
||||||
|
)
|
||||||
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size)
|
173
vllm/v1/worker/worker_base.py
Normal file
173
vllm/v1/worker/worker_base.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
"""A GPU worker class."""
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
|
||||||
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.worker.model_runner_base import ModelRunnerBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerBase:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
):
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.cache_config = vllm_config.cache_config
|
||||||
|
self.lora_config = vllm_config.lora_config
|
||||||
|
self.load_config = vllm_config.load_config
|
||||||
|
self.parallel_config = vllm_config.parallel_config
|
||||||
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
|
self.device_config = vllm_config.device_config
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||||
|
self.observability_config = vllm_config.observability_config
|
||||||
|
|
||||||
|
self.parallel_config.rank = rank
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
|
||||||
|
if self.cache_config.cache_dtype == "auto":
|
||||||
|
self.cache_dtype = self.model_config.dtype
|
||||||
|
else:
|
||||||
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
|
self.cache_config.cache_dtype]
|
||||||
|
|
||||||
|
if self.model_config.trust_remote_code:
|
||||||
|
# note: lazy import to avoid importing torch before initializing
|
||||||
|
from vllm.utils import init_cached_hf_modules
|
||||||
|
init_cached_hf_modules()
|
||||||
|
|
||||||
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||||
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||||
|
torch_profiler_trace_dir)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
torch_profiler_trace_dir, use_gzip=True))
|
||||||
|
else:
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
|
# Initialized by the specific platform
|
||||||
|
self.model_runner: Optional[ModelRunnerBase] = None
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def compile_or_warm_up_model(self) -> None:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
|
||||||
|
if not self.model_config.enforce_eager:
|
||||||
|
self.model_runner.capture_model()
|
||||||
|
|
||||||
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
|
# the model initialization and profiling.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||||
|
assert self.model_runner is not None
|
||||||
|
return self.model_runner.get_kv_cache_spec()
|
||||||
|
|
||||||
|
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||||
|
assert self.model_runner is not None
|
||||||
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
|
def profile(self, is_start: bool = True):
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
if is_start:
|
||||||
|
self.profiler.start()
|
||||||
|
else:
|
||||||
|
self.profiler.stop()
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
# worker will always be healthy as long as it's running.
|
||||||
|
return
|
||||||
|
|
||||||
|
def init_device(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def determine_available_memory(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> Optional[ModelRunnerOutput]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
|
# Check if the GPU supports the dtype.
|
||||||
|
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||||
|
if not current_platform.has_device_capability(80):
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
gpu_name = current_platform.get_device_name()
|
||||||
|
|
||||||
|
if capability is None:
|
||||||
|
compute_str = "does not have a compute capability"
|
||||||
|
else:
|
||||||
|
version_str = capability.as_version_str()
|
||||||
|
compute_str = f"has compute capability {version_str}"
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Bfloat16 is only supported on GPUs with compute capability "
|
||||||
|
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||||
|
"You can use float16 instead by explicitly setting the"
|
||||||
|
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_block_size(
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
) -> int:
|
||||||
|
head_size = model_config.get_head_size()
|
||||||
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||||
|
parallel_config, LayerBlockType.attention)
|
||||||
|
|
||||||
|
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
dtype = model_config.dtype
|
||||||
|
else:
|
||||||
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
|
dtype_size = get_dtype_size(dtype)
|
||||||
|
return dtype_size * total
|
Reference in New Issue
Block a user