Compare commits

...

14 Commits

Author SHA1 Message Date
70b4e46e70 compilation is fixed 2025-02-06 20:49:29 +00:00
5fb9dbe6f6 fix capture model 2025-02-06 20:18:30 +00:00
996b92ccb4 swap works! 2025-02-05 20:28:33 +00:00
2b0526fa15 works! 2025-02-05 16:54:57 +00:00
7be649256f fixes 2025-02-05 15:36:38 +00:00
627efde813 fixes 2025-02-04 22:16:19 +00:00
c2867d5bc1 Optimize decode/prompt prepare code 2025-02-04 21:12:07 +00:00
39c4a4cdb5 review comments
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
1ccf100c6a clean-ups
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
248c5b632d works
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
950f349492 scheduler is clean
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
61bb55f3d5 Chunked prompt works!
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
0bddb6b9a5 reorder funcs
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
c715fb19e5 [V1] TPU support
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
19 changed files with 2142 additions and 407 deletions

View File

@ -8,10 +8,10 @@ prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

View File

@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
@ -34,6 +34,6 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.1 # required for compressed-tensors
compressed-tensors == 0.9.1 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py

View File

@ -106,9 +106,17 @@ dnspython==2.7.0
docutils==0.16
# via awscli
einops==0.8.0
# via -r requirements-test.in
# via
# -r requirements-test.in
# encodec
# vector-quantize-pytorch
# vocos
einx==0.3.0
# via vector-quantize-pytorch
email-validator==2.2.0
# via pydantic
encodec==0.1.1
# via vocos
evaluate==0.4.3
# via lm-eval
fastparquet==2024.11.0
@ -125,6 +133,8 @@ filelock==3.16.1
# triton
fonttools==4.54.1
# via matplotlib
frozendict==2.4.6
# via einx
frozenlist==1.5.0
# via
# aiohttp
@ -159,6 +169,7 @@ huggingface-hub==0.26.2
# timm
# tokenizers
# transformers
# vocos
idna==3.10
# via
# anyio
@ -261,6 +272,8 @@ numpy==1.26.4
# cupy-cuda12x
# datasets
# decord
# einx
# encodec
# evaluate
# fastparquet
# genai-perf
@ -283,6 +296,7 @@ numpy==1.26.4
# torchvision
# transformers
# tritonclient
# vocos
nvidia-cublas-cu12==12.4.5.8
# via
# nvidia-cudnn-cu12
@ -455,6 +469,7 @@ pyyaml==6.0.2
# responses
# timm
# transformers
# vocos
ray[adag]==2.40.0
# via -r requirements-test.in
redis==5.2.0
@ -517,6 +532,7 @@ scipy==1.13.1
# scikit-learn
# sentence-transformers
# statsmodels
# vocos
sentence-transformers==3.2.1
# via -r requirements-test.in
sentencepiece==0.2.0
@ -540,7 +556,9 @@ sqlitedict==2.1.0
statsmodels==0.14.4
# via genai-perf
sympy==1.13.1
# via torch
# via
# einx
# torch
tabledata==1.3.3
# via pytablewriter
tabulate==0.9.0
@ -568,12 +586,21 @@ torch==2.5.1
# -r requirements-test.in
# accelerate
# bitsandbytes
# encodec
# lm-eval
# peft
# sentence-transformers
# tensorizer
# timm
# torchaudio
# torchvision
# vector-quantize-pytorch
# vocos
torchaudio==2.5.1
# via
# -r requirements-test.in
# encodec
# vocos
torchvision==0.20.1
# via timm
tqdm==4.66.6
@ -584,13 +611,15 @@ tqdm==4.66.6
# lm-eval
# nltk
# peft
# pqdm
# sentence-transformers
# tqdm-multiprocess
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.47.0
transformers==4.48.2
# via
# -r requirements-test.in
# genai-perf
# lm-eval
# peft
@ -615,6 +644,7 @@ typing-extensions==4.12.2
# huggingface-hub
# librosa
# mistral-common
# pqdm
# pydantic
# pydantic-core
# torch
@ -626,6 +656,10 @@ urllib3==2.2.3
# requests
# responses
# tritonclient
vector-quantize-pytorch==1.21.2
# via -r requirements-test.in
vocos==0.1.0
# via -r requirements-test.in
word2number==1.1
# via lm-eval
xxhash==3.5.0

View File

@ -13,13 +13,11 @@ ray[default]
# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241126+cpu
torchvision==0.20.0.dev20241126+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
jaxlib==0.4.36.dev20241122
jax==0.4.36.dev20241122
torch==2.6.0.dev20241216+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

View File

@ -20,7 +20,7 @@ TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked
@ -66,14 +66,21 @@ def run_test(more_args):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 currently only supported on CUDA")
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test([])
more_args = []
# Limit compilation time for V1
if current_platform.is_tpu():
more_args = ["--max-num-seqs", "64"]
run_test(more_args)
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)

View File

@ -135,7 +135,7 @@ class CudaPlatformBase(Platform):
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
"vllm.v1.worker.gpu_worker.GPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

View File

@ -32,6 +32,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()

View File

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
@ -30,8 +31,14 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
if selected_backend != _Backend.PALLAS:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"
@ -45,7 +52,7 @@ class TpuPlatform(Platform):
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
return not envs.VLLM_USE_V1
@classmethod
def inference_mode(cls):
@ -60,11 +67,11 @@ class TpuPlatform(Platform):
cache_config.block_size = 16
compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION
# TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
if compilation_config.backend == "":
compilation_config.backend = "openxla"
@ -72,10 +79,6 @@ class TpuPlatform(Platform):
assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding"
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
@ -85,8 +88,27 @@ class TpuPlatform(Platform):
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"
# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False

View File

@ -0,0 +1,351 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if attn_metadata is None:
if output is None:
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size
output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output

View File

@ -57,6 +57,14 @@ class BlockTable:
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)

View File

@ -72,7 +72,7 @@ class InputBatch:
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32)
# Block table.
self.block_table = BlockTable(
@ -436,3 +436,77 @@ class InputBatch:
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0
def swap_positions(b: InputBatch, id_1, id_2):
assert id_1 != id_2
req_id_1 = b.req_ids[id_1]
req_id_2 = b.req_ids[id_2]
assert req_id_1 is not None
assert req_id_2 is not None
assert id_1 == b.req_id_to_index[req_id_1]
assert id_2 == b.req_id_to_index[req_id_2]
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
b.req_id_to_index[req_id_1], b.req_id_to_index[
req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]
ids = [id_1, id_2]
rev_ids = [id_2, id_1]
b.num_tokens[ids] = b.num_tokens[rev_ids]
b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]
b.block_table.swap_row(id_1, id_2)
b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
id_1]
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
id_2], b.stop_token_ids[id_1]
gen_1 = b.generators.pop(id_1, None)
gen_2 = b.generators.pop(id_2, None)
if gen_1 is not None:
b.generators[id_2] = gen_1
if gen_2 is not None:
b.generators[id_1] = gen_2
def ensure_decodes_first(b: InputBatch):
num_reqs = b.num_reqs
while True:
# Find the first prompt index
first_prompt_index = None
for i in range(num_reqs):
if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
first_prompt_index = i
break
if first_prompt_index is None:
break
# Find the last decode index
last_decode_index = None
for i in reversed(range(num_reqs)):
if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
last_decode_index = i
break
if last_decode_index is None:
break
# Sanity
assert first_prompt_index != last_decode_index
# Check if done
if first_prompt_index > last_decode_index:
break
# Swap
swap_positions(b, first_prompt_index, last_decode_index)

View File

@ -5,32 +5,23 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.utils import DeviceMemoryProfiler, cdiv
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
@ -38,87 +29,17 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class GPUModelRunner:
class GPUModelRunner(ModelRunnerBase):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
super().__init__(vllm_config, device)
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
# KV caches for forward pass
self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
@ -202,132 +123,6 @@ class GPUModelRunner:
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@ -611,6 +406,8 @@ class GPUModelRunner:
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
assert self.model is not None
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
@ -698,15 +495,14 @@ class GPUModelRunner:
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
def get_model(self) -> nn.Module:
return self.model
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
assert self.model is not None
self.update_states(scheduler_output)
if self.is_multimodal_model:
# Run the multimodal encoder if any.
@ -833,14 +629,15 @@ class GPUModelRunner:
self.model_memory_usage / float(2**30))
@torch.inference_mode()
def _dummy_run(
def dummy_run(
self,
kv_caches,
num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None,
seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> torch.Tensor:
model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
assert self.model is not None
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
@ -851,7 +648,7 @@ class GPUModelRunner:
positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_tokens]
hidden_states = model(
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
@ -861,6 +658,7 @@ class GPUModelRunner:
return hidden_states
def profile_run(self) -> None:
assert self.model is not None
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
@ -966,7 +764,7 @@ class GPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
@ -992,8 +790,8 @@ class GPUModelRunner:
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self.dummy_run(None, num_tokens)
self.dummy_run(None, num_tokens)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -1036,38 +834,3 @@ class GPUModelRunner:
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@ -1,13 +1,11 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
@ -15,20 +13,17 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase, check_if_gpu_supports_dtype
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class Worker:
class GPUWorker(WorkerBase):
def __init__(
self,
@ -38,46 +33,8 @@ class Worker:
distributed_init_method: str,
is_driver_worker: bool = False,
):
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
super().__init__(vllm_config, local_rank, rank,
distributed_init_method)
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
@ -97,7 +54,16 @@ class Worker:
allocator.wake_up()
def init_device(self):
if self.device_config.device.type == "cuda":
assert self.device_config.device.type == "cuda"
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
@ -111,15 +77,14 @@ class Worker:
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
init_cuda_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
@ -139,6 +104,7 @@ class Worker:
from contextlib import nullcontext
context = nullcontext()
with context:
assert self.model_runner is not None
self.model_runner.load_model()
@torch.inference_mode()
@ -160,6 +126,7 @@ class Worker:
_, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
assert self.model_runner is not None
self.model_runner.profile_run()
free_gpu_memory, _ = torch.cuda.mem_get_info()
@ -191,9 +158,6 @@ class Worker:
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
@ -203,9 +167,12 @@ class Worker:
from contextlib import nullcontext
context = nullcontext()
with context:
assert self.model_runner is not None
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
assert self.model_runner is not None
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
@ -217,44 +184,32 @@ class Worker:
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
self.model_runner.dummy_run(None, size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
assert self.model_runner is not None
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def init_worker_distributed_environment(
def init_cuda_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
@ -264,21 +219,22 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
# TODO: Remove
# def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# # Check if the GPU supports the dtype.
# if torch_dtype == torch.bfloat16: # noqa: SIM102
# if not current_platform.has_device_capability(80):
# capability = current_platform.get_device_capability()
# gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
# if capability is None:
# compute_str = "does not have a compute capability"
# else:
# version_str = capability.as_version_str()
# compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
# raise ValueError(
# "Bfloat16 is only supported on GPUs with compute capability "
# f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
# "You can use float16 instead by explicitly setting the"
# "`dtype` flag in CLI, for example: --dtype=half.")

View File

@ -0,0 +1,307 @@
import enum
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
class ModelRunnerBase:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.device_config = vllm_config.device_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.model: Optional[nn.Module] = None
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def get_model(self) -> nn.Module:
assert self.model is not None
return self.model
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
raise NotImplementedError()
def load_model(self) -> None:
raise NotImplementedError()
def dummy_run(
self,
kv_caches,
num_tokens: int,
seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> torch.Tensor:
raise NotImplementedError()
def profile_run(self) -> None:
raise NotImplementedError()
def capture_model(self) -> None:
raise NotImplementedError()

View File

@ -0,0 +1,888 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from unittest.mock import patch
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
ensure_decodes_first)
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
@dataclass
class PromptDecodeInfo:
prompt_req_ids: List[str]
decode_req_ids: List[str]
prompt_scheduled_tokens: List[int]
@dataclass
class PromptData:
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: PallasMetadata
@dataclass
class DecodeData:
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional[PallasMetadata] = None
class TPUModelRunner(ModelRunnerBase):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(vllm_config, device)
# KV caches for forward pass
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
# Cached torch/numpy tensors
self.num_swaps = 2
self.cur_swap_id = 0
self.input_ids_cpu = []
self.input_ids_np = []
self.input_positions_cpu = []
self.input_positions_np = []
self.slot_mapping_cpu = []
self.slot_mapping_np = []
self.prompt_context_lens_cpu = []
self.prompt_effective_query_lens_cpu = []
self.decode_context_lens_cpu = []
self.decode_context_lens_np = []
for _ in range(self.num_swaps):
self.input_ids_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.input_ids_np.append(self.input_ids_cpu[-1].numpy())
self.input_positions_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.input_positions_np.append(
self.input_positions_cpu[-1].numpy())
self.slot_mapping_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int64,
device="cpu"))
self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())
self.prompt_context_lens_cpu.append(
torch.empty((1), dtype=torch.int32, device="cpu"))
self.prompt_effective_query_lens_cpu.append(
torch.empty((1), dtype=torch.int32, device="cpu"))
self.decode_context_lens_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.decode_context_lens_np.append(
self.decode_context_lens_cpu[-1].numpy())
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
def swap_step(self):
self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps
def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
) -> PromptDecodeInfo:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Traverse decodes first
decode_req_ids = []
for i in range(num_reqs):
req_id = self.input_batch.req_ids[i]
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
if num_computed_tokens < num_prompt_tokens:
# This is prompt
break
# This is decode
assert num_scheduled_tokens == 1
decode_req_ids.append(req_id)
# Traverse prompts
prompt_req_ids = []
prompt_scheduled_tokens = []
for i in range(len(decode_req_ids), num_reqs):
req_id = self.input_batch.req_ids[i]
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
prompt_req_ids.append(req_id)
prompt_scheduled_tokens.append(num_scheduled_tokens)
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
prompt_scheduled_tokens)
def _prepare_prompt(self, req_index: int,
num_scheduled_tokens: int) -> PromptData:
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
req_index]
num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
# Prompt len
prompt_len = num_scheduled_tokens
padded_prompt_len = _get_padded_prompt_len(prompt_len)
assert padded_prompt_len <= self.max_model_len
# Seq len
seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
# DEBUG
# print("_prepare_prompt:")
# print(" prompt_len = {}".format(prompt_len))
# print(" padded_prompt_len = {}".format(padded_prompt_len))
# print(" num_computed_tokens = {}".format(num_computed_tokens))
# print(" num_prompt_tokens = {}".format(num_prompt_tokens))
# print(" seq_len = {}".format(seq_len))
# print(" padded_seq_len = {}".format(padded_seq_len))
# Input tokens
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
req_index, num_computed_tokens:padded_seq_len]
input_tokens_cpu[prompt_len:] = 0
# DEBUG
# print(" input_tokens_cpu.shape = {} val = {}".format(
# input_tokens_cpu.shape, input_tokens_cpu))
# Input positions
input_positions_np = self.input_positions_np[
self.cur_swap_id][:padded_prompt_len]
np.add(num_computed_tokens,
self.arange_np[:padded_prompt_len],
out=input_positions_np)
input_positions_np[prompt_len:] = 0
# DEBUG
# print(" input_positions_np.shape = {} val = {}".format(
# input_positions_np.shape, input_positions_np))
# Slot mapping
block_table_np = \
self.input_batch.block_table.get_numpy_array()
block_numbers_np = block_table_np[req_index, input_positions_np //
self.block_size]
block_offsets_np = input_positions_np % self.block_size
slot_mapping_np = self.slot_mapping_np[
self.cur_swap_id][:padded_prompt_len]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
# DEBUG
# print(" slot_mapping_np.shape = {} val = {}".format(
# slot_mapping_np.shape, slot_mapping_np))
# Block table
block_table_cpu = None
if num_computed_tokens > 0:
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = block_table_cpu[req_index]
# DEBUG
# print(" block_table_cpu = {}".format(block_table_cpu))
# Context len
self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
if num_computed_tokens > 0:
self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len
# Effective query len
self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len
# Get final tensors
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
input_positions = self.input_positions_cpu[
self.cur_swap_id][:padded_prompt_len].reshape(1,
-1).to(self.device)
slot_mapping = self.slot_mapping_cpu[
self.cur_swap_id][:padded_prompt_len].reshape(1,
-1).to(self.device)
block_table = block_table_cpu.reshape(1, -1).to(
self.device) if block_table_cpu is not None else None
context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
self.device)
effective_query_lens = self.prompt_effective_query_lens_cpu[
self.cur_swap_id].to(self.device)
self.swap_step()
# DEBUG
# print(" input_tokens.shape = {} val = {}".format(
# input_tokens.shape, input_tokens))
# print(" input_positions.shape = {} val = {}".format(
# input_positions.shape, input_positions))
# print(" slot_mapping.shape = {} val = {}".format(
# slot_mapping.shape, slot_mapping))
# print(" block_table = {}".format(block_table))
# print(" context_lens.shape = {} val = {}".format(
# context_lens.shape, context_lens))
# print(" effective_query_lens.shape = {} val = {}".format(
# effective_query_lens.shape, effective_query_lens))
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=1,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
return PromptData(input_tokens, input_positions, attn_metadata)
def _prepare_decode(
self,
decode_req_ids: List[str],
) -> DecodeData:
# Batch size
batch_size = len(decode_req_ids)
padded_batch_size = _get_padded_batch_size(batch_size)
assert padded_batch_size <= self.max_model_len
# Init [0 .. batch_size - 1]
req_indices_np = self.arange_np[:padded_batch_size]
# DEBUG
# print("_prepare_decode:")
# print(" batch_size = {}".format(batch_size))
# print(" padded_batch_size = {}".format(padded_batch_size))
# print(" req_indices_np.shape = {} val = {}".format(
# req_indices_np.shape, req_indices_np))
# Input positions
input_positions_np = self.input_positions_np[
self.cur_swap_id][:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
0,
out=input_positions_np)
input_positions_np[batch_size:] = 0
input_positions_cpu = self.input_positions_cpu[
self.cur_swap_id][:padded_batch_size]
# DEBUG
# print(" input_positions_cpu.shape = {} data = {}".format(
# input_positions_cpu.shape, input_positions_cpu))
# Input tokens
token_indices_np = (
input_positions_np +
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
input_tokens_cpu = self.input_ids_cpu[
self.cur_swap_id][:padded_batch_size]
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_np),
out=input_tokens_cpu)
input_tokens_cpu[batch_size:] = 0
# DEBUG
# print(" token_indices_np.shape = {} val = {}".format(
# token_indices_np.shape, token_indices_np))
# print(" input_tokens_cpu.shape = {} data = {}".format(
# input_tokens_cpu.shape, input_tokens_cpu))
# Slot mapping
block_table_indices_np = (
req_indices_np * self.max_num_blocks_per_req +
input_positions_np // self.block_size)
# DEBUG
# print(
# " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
# .format(block_table_indices_np.shape, block_table_indices_np,
# self.max_num_blocks_per_req))
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
# DEBUG
# print(" block_table_cpu.shape = {} data = {}".format(
# block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
block_numbers_np = block_table_cpu.flatten(
)[block_table_indices_np].numpy()
# DEBUG
# print(" block_numbers_np.shape = {} data = {}".format(
# block_numbers_np.shape, block_numbers_np))
block_offsets_np = input_positions_np % self.block_size
# DEBUG
# print(" block_offsets_np.shape = {} data = {}".format(
# block_offsets_np.shape, block_offsets_np))
slot_mapping_np = self.slot_mapping_np[
self.cur_swap_id][:padded_batch_size]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
# DEBUG
# print(" slot_mapping_np.shape = {} data = {}".format(
# slot_mapping_np.shape, slot_mapping_np))
block_table_cpu = block_table_cpu[:padded_batch_size]
# Context lens
context_lens_np = self.decode_context_lens_np[
self.cur_swap_id][:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
1,
out=context_lens_np)
context_lens_np[batch_size:] = 0
# Get final tensors
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
slot_mapping = self.slot_mapping_cpu[
self.cur_swap_id][:padded_batch_size].reshape(-1,
1).to(self.device)
block_table = block_table_cpu.to(self.device)
context_lens = self.decode_context_lens_cpu[
self.cur_swap_id][:padded_batch_size].to(self.device)
self.swap_step()
# DEBUG
# print(" context_lens.shape = {} val = {}".format(
# context_lens.shape, context_lens))
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table,
context_lens=context_lens,
effective_query_lens=None,
)
return DecodeData(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata)
@torch.no_grad()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
# Update cached state
self.update_states(scheduler_output)
# If necessary, swap decodes/prompts to have all decodes on the start
ensure_decodes_first(self.input_batch)
# Prepare prompts/decodes info
pd_info = self._get_prompts_and_decodes(scheduler_output)
# Init
num_prompts = len(pd_info.prompt_req_ids)
num_decodes = len(pd_info.decode_req_ids)
decode_data = None
sampled_token_ids = [0] * self.input_batch.num_reqs
# Run each prompt individually
is_first = True
for i in range(num_prompts):
req_id = pd_info.prompt_req_ids[i]
req_index = num_decodes + i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
prompt_len = num_scheduled_tokens
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
# Prepare first prompt
if is_first:
prompt_data = self._prepare_prompt(req_index,
num_scheduled_tokens)
is_first = False
# Run forward pass
with set_forward_context(prompt_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions,
prompt_data.attn_metadata,
self.kv_caches)
# In parallel to TPU execution, prepare the next iteration
if i < num_prompts - 1:
# There is next prompt => prepare it
prompt_data = self._prepare_prompt(
req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
elif i == num_prompts - 1 and num_decodes > 0:
# There is next decode => prepare it
decode_data = self._prepare_decode(pd_info.decode_req_ids)
# Update cached state (if prompt is fully done)
if seq_len >= len(req_state.prompt_token_ids):
# Transfer sampled tokens from TPU to CPU
selected_token_ids_cpu = selected_token_ids.cpu()
# Get output token
token_id = selected_token_ids_cpu[prompt_len - 1].item()
sampled_token_ids[req_index] = token_id
# DEBUG
# print(
# " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
# .format(token_id, prompt_len, req_id, req_index,
# selected_token_ids_cpu))
# Add output token to the request
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Run decodes (a single batch)
if num_decodes > 0:
# Prepare decode (if was not yet prepared)
if decode_data is None:
decode_data = self._prepare_decode(pd_info.decode_req_ids)
# Run forward pass
with set_forward_context(decode_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions,
decode_data.attn_metadata,
self.kv_caches)
# Transfer sampled tokens from TPU to CPU
decode_token_ids_cpu = selected_token_ids.cpu()
# Convert to list
decode_token_ids_list = decode_token_ids_cpu.tolist()
# Update cached state for each decode request
for i in range(num_decodes):
req_id = pd_info.decode_req_ids[i]
req_index = i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
seq_len = req_state.num_computed_tokens + 1
token_id = decode_token_ids_list[i]
sampled_token_ids[req_index] = token_id
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Create output
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=None,
logprobs_cpu=None,
)
return model_runner_output
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapperV1(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def dummy_run(
self,
kv_caches,
num_tokens: int,
seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> None:
assert seq_len is not None
assert exec_mode is not None
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((num_tokens, seq_len),
dtype=torch.int64,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = PallasMetadata(
num_prefills=num_tokens,
num_prefill_tokens=num_tokens * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.zeros(
(num_tokens, self.max_num_blocks_per_req),
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = PallasMetadata(
num_prefills=num_tokens,
num_prefill_tokens=num_tokens * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((num_tokens, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(num_tokens, self.max_num_blocks_per_req),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_tokens * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(token_ids, position_ids, attn_metadata, kv_caches)
def capture_model(self) -> None:
"""Compile the model."""
# Prefill
logger.info(
"Compiling the model with different input shapes for prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info(" -- Compilation for prefill done in %.2f [secs].",
end - start)
# Prefix prefill
if self.scheduler_config.enable_chunked_prefill:
logger.info("Compiling the model with different input shapes for "
"prefix prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info(
" -- Compilation for prefix prefill done in %.2f [secs].",
end - start)
# Decode
logger.info(
"Compiling the model with different input shapes for decode:")
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info(" -- Compilation for decode done in %.2f [secs].",
end - start)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
# Skip this in memory profiling at initialization.
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indicies *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
assert self.model is not None
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, None)
# Greedy sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
return argmax_token_ids
def _get_padded_prompt_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16

View File

@ -0,0 +1,153 @@
"""A TPU worker class."""
import os
from typing import Optional, Dict
import torch
import torch.distributed
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.utils import bind_kv_cache
logger = init_logger(__name__)
class TPUWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(vllm_config, local_rank, rank,
distributed_init_method)
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
init_tpu_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
# Device initialization should happen after initializing
# the distributed runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int:
assert self.model_runner is not None
kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
runner_kv_caches = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
self.model_runner.dummy_run(
runner_kv_caches,
num_tokens=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
exec_mode=ExecutionMode.PREFILL,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
return int(tpu_kv_cache_bytes)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
assert self.model_runner is not None
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
def init_tpu_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

View File

@ -0,0 +1,173 @@
"""A GPU worker class."""
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.model_runner_base import ModelRunnerBase
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class WorkerBase:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
# Initialized by the specific platform
self.model_runner: Optional[ModelRunnerBase] = None
def load_model(self) -> None:
assert self.model_runner is not None
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
assert self.model_runner is not None
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
assert self.model_runner is not None
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec:
assert self.model_runner is not None
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
assert self.model_runner is not None
self.model_runner.initialize_kv_cache(kv_cache_config)
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def init_device(self):
raise NotImplementedError()
def determine_available_memory(self) -> int:
raise NotImplementedError()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
raise NotImplementedError()
def check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total