Compare commits

...

14 Commits

Author SHA1 Message Date
70b4e46e70 compilation is fixed 2025-02-06 20:49:29 +00:00
5fb9dbe6f6 fix capture model 2025-02-06 20:18:30 +00:00
996b92ccb4 swap works! 2025-02-05 20:28:33 +00:00
2b0526fa15 works! 2025-02-05 16:54:57 +00:00
7be649256f fixes 2025-02-05 15:36:38 +00:00
627efde813 fixes 2025-02-04 22:16:19 +00:00
c2867d5bc1 Optimize decode/prompt prepare code 2025-02-04 21:12:07 +00:00
39c4a4cdb5 review comments
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
1ccf100c6a clean-ups
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
248c5b632d works
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
950f349492 scheduler is clean
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
61bb55f3d5 Chunked prompt works!
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
0bddb6b9a5 reorder funcs
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
c715fb19e5 [V1] TPU support
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-01-28 23:08:50 +00:00
19 changed files with 2142 additions and 407 deletions

View File

@ -89,4 +89,4 @@ repos:
name: Suggestion name: Suggestion
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
language: system language: system
verbose: true verbose: true

View File

@ -8,10 +8,10 @@ prompts = [
"The future of AI is", "The future of AI is",
] ]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m") llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
@ -19,4 +19,4 @@ outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm tqdm
blake3 blake3
py-cpuinfo py-cpuinfo
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL. transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer. protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
@ -34,6 +34,6 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL. einops # Required for Qwen2-VL.
compressed-tensors == 0.8.1 # required for compressed-tensors compressed-tensors == 0.9.1 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py cloudpickle # allows pickling lambda functions in model_executor/models/registry.py

View File

@ -2,7 +2,7 @@
# This file is autogenerated by pip-compile with Python 3.12 # This file is autogenerated by pip-compile with Python 3.12
# by the following command: # by the following command:
# #
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt # python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
# #
absl-py==2.1.0 absl-py==2.1.0
# via rouge-score # via rouge-score
@ -106,9 +106,17 @@ dnspython==2.7.0
docutils==0.16 docutils==0.16
# via awscli # via awscli
einops==0.8.0 einops==0.8.0
# via -r requirements-test.in # via
# -r requirements-test.in
# encodec
# vector-quantize-pytorch
# vocos
einx==0.3.0
# via vector-quantize-pytorch
email-validator==2.2.0 email-validator==2.2.0
# via pydantic # via pydantic
encodec==0.1.1
# via vocos
evaluate==0.4.3 evaluate==0.4.3
# via lm-eval # via lm-eval
fastparquet==2024.11.0 fastparquet==2024.11.0
@ -125,6 +133,8 @@ filelock==3.16.1
# triton # triton
fonttools==4.54.1 fonttools==4.54.1
# via matplotlib # via matplotlib
frozendict==2.4.6
# via einx
frozenlist==1.5.0 frozenlist==1.5.0
# via # via
# aiohttp # aiohttp
@ -159,6 +169,7 @@ huggingface-hub==0.26.2
# timm # timm
# tokenizers # tokenizers
# transformers # transformers
# vocos
idna==3.10 idna==3.10
# via # via
# anyio # anyio
@ -261,6 +272,8 @@ numpy==1.26.4
# cupy-cuda12x # cupy-cuda12x
# datasets # datasets
# decord # decord
# einx
# encodec
# evaluate # evaluate
# fastparquet # fastparquet
# genai-perf # genai-perf
@ -283,6 +296,7 @@ numpy==1.26.4
# torchvision # torchvision
# transformers # transformers
# tritonclient # tritonclient
# vocos
nvidia-cublas-cu12==12.4.5.8 nvidia-cublas-cu12==12.4.5.8
# via # via
# nvidia-cudnn-cu12 # nvidia-cudnn-cu12
@ -455,6 +469,7 @@ pyyaml==6.0.2
# responses # responses
# timm # timm
# transformers # transformers
# vocos
ray[adag]==2.40.0 ray[adag]==2.40.0
# via -r requirements-test.in # via -r requirements-test.in
redis==5.2.0 redis==5.2.0
@ -517,6 +532,7 @@ scipy==1.13.1
# scikit-learn # scikit-learn
# sentence-transformers # sentence-transformers
# statsmodels # statsmodels
# vocos
sentence-transformers==3.2.1 sentence-transformers==3.2.1
# via -r requirements-test.in # via -r requirements-test.in
sentencepiece==0.2.0 sentencepiece==0.2.0
@ -540,7 +556,9 @@ sqlitedict==2.1.0
statsmodels==0.14.4 statsmodels==0.14.4
# via genai-perf # via genai-perf
sympy==1.13.1 sympy==1.13.1
# via torch # via
# einx
# torch
tabledata==1.3.3 tabledata==1.3.3
# via pytablewriter # via pytablewriter
tabulate==0.9.0 tabulate==0.9.0
@ -568,12 +586,21 @@ torch==2.5.1
# -r requirements-test.in # -r requirements-test.in
# accelerate # accelerate
# bitsandbytes # bitsandbytes
# encodec
# lm-eval # lm-eval
# peft # peft
# sentence-transformers # sentence-transformers
# tensorizer # tensorizer
# timm # timm
# torchaudio
# torchvision # torchvision
# vector-quantize-pytorch
# vocos
torchaudio==2.5.1
# via
# -r requirements-test.in
# encodec
# vocos
torchvision==0.20.1 torchvision==0.20.1
# via timm # via timm
tqdm==4.66.6 tqdm==4.66.6
@ -584,13 +611,15 @@ tqdm==4.66.6
# lm-eval # lm-eval
# nltk # nltk
# peft # peft
# pqdm
# sentence-transformers # sentence-transformers
# tqdm-multiprocess # tqdm-multiprocess
# transformers # transformers
tqdm-multiprocess==0.0.11 tqdm-multiprocess==0.0.11
# via lm-eval # via lm-eval
transformers==4.47.0 transformers==4.48.2
# via # via
# -r requirements-test.in
# genai-perf # genai-perf
# lm-eval # lm-eval
# peft # peft
@ -615,6 +644,7 @@ typing-extensions==4.12.2
# huggingface-hub # huggingface-hub
# librosa # librosa
# mistral-common # mistral-common
# pqdm
# pydantic # pydantic
# pydantic-core # pydantic-core
# torch # torch
@ -626,6 +656,10 @@ urllib3==2.2.3
# requests # requests
# responses # responses
# tritonclient # tritonclient
vector-quantize-pytorch==1.21.2
# via -r requirements-test.in
vocos==0.1.0
# via -r requirements-test.in
word2number==1.1 word2number==1.1
# via lm-eval # via lm-eval
xxhash==3.5.0 xxhash==3.5.0
@ -638,4 +672,4 @@ zstandard==0.23.0
# via lm-eval # via lm-eval
# The following packages are considered to be unsafe in a requirements file: # The following packages are considered to be unsafe in a requirements file:
# setuptools # setuptools

View File

@ -13,13 +13,11 @@ ray[default]
# Install torch_xla # Install torch_xla
--pre --pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241126+cpu torch==2.6.0.dev20241216+cpu
torchvision==0.20.0.dev20241126+cpu torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
jaxlib==0.4.36.dev20241122
jax==0.4.36.dev20241122

View File

@ -20,7 +20,7 @@ TASK = "gsm8k"
FILTER = "exact_match,strict-match" FILTER = "exact_match,strict-match"
RTOL = 0.03 RTOL = 0.03
EXPECTED_VALUE = 0.58 EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"] DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [ MORE_ARGS_LIST = [
[], # Default [], # Default
["--enable-chunked-prefill"], # Chunked ["--enable-chunked-prefill"], # Chunked
@ -66,14 +66,21 @@ def run_test(more_args):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not current_platform.is_cuda(), @pytest.mark.skipif(not current_platform.is_cuda()
reason="V1 currently only supported on CUDA") and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch): def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine.""" """Run with the V1 Engine."""
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
run_test([]) more_args = []
# Limit compilation time for V1
if current_platform.is_tpu():
more_args = ["--max-num-seqs", "64"]
run_test(more_args)
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) @pytest.mark.parametrize("more_args", MORE_ARGS_LIST)

View File

@ -34,4 +34,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode run_mypy vllm/spec_decode
run_mypy vllm/worker run_mypy vllm/worker
run_mypy vllm/v1 run_mypy vllm/v1

View File

@ -135,7 +135,7 @@ class CudaPlatformBase(Platform):
else: else:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker" "vllm.v1.worker.gpu_worker.GPUWorker"
else: else:
parallel_config.worker_cls = "vllm.worker.worker.Worker" parallel_config.worker_cls = "vllm.worker.worker.Worker"

View File

@ -32,6 +32,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
HPU_ATTN = enum.auto() HPU_ATTN = enum.auto()
PALLAS = enum.auto() PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto() IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto() NO_ATTENTION = enum.auto()

View File

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum, _Backend
@ -30,10 +31,16 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str: block_size: int, use_v1: bool) -> str:
if selected_backend != _Backend.PALLAS: if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend" if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
@ -45,7 +52,7 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True return not envs.VLLM_USE_V1
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
@ -60,11 +67,11 @@ class TpuPlatform(Platform):
cache_config.block_size = 16 cache_config.block_size = 16
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION # TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
compilation_config.level = CompilationLevel.DYNAMO_ONCE compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
if compilation_config.backend == "": if compilation_config.backend == "":
compilation_config.backend = "openxla" compilation_config.backend = "openxla"
@ -72,10 +79,6 @@ class TpuPlatform(Platform):
assert vllm_config.speculative_config is None, \ assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding" "TPU does not support speculative decoding"
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32): if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning( logger.warning(
"The TPU backend currently does not support %s. " "The TPU backend currently does not support %s. "
@ -85,8 +88,27 @@ class TpuPlatform(Platform):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step: if envs.VLLM_USE_V1:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" "vllm.v1.worker.tpu_worker.TPUWorker"
else: else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"
# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False

View File

@ -0,0 +1,351 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if attn_metadata is None:
if output is None:
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size
output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output

View File

@ -57,6 +57,14 @@ class BlockTable:
src, :num_blocks] src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
def commit(self, num_reqs: int) -> None: def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True) non_blocking=True)

View File

@ -72,7 +72,7 @@ class InputBatch:
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32)
# Block table. # Block table.
self.block_table = BlockTable( self.block_table = BlockTable(
@ -436,3 +436,77 @@ class InputBatch:
@property @property
def no_prompt_logprob(self) -> bool: def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0 return len(self.prompt_logprob_reqs) == 0
def swap_positions(b: InputBatch, id_1, id_2):
assert id_1 != id_2
req_id_1 = b.req_ids[id_1]
req_id_2 = b.req_ids[id_2]
assert req_id_1 is not None
assert req_id_2 is not None
assert id_1 == b.req_id_to_index[req_id_1]
assert id_2 == b.req_id_to_index[req_id_2]
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
b.req_id_to_index[req_id_1], b.req_id_to_index[
req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]
ids = [id_1, id_2]
rev_ids = [id_2, id_1]
b.num_tokens[ids] = b.num_tokens[rev_ids]
b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]
b.block_table.swap_row(id_1, id_2)
b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
id_1]
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
id_2], b.stop_token_ids[id_1]
gen_1 = b.generators.pop(id_1, None)
gen_2 = b.generators.pop(id_2, None)
if gen_1 is not None:
b.generators[id_2] = gen_1
if gen_2 is not None:
b.generators[id_1] = gen_2
def ensure_decodes_first(b: InputBatch):
num_reqs = b.num_reqs
while True:
# Find the first prompt index
first_prompt_index = None
for i in range(num_reqs):
if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
first_prompt_index = i
break
if first_prompt_index is None:
break
# Find the last decode index
last_decode_index = None
for i in reversed(range(num_reqs)):
if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
last_decode_index = i
break
if last_decode_index is None:
break
# Sanity
assert first_prompt_index != last_decode_index
# Check if done
if first_prompt_index > last_decode_index:
break
# Swap
swap_positions(b, first_prompt_index, last_decode_index)

View File

@ -5,32 +5,23 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
import numpy as np import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.utils import DeviceMemoryProfiler, cdiv
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
@ -38,87 +29,17 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
class GPUModelRunner: class GPUModelRunner(ModelRunnerBase):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
): ):
self.vllm_config = vllm_config super().__init__(vllm_config, device)
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
model_config = self.model_config # KV caches for forward pass
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = [] self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
@ -202,132 +123,6 @@ class GPUModelRunner:
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
@ -611,6 +406,8 @@ class GPUModelRunner:
return sampling_metadata return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
assert self.model is not None
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return
@ -698,15 +495,14 @@ class GPUModelRunner:
encoder_outputs.append(encoder_output[start_idx:end_idx]) encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs return encoder_outputs
def get_model(self) -> nn.Module:
return self.model
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
self._update_states(scheduler_output) assert self.model is not None
self.update_states(scheduler_output)
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
@ -833,14 +629,15 @@ class GPUModelRunner:
self.model_memory_usage / float(2**30)) self.model_memory_usage / float(2**30))
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def dummy_run(
self, self,
kv_caches,
num_tokens: int, num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None, seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> torch.Tensor: ) -> torch.Tensor:
model = self.model assert self.model is not None
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
@ -851,7 +648,7 @@ class GPUModelRunner:
positions = self.mrope_positions[:, :num_tokens] \ positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \ if self.model_config.uses_mrope \
else self.positions[:num_tokens] else self.positions[:num_tokens]
hidden_states = model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
@ -861,6 +658,7 @@ class GPUModelRunner:
return hidden_states return hidden_states
def profile_run(self) -> None: def profile_run(self) -> None:
assert self.model is not None
# use an empty tensor instead of `None`` to force Dynamo to pass # use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`. # it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as # the `dtype` argument does not matter, and we use `float32` as
@ -966,7 +764,7 @@ class GPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens)
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens] logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler. # TODO(woosuk): Consider the memory usage of the sampler.
@ -992,8 +790,8 @@ class GPUModelRunner:
for num_tokens in reversed(self.cudagraph_batch_sizes): for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups): cudagraph_num_of_warmups):
self._dummy_run(num_tokens) self.dummy_run(None, num_tokens)
self._dummy_run(num_tokens) self.dummy_run(None, num_tokens)
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -1036,38 +834,3 @@ class GPUModelRunner:
kv_caches, kv_caches,
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
self.kv_caches) self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@ -1,13 +1,11 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import TYPE_CHECKING, Optional from typing import Optional
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
@ -15,20 +13,17 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase, check_if_gpu_supports_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class GPUWorker(WorkerBase):
class Worker:
def __init__( def __init__(
self, self,
@ -38,46 +33,8 @@ class Worker:
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
super().__init__(vllm_config, local_rank, rank,
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) distributed_init_method)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0] free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
@ -97,31 +54,39 @@ class Worker:
allocator.wake_up() allocator.wake_up()
def init_device(self): def init_device(self):
if self.device_config.device.type == "cuda": assert self.device_config.device.type == "cuda"
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building. # torch.distributed.all_reduce does not free the input tensor until
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) # the synchronization point. This causes the memory usage to grow
self.device = torch.device(f"cuda:{self.local_rank}") # as the number of all_reduce calls increases. This env var disables
torch.cuda.set_device(self.device) # this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, init_cuda_worker_distributed_environment(self.parallel_config,
self.distributed_init_method, self.rank,
self.local_rank) self.distributed_init_method,
self.local_rank)
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
@ -139,6 +104,7 @@ class Worker:
from contextlib import nullcontext from contextlib import nullcontext
context = nullcontext() context = nullcontext()
with context: with context:
assert self.model_runner is not None
self.model_runner.load_model() self.model_runner.load_model()
@torch.inference_mode() @torch.inference_mode()
@ -160,6 +126,7 @@ class Worker:
_, total_gpu_memory = torch.cuda.mem_get_info() _, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
assert self.model_runner is not None
self.model_runner.profile_run() self.model_runner.profile_run()
free_gpu_memory, _ = torch.cuda.mem_get_info() free_gpu_memory, _ = torch.cuda.mem_get_info()
@ -191,9 +158,6 @@ class Worker:
return int(available_kv_cache_memory) return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:
@ -203,9 +167,12 @@ class Worker:
from contextlib import nullcontext from contextlib import nullcontext
context = nullcontext() context = nullcontext()
with context: with context:
assert self.model_runner is not None
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
assert self.model_runner is not None
# warm up sizes that are not in cudagraph capture sizes, # warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance, # but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. # e.g. for the max-num-batched token size in chunked prefill.
@ -217,44 +184,32 @@ class Worker:
] ]
for size in sorted(warmup_sizes, reverse=True): for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size) logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size) self.model_runner.dummy_run(None, size)
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model() self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
assert self.model_runner is not None
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None return output if self.rank == 0 else None
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def check_health(self) -> None: def init_cuda_worker_distributed_environment(
# worker will always be healthy as long as it's running.
return
def init_worker_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
@ -264,21 +219,22 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # TODO: Remove
# Check if the GPU supports the dtype. # def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102 # # Check if the GPU supports the dtype.
if not current_platform.has_device_capability(80): # if torch_dtype == torch.bfloat16: # noqa: SIM102
capability = current_platform.get_device_capability() # if not current_platform.has_device_capability(80):
gpu_name = current_platform.get_device_name() # capability = current_platform.get_device_capability()
# gpu_name = current_platform.get_device_name()
if capability is None: # if capability is None:
compute_str = "does not have a compute capability" # compute_str = "does not have a compute capability"
else: # else:
version_str = capability.as_version_str() # version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}" # compute_str = f"has compute capability {version_str}"
raise ValueError( # raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability " # "Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " # f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the" # "You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.") # "`dtype` flag in CLI, for example: --dtype=half.")

View File

@ -0,0 +1,307 @@
import enum
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
class ModelRunnerBase:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.device_config = vllm_config.device_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.model: Optional[nn.Module] = None
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def get_model(self) -> nn.Module:
assert self.model is not None
return self.model
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
raise NotImplementedError()
def load_model(self) -> None:
raise NotImplementedError()
def dummy_run(
self,
kv_caches,
num_tokens: int,
seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> torch.Tensor:
raise NotImplementedError()
def profile_run(self) -> None:
raise NotImplementedError()
def capture_model(self) -> None:
raise NotImplementedError()

View File

@ -0,0 +1,888 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from unittest.mock import patch
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
ensure_decodes_first)
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
@dataclass
class PromptDecodeInfo:
prompt_req_ids: List[str]
decode_req_ids: List[str]
prompt_scheduled_tokens: List[int]
@dataclass
class PromptData:
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: PallasMetadata
@dataclass
class DecodeData:
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional[PallasMetadata] = None
class TPUModelRunner(ModelRunnerBase):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(vllm_config, device)
# KV caches for forward pass
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
# Cached torch/numpy tensors
self.num_swaps = 2
self.cur_swap_id = 0
self.input_ids_cpu = []
self.input_ids_np = []
self.input_positions_cpu = []
self.input_positions_np = []
self.slot_mapping_cpu = []
self.slot_mapping_np = []
self.prompt_context_lens_cpu = []
self.prompt_effective_query_lens_cpu = []
self.decode_context_lens_cpu = []
self.decode_context_lens_np = []
for _ in range(self.num_swaps):
self.input_ids_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.input_ids_np.append(self.input_ids_cpu[-1].numpy())
self.input_positions_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.input_positions_np.append(
self.input_positions_cpu[-1].numpy())
self.slot_mapping_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int64,
device="cpu"))
self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())
self.prompt_context_lens_cpu.append(
torch.empty((1), dtype=torch.int32, device="cpu"))
self.prompt_effective_query_lens_cpu.append(
torch.empty((1), dtype=torch.int32, device="cpu"))
self.decode_context_lens_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.decode_context_lens_np.append(
self.decode_context_lens_cpu[-1].numpy())
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
def swap_step(self):
self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps
def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
) -> PromptDecodeInfo:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Traverse decodes first
decode_req_ids = []
for i in range(num_reqs):
req_id = self.input_batch.req_ids[i]
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
if num_computed_tokens < num_prompt_tokens:
# This is prompt
break
# This is decode
assert num_scheduled_tokens == 1
decode_req_ids.append(req_id)
# Traverse prompts
prompt_req_ids = []
prompt_scheduled_tokens = []
for i in range(len(decode_req_ids), num_reqs):
req_id = self.input_batch.req_ids[i]
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
prompt_req_ids.append(req_id)
prompt_scheduled_tokens.append(num_scheduled_tokens)
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
prompt_scheduled_tokens)
def _prepare_prompt(self, req_index: int,
num_scheduled_tokens: int) -> PromptData:
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
req_index]
num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
# Prompt len
prompt_len = num_scheduled_tokens
padded_prompt_len = _get_padded_prompt_len(prompt_len)
assert padded_prompt_len <= self.max_model_len
# Seq len
seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
# DEBUG
# print("_prepare_prompt:")
# print(" prompt_len = {}".format(prompt_len))
# print(" padded_prompt_len = {}".format(padded_prompt_len))
# print(" num_computed_tokens = {}".format(num_computed_tokens))
# print(" num_prompt_tokens = {}".format(num_prompt_tokens))
# print(" seq_len = {}".format(seq_len))
# print(" padded_seq_len = {}".format(padded_seq_len))
# Input tokens
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
req_index, num_computed_tokens:padded_seq_len]
input_tokens_cpu[prompt_len:] = 0
# DEBUG
# print(" input_tokens_cpu.shape = {} val = {}".format(
# input_tokens_cpu.shape, input_tokens_cpu))
# Input positions
input_positions_np = self.input_positions_np[
self.cur_swap_id][:padded_prompt_len]
np.add(num_computed_tokens,
self.arange_np[:padded_prompt_len],
out=input_positions_np)
input_positions_np[prompt_len:] = 0
# DEBUG
# print(" input_positions_np.shape = {} val = {}".format(
# input_positions_np.shape, input_positions_np))
# Slot mapping
block_table_np = \
self.input_batch.block_table.get_numpy_array()
block_numbers_np = block_table_np[req_index, input_positions_np //
self.block_size]
block_offsets_np = input_positions_np % self.block_size
slot_mapping_np = self.slot_mapping_np[
self.cur_swap_id][:padded_prompt_len]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
# DEBUG
# print(" slot_mapping_np.shape = {} val = {}".format(
# slot_mapping_np.shape, slot_mapping_np))
# Block table
block_table_cpu = None
if num_computed_tokens > 0:
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = block_table_cpu[req_index]
# DEBUG
# print(" block_table_cpu = {}".format(block_table_cpu))
# Context len
self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
if num_computed_tokens > 0:
self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len
# Effective query len
self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len
# Get final tensors
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
input_positions = self.input_positions_cpu[
self.cur_swap_id][:padded_prompt_len].reshape(1,
-1).to(self.device)
slot_mapping = self.slot_mapping_cpu[
self.cur_swap_id][:padded_prompt_len].reshape(1,
-1).to(self.device)
block_table = block_table_cpu.reshape(1, -1).to(
self.device) if block_table_cpu is not None else None
context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
self.device)
effective_query_lens = self.prompt_effective_query_lens_cpu[
self.cur_swap_id].to(self.device)
self.swap_step()
# DEBUG
# print(" input_tokens.shape = {} val = {}".format(
# input_tokens.shape, input_tokens))
# print(" input_positions.shape = {} val = {}".format(
# input_positions.shape, input_positions))
# print(" slot_mapping.shape = {} val = {}".format(
# slot_mapping.shape, slot_mapping))
# print(" block_table = {}".format(block_table))
# print(" context_lens.shape = {} val = {}".format(
# context_lens.shape, context_lens))
# print(" effective_query_lens.shape = {} val = {}".format(
# effective_query_lens.shape, effective_query_lens))
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=1,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
return PromptData(input_tokens, input_positions, attn_metadata)
def _prepare_decode(
self,
decode_req_ids: List[str],
) -> DecodeData:
# Batch size
batch_size = len(decode_req_ids)
padded_batch_size = _get_padded_batch_size(batch_size)
assert padded_batch_size <= self.max_model_len
# Init [0 .. batch_size - 1]
req_indices_np = self.arange_np[:padded_batch_size]
# DEBUG
# print("_prepare_decode:")
# print(" batch_size = {}".format(batch_size))
# print(" padded_batch_size = {}".format(padded_batch_size))
# print(" req_indices_np.shape = {} val = {}".format(
# req_indices_np.shape, req_indices_np))
# Input positions
input_positions_np = self.input_positions_np[
self.cur_swap_id][:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
0,
out=input_positions_np)
input_positions_np[batch_size:] = 0
input_positions_cpu = self.input_positions_cpu[
self.cur_swap_id][:padded_batch_size]
# DEBUG
# print(" input_positions_cpu.shape = {} data = {}".format(
# input_positions_cpu.shape, input_positions_cpu))
# Input tokens
token_indices_np = (
input_positions_np +
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
input_tokens_cpu = self.input_ids_cpu[
self.cur_swap_id][:padded_batch_size]
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_np),
out=input_tokens_cpu)
input_tokens_cpu[batch_size:] = 0
# DEBUG
# print(" token_indices_np.shape = {} val = {}".format(
# token_indices_np.shape, token_indices_np))
# print(" input_tokens_cpu.shape = {} data = {}".format(
# input_tokens_cpu.shape, input_tokens_cpu))
# Slot mapping
block_table_indices_np = (
req_indices_np * self.max_num_blocks_per_req +
input_positions_np // self.block_size)
# DEBUG
# print(
# " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
# .format(block_table_indices_np.shape, block_table_indices_np,
# self.max_num_blocks_per_req))
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
# DEBUG
# print(" block_table_cpu.shape = {} data = {}".format(
# block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
block_numbers_np = block_table_cpu.flatten(
)[block_table_indices_np].numpy()
# DEBUG
# print(" block_numbers_np.shape = {} data = {}".format(
# block_numbers_np.shape, block_numbers_np))
block_offsets_np = input_positions_np % self.block_size
# DEBUG
# print(" block_offsets_np.shape = {} data = {}".format(
# block_offsets_np.shape, block_offsets_np))
slot_mapping_np = self.slot_mapping_np[
self.cur_swap_id][:padded_batch_size]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
# DEBUG
# print(" slot_mapping_np.shape = {} data = {}".format(
# slot_mapping_np.shape, slot_mapping_np))
block_table_cpu = block_table_cpu[:padded_batch_size]
# Context lens
context_lens_np = self.decode_context_lens_np[
self.cur_swap_id][:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
1,
out=context_lens_np)
context_lens_np[batch_size:] = 0
# Get final tensors
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
slot_mapping = self.slot_mapping_cpu[
self.cur_swap_id][:padded_batch_size].reshape(-1,
1).to(self.device)
block_table = block_table_cpu.to(self.device)
context_lens = self.decode_context_lens_cpu[
self.cur_swap_id][:padded_batch_size].to(self.device)
self.swap_step()
# DEBUG
# print(" context_lens.shape = {} val = {}".format(
# context_lens.shape, context_lens))
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table,
context_lens=context_lens,
effective_query_lens=None,
)
return DecodeData(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata)
@torch.no_grad()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
# Update cached state
self.update_states(scheduler_output)
# If necessary, swap decodes/prompts to have all decodes on the start
ensure_decodes_first(self.input_batch)
# Prepare prompts/decodes info
pd_info = self._get_prompts_and_decodes(scheduler_output)
# Init
num_prompts = len(pd_info.prompt_req_ids)
num_decodes = len(pd_info.decode_req_ids)
decode_data = None
sampled_token_ids = [0] * self.input_batch.num_reqs
# Run each prompt individually
is_first = True
for i in range(num_prompts):
req_id = pd_info.prompt_req_ids[i]
req_index = num_decodes + i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
prompt_len = num_scheduled_tokens
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
# Prepare first prompt
if is_first:
prompt_data = self._prepare_prompt(req_index,
num_scheduled_tokens)
is_first = False
# Run forward pass
with set_forward_context(prompt_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions,
prompt_data.attn_metadata,
self.kv_caches)
# In parallel to TPU execution, prepare the next iteration
if i < num_prompts - 1:
# There is next prompt => prepare it
prompt_data = self._prepare_prompt(
req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
elif i == num_prompts - 1 and num_decodes > 0:
# There is next decode => prepare it
decode_data = self._prepare_decode(pd_info.decode_req_ids)
# Update cached state (if prompt is fully done)
if seq_len >= len(req_state.prompt_token_ids):
# Transfer sampled tokens from TPU to CPU
selected_token_ids_cpu = selected_token_ids.cpu()
# Get output token
token_id = selected_token_ids_cpu[prompt_len - 1].item()
sampled_token_ids[req_index] = token_id
# DEBUG
# print(
# " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
# .format(token_id, prompt_len, req_id, req_index,
# selected_token_ids_cpu))
# Add output token to the request
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Run decodes (a single batch)
if num_decodes > 0:
# Prepare decode (if was not yet prepared)
if decode_data is None:
decode_data = self._prepare_decode(pd_info.decode_req_ids)
# Run forward pass
with set_forward_context(decode_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions,
decode_data.attn_metadata,
self.kv_caches)
# Transfer sampled tokens from TPU to CPU
decode_token_ids_cpu = selected_token_ids.cpu()
# Convert to list
decode_token_ids_list = decode_token_ids_cpu.tolist()
# Update cached state for each decode request
for i in range(num_decodes):
req_id = pd_info.decode_req_ids[i]
req_index = i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
seq_len = req_state.num_computed_tokens + 1
token_id = decode_token_ids_list[i]
sampled_token_ids[req_index] = token_id
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Create output
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=None,
logprobs_cpu=None,
)
return model_runner_output
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapperV1(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def dummy_run(
self,
kv_caches,
num_tokens: int,
seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> None:
assert seq_len is not None
assert exec_mode is not None
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((num_tokens, seq_len),
dtype=torch.int64,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = PallasMetadata(
num_prefills=num_tokens,
num_prefill_tokens=num_tokens * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.zeros(
(num_tokens, self.max_num_blocks_per_req),
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = PallasMetadata(
num_prefills=num_tokens,
num_prefill_tokens=num_tokens * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((num_tokens, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(num_tokens, self.max_num_blocks_per_req),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_tokens * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(token_ids, position_ids, attn_metadata, kv_caches)
def capture_model(self) -> None:
"""Compile the model."""
# Prefill
logger.info(
"Compiling the model with different input shapes for prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info(" -- Compilation for prefill done in %.2f [secs].",
end - start)
# Prefix prefill
if self.scheduler_config.enable_chunked_prefill:
logger.info("Compiling the model with different input shapes for "
"prefix prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info(
" -- Compilation for prefix prefill done in %.2f [secs].",
end - start)
# Decode
logger.info(
"Compiling the model with different input shapes for decode:")
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info(" -- Compilation for decode done in %.2f [secs].",
end - start)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
# Skip this in memory profiling at initialization.
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indicies *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
assert self.model is not None
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, None)
# Greedy sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
return argmax_token_ids
def _get_padded_prompt_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16

View File

@ -0,0 +1,153 @@
"""A TPU worker class."""
import os
from typing import Optional, Dict
import torch
import torch.distributed
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.utils import bind_kv_cache
logger = init_logger(__name__)
class TPUWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(vllm_config, local_rank, rank,
distributed_init_method)
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
init_tpu_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
# Device initialization should happen after initializing
# the distributed runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int:
assert self.model_runner is not None
kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
runner_kv_caches = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
self.model_runner.dummy_run(
runner_kv_caches,
num_tokens=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
exec_mode=ExecutionMode.PREFILL,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
return int(tpu_kv_cache_bytes)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
assert self.model_runner is not None
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
def init_tpu_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

View File

@ -0,0 +1,173 @@
"""A GPU worker class."""
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.model_runner_base import ModelRunnerBase
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class WorkerBase:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
# Initialized by the specific platform
self.model_runner: Optional[ModelRunnerBase] = None
def load_model(self) -> None:
assert self.model_runner is not None
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
assert self.model_runner is not None
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
assert self.model_runner is not None
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec:
assert self.model_runner is not None
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
assert self.model_runner is not None
self.model_runner.initialize_kv_cache(kv_cache_config)
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def init_device(self):
raise NotImplementedError()
def determine_available_memory(self) -> int:
raise NotImplementedError()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
raise NotImplementedError()
def check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total