mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-31 14:24:37 +08:00 
			
		
		
		
	Compare commits
	
		
			14 Commits
		
	
	
		
			v0.11.1rc3
			...
			tpu_v1_opt
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 70b4e46e70 | |||
| 5fb9dbe6f6 | |||
| 996b92ccb4 | |||
| 2b0526fa15 | |||
| 7be649256f | |||
| 627efde813 | |||
| c2867d5bc1 | |||
| 39c4a4cdb5 | |||
| 1ccf100c6a | |||
| 248c5b632d | |||
| 950f349492 | |||
| 61bb55f3d5 | |||
| 0bddb6b9a5 | |||
| c715fb19e5 | 
| @ -89,4 +89,4 @@ repos: | ||||
|     name: Suggestion | ||||
|     entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' | ||||
|     language: system | ||||
|     verbose: true | ||||
|     verbose: true | ||||
| @ -8,10 +8,10 @@ prompts = [ | ||||
|     "The future of AI is", | ||||
| ] | ||||
| # Create a sampling params object. | ||||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||||
| sampling_params = SamplingParams()  #temperature=0.8, top_p=0.95) | ||||
|  | ||||
| # Create an LLM. | ||||
| llm = LLM(model="facebook/opt-125m") | ||||
| llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) | ||||
| # Generate texts from the prompts. The output is a list of RequestOutput objects | ||||
| # that contain the prompt, generated text, and other information. | ||||
| outputs = llm.generate(prompts, sampling_params) | ||||
| @ -19,4 +19,4 @@ outputs = llm.generate(prompts, sampling_params) | ||||
| for output in outputs: | ||||
|     prompt = output.prompt | ||||
|     generated_text = output.outputs[0].text | ||||
|     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||||
|     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||||
|  | ||||
| @ -5,7 +5,7 @@ requests >= 2.26.0 | ||||
| tqdm | ||||
| blake3 | ||||
| py-cpuinfo | ||||
| transformers >= 4.45.2  # Required for Llama 3.2 and Qwen2-VL. | ||||
| transformers >= 4.48.2  # Required for Bamba model and Transformers backend. | ||||
| tokenizers >= 0.19.1  # Required for Llama 3. | ||||
| protobuf # Required by LlamaTokenizer. | ||||
| fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' | ||||
| @ -34,6 +34,6 @@ pyyaml | ||||
| six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 | ||||
| setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 | ||||
| einops # Required for Qwen2-VL. | ||||
| compressed-tensors == 0.8.1 # required for compressed-tensors | ||||
| compressed-tensors == 0.9.1 # required for compressed-tensors | ||||
| depyf==0.18.0 # required for profiling and debugging with compilation config | ||||
| cloudpickle # allows pickling lambda functions in model_executor/models/registry.py | ||||
| cloudpickle # allows pickling lambda functions in model_executor/models/registry.py | ||||
| @ -2,7 +2,7 @@ | ||||
| # This file is autogenerated by pip-compile with Python 3.12 | ||||
| # by the following command: | ||||
| # | ||||
| #    python3.12 -m piptools compile requirements-test.in -o requirements-test.txt | ||||
| # python3.12 -m piptools compile requirements-test.in -o requirements-test.txt | ||||
| # | ||||
| absl-py==2.1.0 | ||||
|     # via rouge-score | ||||
| @ -106,9 +106,17 @@ dnspython==2.7.0 | ||||
| docutils==0.16 | ||||
|     # via awscli | ||||
| einops==0.8.0 | ||||
|     # via -r requirements-test.in | ||||
|     # via | ||||
|     #   -r requirements-test.in | ||||
|     #   encodec | ||||
|     #   vector-quantize-pytorch | ||||
|     #   vocos | ||||
| einx==0.3.0 | ||||
|     # via vector-quantize-pytorch | ||||
| email-validator==2.2.0 | ||||
|     # via pydantic | ||||
| encodec==0.1.1 | ||||
|     # via vocos | ||||
| evaluate==0.4.3 | ||||
|     # via lm-eval | ||||
| fastparquet==2024.11.0 | ||||
| @ -125,6 +133,8 @@ filelock==3.16.1 | ||||
|     #   triton | ||||
| fonttools==4.54.1 | ||||
|     # via matplotlib | ||||
| frozendict==2.4.6 | ||||
|     # via einx | ||||
| frozenlist==1.5.0 | ||||
|     # via | ||||
|     #   aiohttp | ||||
| @ -159,6 +169,7 @@ huggingface-hub==0.26.2 | ||||
|     #   timm | ||||
|     #   tokenizers | ||||
|     #   transformers | ||||
|     #   vocos | ||||
| idna==3.10 | ||||
|     # via | ||||
|     #   anyio | ||||
| @ -261,6 +272,8 @@ numpy==1.26.4 | ||||
|     #   cupy-cuda12x | ||||
|     #   datasets | ||||
|     #   decord | ||||
|     #   einx | ||||
|     #   encodec | ||||
|     #   evaluate | ||||
|     #   fastparquet | ||||
|     #   genai-perf | ||||
| @ -283,6 +296,7 @@ numpy==1.26.4 | ||||
|     #   torchvision | ||||
|     #   transformers | ||||
|     #   tritonclient | ||||
|     #   vocos | ||||
| nvidia-cublas-cu12==12.4.5.8 | ||||
|     # via | ||||
|     #   nvidia-cudnn-cu12 | ||||
| @ -455,6 +469,7 @@ pyyaml==6.0.2 | ||||
|     #   responses | ||||
|     #   timm | ||||
|     #   transformers | ||||
|     #   vocos | ||||
| ray[adag]==2.40.0 | ||||
|     # via -r requirements-test.in | ||||
| redis==5.2.0 | ||||
| @ -517,6 +532,7 @@ scipy==1.13.1 | ||||
|     #   scikit-learn | ||||
|     #   sentence-transformers | ||||
|     #   statsmodels | ||||
|     #   vocos | ||||
| sentence-transformers==3.2.1 | ||||
|     # via -r requirements-test.in | ||||
| sentencepiece==0.2.0 | ||||
| @ -540,7 +556,9 @@ sqlitedict==2.1.0 | ||||
| statsmodels==0.14.4 | ||||
|     # via genai-perf | ||||
| sympy==1.13.1 | ||||
|     # via torch | ||||
|     # via | ||||
|     #   einx | ||||
|     #   torch | ||||
| tabledata==1.3.3 | ||||
|     # via pytablewriter | ||||
| tabulate==0.9.0 | ||||
| @ -568,12 +586,21 @@ torch==2.5.1 | ||||
|     #   -r requirements-test.in | ||||
|     #   accelerate | ||||
|     #   bitsandbytes | ||||
|     #   encodec | ||||
|     #   lm-eval | ||||
|     #   peft | ||||
|     #   sentence-transformers | ||||
|     #   tensorizer | ||||
|     #   timm | ||||
|     #   torchaudio | ||||
|     #   torchvision | ||||
|     #   vector-quantize-pytorch | ||||
|     #   vocos | ||||
| torchaudio==2.5.1 | ||||
|     # via | ||||
|     #   -r requirements-test.in | ||||
|     #   encodec | ||||
|     #   vocos | ||||
| torchvision==0.20.1 | ||||
|     # via timm | ||||
| tqdm==4.66.6 | ||||
| @ -584,13 +611,15 @@ tqdm==4.66.6 | ||||
|     #   lm-eval | ||||
|     #   nltk | ||||
|     #   peft | ||||
|     #   pqdm | ||||
|     #   sentence-transformers | ||||
|     #   tqdm-multiprocess | ||||
|     #   transformers | ||||
| tqdm-multiprocess==0.0.11 | ||||
|     # via lm-eval | ||||
| transformers==4.47.0 | ||||
| transformers==4.48.2 | ||||
|     # via | ||||
|     #   -r requirements-test.in | ||||
|     #   genai-perf | ||||
|     #   lm-eval | ||||
|     #   peft | ||||
| @ -615,6 +644,7 @@ typing-extensions==4.12.2 | ||||
|     #   huggingface-hub | ||||
|     #   librosa | ||||
|     #   mistral-common | ||||
|     #   pqdm | ||||
|     #   pydantic | ||||
|     #   pydantic-core | ||||
|     #   torch | ||||
| @ -626,6 +656,10 @@ urllib3==2.2.3 | ||||
|     #   requests | ||||
|     #   responses | ||||
|     #   tritonclient | ||||
| vector-quantize-pytorch==1.21.2 | ||||
|     # via -r requirements-test.in | ||||
| vocos==0.1.0 | ||||
|     # via -r requirements-test.in | ||||
| word2number==1.1 | ||||
|     # via lm-eval | ||||
| xxhash==3.5.0 | ||||
| @ -638,4 +672,4 @@ zstandard==0.23.0 | ||||
|     # via lm-eval | ||||
|  | ||||
| # The following packages are considered to be unsafe in a requirements file: | ||||
| # setuptools | ||||
| # setuptools | ||||
| @ -13,13 +13,11 @@ ray[default] | ||||
| # Install torch_xla | ||||
| --pre | ||||
| --extra-index-url https://download.pytorch.org/whl/nightly/cpu | ||||
| --find-links https://storage.googleapis.com/libtpu-wheels/index.html | ||||
| --find-links https://storage.googleapis.com/libtpu-releases/index.html | ||||
| --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html | ||||
| --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | ||||
| torch==2.6.0.dev20241126+cpu | ||||
| torchvision==0.20.0.dev20241126+cpu | ||||
| torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" | ||||
| torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" | ||||
| torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" | ||||
| jaxlib==0.4.36.dev20241122 | ||||
| jax==0.4.36.dev20241122 | ||||
| torch==2.6.0.dev20241216+cpu | ||||
| torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" | ||||
| torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" | ||||
| torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" | ||||
| @ -20,7 +20,7 @@ TASK = "gsm8k" | ||||
| FILTER = "exact_match,strict-match" | ||||
| RTOL = 0.03 | ||||
| EXPECTED_VALUE = 0.58 | ||||
| DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"] | ||||
| DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] | ||||
| MORE_ARGS_LIST = [ | ||||
|     [],  # Default | ||||
|     ["--enable-chunked-prefill"],  # Chunked | ||||
| @ -66,14 +66,21 @@ def run_test(more_args): | ||||
|                 ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}" | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif(not current_platform.is_cuda(), | ||||
|                     reason="V1 currently only supported on CUDA") | ||||
| @pytest.mark.skipif(not current_platform.is_cuda() | ||||
|                     and not current_platform.is_tpu(), | ||||
|                     reason="V1 currently only supported on CUDA and TPU") | ||||
| def test_lm_eval_accuracy_v1_engine(monkeypatch): | ||||
|     """Run with the V1 Engine.""" | ||||
|  | ||||
|     with monkeypatch.context() as m: | ||||
|         m.setenv("VLLM_USE_V1", "1") | ||||
|         run_test([]) | ||||
|         more_args = [] | ||||
|  | ||||
|         # Limit compilation time for V1 | ||||
|         if current_platform.is_tpu(): | ||||
|             more_args = ["--max-num-seqs", "64"] | ||||
|  | ||||
|         run_test(more_args) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) | ||||
|  | ||||
| @ -34,4 +34,4 @@ run_mypy vllm/plugins | ||||
| run_mypy vllm/prompt_adapter | ||||
| run_mypy vllm/spec_decode | ||||
| run_mypy vllm/worker | ||||
| run_mypy vllm/v1 | ||||
| run_mypy vllm/v1 | ||||
| @ -135,7 +135,7 @@ class CudaPlatformBase(Platform): | ||||
|             else: | ||||
|                 if envs.VLLM_USE_V1: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                             "vllm.v1.worker.gpu_worker.Worker" | ||||
|                             "vllm.v1.worker.gpu_worker.GPUWorker" | ||||
|                 else: | ||||
|                     parallel_config.worker_cls = "vllm.worker.worker.Worker" | ||||
|  | ||||
|  | ||||
| @ -32,6 +32,7 @@ class _Backend(enum.Enum): | ||||
|     FLASHINFER = enum.auto() | ||||
|     HPU_ATTN = enum.auto() | ||||
|     PALLAS = enum.auto() | ||||
|     PALLAS_VLLM_V1 = enum.auto() | ||||
|     IPEX = enum.auto() | ||||
|     BLOCK_SPARSE_FLASH_ATTN = enum.auto() | ||||
|     NO_ATTENTION = enum.auto() | ||||
|  | ||||
| @ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| import torch | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.logger import init_logger | ||||
|  | ||||
| from .interface import Platform, PlatformEnum, _Backend | ||||
| @ -30,10 +31,16 @@ class TpuPlatform(Platform): | ||||
|     def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, | ||||
|                              dtype: torch.dtype, kv_cache_dtype: Optional[str], | ||||
|                              block_size: int, use_v1: bool) -> str: | ||||
|         if selected_backend != _Backend.PALLAS: | ||||
|         if (selected_backend != _Backend.PALLAS | ||||
|                 and selected_backend != _Backend.PALLAS_VLLM_V1): | ||||
|             logger.info("Cannot use %s backend on TPU.", selected_backend) | ||||
|         logger.info("Using Pallas backend.") | ||||
|         return "vllm.attention.backends.pallas.PallasAttentionBackend" | ||||
|  | ||||
|         if use_v1: | ||||
|             logger.info("Using Pallas V1 backend.") | ||||
|             return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" | ||||
|         else: | ||||
|             logger.info("Using Pallas backend.") | ||||
|             return "vllm.attention.backends.pallas.PallasAttentionBackend" | ||||
|  | ||||
|     @classmethod | ||||
|     def get_device_name(cls, device_id: int = 0) -> str: | ||||
| @ -45,7 +52,7 @@ class TpuPlatform(Platform): | ||||
|  | ||||
|     @classmethod | ||||
|     def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: | ||||
|         return True | ||||
|         return not envs.VLLM_USE_V1 | ||||
|  | ||||
|     @classmethod | ||||
|     def inference_mode(cls): | ||||
| @ -60,11 +67,11 @@ class TpuPlatform(Platform): | ||||
|             cache_config.block_size = 16 | ||||
|  | ||||
|         compilation_config = vllm_config.compilation_config | ||||
|         if compilation_config.level == CompilationLevel.NO_COMPILATION: | ||||
|             # TPU does not support NO_COMPILATION | ||||
|  | ||||
|         # TPU only supports DYNAMO_ONCE compilation level | ||||
|         if compilation_config.level != CompilationLevel.DYNAMO_ONCE: | ||||
|             logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") | ||||
|             compilation_config.level = CompilationLevel.DYNAMO_ONCE | ||||
|         assert compilation_config.level < CompilationLevel.PIECEWISE,\ | ||||
|             "TPU does not support Inductor." | ||||
|  | ||||
|         if compilation_config.backend == "": | ||||
|             compilation_config.backend = "openxla" | ||||
| @ -72,10 +79,6 @@ class TpuPlatform(Platform): | ||||
|         assert vllm_config.speculative_config is None, \ | ||||
|             "TPU does not support speculative decoding" | ||||
|  | ||||
|         assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( | ||||
|             "Chunked prefill is not yet supported for TPU backend") | ||||
|         assert not vllm_config.speculative_config, ( | ||||
|             "Speculative decoding is not yet supported for TPU backend") | ||||
|         if vllm_config.model_config.dtype in (torch.float16, torch.float32): | ||||
|             logger.warning( | ||||
|                 "The TPU backend currently does not support %s. " | ||||
| @ -85,8 +88,27 @@ class TpuPlatform(Platform): | ||||
|         parallel_config = vllm_config.parallel_config | ||||
|         scheduler_config = vllm_config.scheduler_config | ||||
|         if parallel_config.worker_cls == "auto": | ||||
|             if scheduler_config.is_multi_step: | ||||
|             if envs.VLLM_USE_V1: | ||||
|                 parallel_config.worker_cls = \ | ||||
|                     "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" | ||||
|                     "vllm.v1.worker.tpu_worker.TPUWorker" | ||||
|             else: | ||||
|                 parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" | ||||
|                 if scheduler_config.is_multi_step: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" | ||||
|                 else: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.worker.tpu_worker.TPUWorker" | ||||
|  | ||||
|         # Adjust scheduler config for V1 | ||||
|         # TODO: Add support for these | ||||
|         if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching: | ||||
|             logger.warning("[V1][TPU] Disable prefix caching") | ||||
|             vllm_config.cache_config.enable_prefix_caching = False | ||||
|  | ||||
|         assert not vllm_config.speculative_config, ( | ||||
|             "Speculative decoding is not yet supported for TPU backend") | ||||
|  | ||||
|     @classmethod | ||||
|     def is_pin_memory_available(cls): | ||||
|         logger.warning("Pin memory is not supported on TPU.") | ||||
|         return False | ||||
|  | ||||
							
								
								
									
										351
									
								
								vllm/v1/attention/backends/pallas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										351
									
								
								vllm/v1/attention/backends/pallas.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,351 @@ | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| import torch | ||||
| import torch_xla.experimental.custom_kernel  # Required to register custom ops. | ||||
|  | ||||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||
|                                               AttentionLayer, | ||||
|                                               AttentionMetadata, AttentionType) | ||||
| from vllm.attention.backends.utils import CommonAttentionState | ||||
|  | ||||
|  | ||||
| class PallasAttentionBackend(AttentionBackend): | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_name() -> str: | ||||
|         return "PALLAS_VLLM_V1" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: | ||||
|         return PallasAttentionBackendImpl | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_metadata_cls() -> Type["PallasMetadata"]: | ||||
|         return PallasMetadata | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_state_cls() -> Type["CommonAttentionState"]: | ||||
|         return CommonAttentionState | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_kv_cache_shape( | ||||
|         num_blocks: int, | ||||
|         block_size: int, | ||||
|         num_kv_heads: int, | ||||
|         head_size: int, | ||||
|     ) -> Tuple[int, ...]: | ||||
|         return (num_kv_heads, num_blocks, block_size, head_size) | ||||
|  | ||||
|     @staticmethod | ||||
|     def swap_blocks( | ||||
|         src_kv_cache: torch.Tensor, | ||||
|         dst_kv_cache: torch.Tensor, | ||||
|         src_to_dst: torch.Tensor, | ||||
|     ) -> None: | ||||
|         raise RuntimeError("swap_blocks is not used for the TPU backend.") | ||||
|  | ||||
|     @torch.compile(backend="openxla") | ||||
|     @staticmethod | ||||
|     def copy_blocks( | ||||
|         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], | ||||
|         src_to_dists: Tuple[torch.Tensor, torch.Tensor], | ||||
|     ) -> None: | ||||
|         src_indices, dst_indices = src_to_dists | ||||
|         for k_cache, v_cache in kv_caches: | ||||
|             torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) | ||||
|             k_cache[:, dst_indices] = k_cache[:, src_indices] | ||||
|             torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) | ||||
|             v_cache[:, dst_indices] = v_cache[:, src_indices] | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class PallasMetadata(AttentionMetadata): | ||||
|  | ||||
|     # Currently, input sequences can only contain all prefills | ||||
|     # or all decoding. | ||||
|     block_tables: Optional[torch.Tensor] = None | ||||
|     context_lens: Optional[torch.Tensor] = None | ||||
|     effective_query_lens: Optional[torch.Tensor] = None | ||||
|  | ||||
|     @property | ||||
|     def prefill_metadata(self) -> Optional["PallasMetadata"]: | ||||
|         if self.num_prefills == 0: | ||||
|             return None | ||||
|  | ||||
|         assert self.num_decode_tokens == 0 | ||||
|         return self | ||||
|  | ||||
|     @property | ||||
|     def decode_metadata(self) -> Optional["PallasMetadata"]: | ||||
|         if self.num_decode_tokens == 0: | ||||
|             return None | ||||
|  | ||||
|         assert self.num_prefills == 0 | ||||
|         assert self.num_prefill_tokens == 0 | ||||
|         assert self.block_tables is not None | ||||
|         assert self.context_lens is not None | ||||
|         return self | ||||
|  | ||||
|  | ||||
| class PallasAttentionBackendImpl(AttentionImpl): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_heads: int, | ||||
|         head_size: int, | ||||
|         scale: float, | ||||
|         num_kv_heads: int, | ||||
|         alibi_slopes: Optional[List[float]], | ||||
|         sliding_window: Optional[int], | ||||
|         kv_cache_dtype: str, | ||||
|         blocksparse_params: Optional[Dict[str, Any]] = None, | ||||
|         logits_soft_cap: Optional[float] = None, | ||||
|         attn_type: str = AttentionType.DECODER, | ||||
|     ) -> None: | ||||
|         self.num_heads = num_heads | ||||
|         self.head_size = head_size | ||||
|         self.scale = float(scale) | ||||
|         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||||
|  | ||||
|         assert self.num_heads % self.num_kv_heads == 0 | ||||
|         self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||||
|         if head_size % 128 != 0: | ||||
|             raise NotImplementedError("Head size must be a multiple of 128.") | ||||
|         if alibi_slopes is not None: | ||||
|             raise NotImplementedError("Alibi slopes is not supported.") | ||||
|         if sliding_window is not None: | ||||
|             raise NotImplementedError("Sliding window is not supported.") | ||||
|         if kv_cache_dtype != "auto": | ||||
|             raise NotImplementedError("FP8 KV cache dtype is not supported.") | ||||
|         if blocksparse_params is not None: | ||||
|             raise NotImplementedError("Blocksparse is not supported.") | ||||
|         if logits_soft_cap is not None: | ||||
|             raise NotImplementedError( | ||||
|                 "Attention logits soft-capping is not supported.") | ||||
|  | ||||
|         if torch_xla.tpu.version() < 4: | ||||
|             raise NotImplementedError("TPU version must be 4 or higher.") | ||||
|  | ||||
|         self.megacore_mode = None | ||||
|         tpu_env = torch_xla.tpu.get_tpu_env() | ||||
|         tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) | ||||
|                     or tpu_env.get("TYPE", None) | ||||
|                     or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) | ||||
|         assert tpu_type is not None | ||||
|         tpu_type = tpu_type.lower() | ||||
|  | ||||
|         if (("lite" not in tpu_type) and ("v6" not in tpu_type)): | ||||
|             if self.num_kv_heads % 2 == 0: | ||||
|                 self.megacore_mode = "kv_head" | ||||
|             else: | ||||
|                 # NOTE(woosuk): If the batch size is not a multiple of 2, the | ||||
|                 # megacore mode will be None. | ||||
|                 self.megacore_mode = "batch" | ||||
|  | ||||
|         if attn_type != AttentionType.DECODER: | ||||
|             raise NotImplementedError("Encoder self-attention and " | ||||
|                                       "encoder/decoder cross-attention " | ||||
|                                       "are not implemented for " | ||||
|                                       "PallasAttentionBackendImpl") | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         layer: AttentionLayer, | ||||
|         query: torch.Tensor, | ||||
|         key: torch.Tensor, | ||||
|         value: torch.Tensor, | ||||
|         kv_cache: Tuple[torch.Tensor, torch.Tensor], | ||||
|         attn_metadata: PallasMetadata, | ||||
|         output: Optional[torch.Tensor] = None, | ||||
|     ) -> torch.Tensor: | ||||
|         """Forward pass with Pallas attention. | ||||
|  | ||||
|         Args: | ||||
|             query: shape = [batch_size, seq_len, num_heads * head_size] | ||||
|             key: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||||
|             value: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||||
|             kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] | ||||
|             kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] | ||||
|                 NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor  | ||||
|                 with shape [0] for profiling run. | ||||
|             attn_metadata: Metadata for attention. | ||||
|         Returns: | ||||
|             shape = [batch_size, seq_len, num_heads * head_size] | ||||
|         """ | ||||
|  | ||||
|         if attn_metadata is None: | ||||
|             if output is None: | ||||
|                 output = torch.ones_like(query) | ||||
|             return output | ||||
|  | ||||
|         assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 | ||||
|         batch_size, seq_len, hidden_size = query.shape | ||||
|         query = query.view(batch_size, seq_len, self.num_heads, self.head_size) | ||||
|         key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) | ||||
|         value = value.view(batch_size, seq_len, self.num_kv_heads, | ||||
|                            self.head_size) | ||||
|  | ||||
|         if kv_cache[0].numel() > 0: | ||||
|             slot_mapping = attn_metadata.slot_mapping | ||||
|             key_cache, value_cache = kv_cache | ||||
|             write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) | ||||
|  | ||||
|         query = query * self.scale | ||||
|         if attn_metadata.num_prefills > 0: | ||||
|             if attn_metadata.block_tables is None: | ||||
|                 # Prefill without paged KV cache. | ||||
|                 assert seq_len % 16 == 0, ( | ||||
|                     "Pallas FlashAttention kernel requires seq_len to be a " | ||||
|                     f"multiple of 16 but got {seq_len}") | ||||
|  | ||||
|                 # Handle GQA/MQA. | ||||
|                 if self.num_kv_heads != self.num_heads: | ||||
|                     key = key.repeat_interleave(self.num_queries_per_kv, | ||||
|                                                 dim=-2) | ||||
|                     key = key.view(batch_size, seq_len, self.num_heads, | ||||
|                                    self.head_size) | ||||
|                     value = value.repeat_interleave(self.num_queries_per_kv, | ||||
|                                                     dim=-2) | ||||
|                     value = value.view(batch_size, seq_len, self.num_heads, | ||||
|                                        self.head_size) | ||||
|                 # FlashAttention kernel requires the input shape to be | ||||
|                 # [batch_size, num_heads, seq_len, d_model] | ||||
|                 # while the input is [batch_size, seq_len, num_heads, d_model]. | ||||
|                 # Permute the input to match the required format. | ||||
|                 output = torch.ops.xla.flash_attention( | ||||
|                     query.permute(0, 2, 1, 3), | ||||
|                     key.permute(0, 2, 1, 3), | ||||
|                     value.permute(0, 2, 1, 3), | ||||
|                     True, | ||||
|                 ) | ||||
|                 output = output.permute(0, 2, 1, 3) | ||||
|             else: | ||||
|                 # Prefill with paged KV cache. | ||||
|                 # TODO(woosuk): Tune the below knobs. | ||||
|                 num_kv_pages_per_compute_block = 16 | ||||
|                 num_queries_per_compute_block = 16 | ||||
|                 assert seq_len % num_queries_per_compute_block == 0 | ||||
|                 output = torch.ops.xla.multi_queries_paged_attention( | ||||
|                     query, | ||||
|                     key_cache, | ||||
|                     value_cache, | ||||
|                     attn_metadata.context_lens, | ||||
|                     attn_metadata.block_tables, | ||||
|                     attn_metadata.effective_query_lens, | ||||
|                     num_kv_pages_per_compute_block, | ||||
|                     num_queries_per_compute_block, | ||||
|                     use_kernel=True, | ||||
|                 ) | ||||
|         else: | ||||
|             # Decoding run. | ||||
|             assert kv_cache[0].numel() > 0 | ||||
|             query = query.squeeze(dim=1) | ||||
|             pages_per_compute_block = 16  # TODO(woosuk): Tune this value. | ||||
|  | ||||
|             assert attn_metadata.block_tables is not None | ||||
|             assert attn_metadata.context_lens is not None | ||||
|             # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire | ||||
|             # block table in SMEM. Therefore, if the block table is too large, | ||||
|             # the kernel compilation will fail. To avoid this, we split the | ||||
|             # batch dimension into smaller chunks and run the kernel multiple | ||||
|             # times. | ||||
|             MAX_SMEM_USAGE = 512 * 1024 | ||||
|             size_per_seq = 4 * attn_metadata.block_tables.shape[1] | ||||
|             max_num_seq = MAX_SMEM_USAGE // size_per_seq | ||||
|  | ||||
|             if batch_size <= max_num_seq: | ||||
|                 output = paged_attention( | ||||
|                     query, | ||||
|                     key_cache, | ||||
|                     value_cache, | ||||
|                     attn_metadata.context_lens, | ||||
|                     attn_metadata.block_tables, | ||||
|                     pages_per_compute_block, | ||||
|                     self.megacore_mode, | ||||
|                 ) | ||||
|             else: | ||||
|                 chunk_size = max_num_seq | ||||
|                 # Make sure the chunk size is a multiple of 2. | ||||
|                 chunk_size = chunk_size // 2 * 2 | ||||
|                 num_chunks = (batch_size + chunk_size - 1) // chunk_size | ||||
|  | ||||
|                 output = torch.empty_like(query) | ||||
|                 for chunk_idx in range(num_chunks): | ||||
|                     chunk_start = chunk_idx * chunk_size | ||||
|                     chunk_end = chunk_start + chunk_size | ||||
|                     # NOTE(woosuk): We skip this line because it causes Dynamo | ||||
|                     # compilation error. Instead, we rely on the slice operation | ||||
|                     # to handle the out-of-bound case. | ||||
|                     # chunk_end = min(chunk_end, batch_size) | ||||
|                     chunk_output = paged_attention( | ||||
|                         query[chunk_start:chunk_end], | ||||
|                         key_cache, | ||||
|                         value_cache, | ||||
|                         attn_metadata.context_lens[chunk_start:chunk_end], | ||||
|                         attn_metadata.block_tables[chunk_start:chunk_end], | ||||
|                         pages_per_compute_block, | ||||
|                         self.megacore_mode, | ||||
|                     ) | ||||
|                     output[chunk_start:chunk_end] = chunk_output | ||||
|  | ||||
|         # Reshape the output tensor. | ||||
|         return output.reshape(batch_size, seq_len, hidden_size) | ||||
|  | ||||
|  | ||||
| def write_to_kv_cache( | ||||
|     key: torch.Tensor, | ||||
|     value: torch.Tensor, | ||||
|     key_cache: torch.Tensor, | ||||
|     value_cache: torch.Tensor, | ||||
|     slot_mapping: torch.Tensor, | ||||
| ) -> None: | ||||
|     torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) | ||||
|     torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) | ||||
|  | ||||
|     key = key.flatten(0, 2) | ||||
|     value = value.flatten(0, 2) | ||||
|     key_cache = key_cache.flatten(0, 2) | ||||
|     value_cache = value_cache.flatten(0, 2) | ||||
|     key_cache.index_copy_(0, slot_mapping, key) | ||||
|     value_cache.index_copy_(0, slot_mapping, value) | ||||
|  | ||||
|  | ||||
| def paged_attention( | ||||
|     query: torch.Tensor, | ||||
|     key_cache: torch.Tensor, | ||||
|     value_cache: torch.Tensor, | ||||
|     context_lens: torch.Tensor, | ||||
|     block_tables: torch.Tensor, | ||||
|     pages_per_compute_block: int, | ||||
|     megacore_mode: Optional[str], | ||||
| ) -> torch.Tensor: | ||||
|     batch_size = query.shape[0] | ||||
|     if megacore_mode == "batch" and batch_size % 2 != 0: | ||||
|         megacore_mode = None | ||||
|     else: | ||||
|         megacore_mode = megacore_mode | ||||
|  | ||||
|     # NOTE(woosuk): A temporary workaround to avoid the error: | ||||
|     # "xla::paged_attention() Expected a value of type 'str' for | ||||
|     # argument 'megacore_mode' but instead found type 'NoneType'." | ||||
|     if megacore_mode is not None: | ||||
|         output = torch.ops.xla.paged_attention( | ||||
|             query, | ||||
|             key_cache, | ||||
|             value_cache, | ||||
|             context_lens, | ||||
|             block_tables, | ||||
|             pages_per_compute_block, | ||||
|             megacore_mode=megacore_mode, | ||||
|         ) | ||||
|     else: | ||||
|         output = torch.ops.xla.paged_attention( | ||||
|             query, | ||||
|             key_cache, | ||||
|             value_cache, | ||||
|             context_lens, | ||||
|             block_tables, | ||||
|             pages_per_compute_block, | ||||
|         ) | ||||
|     return output | ||||
| @ -57,6 +57,14 @@ class BlockTable: | ||||
|             src, :num_blocks] | ||||
|         self.num_blocks_per_row[tgt] = num_blocks | ||||
|  | ||||
|     def swap_row(self, src: int, tgt: int) -> None: | ||||
|         num_blocks_src = self.num_blocks_per_row[src] | ||||
|         num_blocks_tgt = self.num_blocks_per_row[tgt] | ||||
|         self.num_blocks_per_row[src] = num_blocks_tgt | ||||
|         self.num_blocks_per_row[tgt] = num_blocks_src | ||||
|  | ||||
|         self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] | ||||
|  | ||||
|     def commit(self, num_reqs: int) -> None: | ||||
|         self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], | ||||
|                                           non_blocking=True) | ||||
|  | ||||
| @ -72,7 +72,7 @@ class InputBatch: | ||||
|         self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() | ||||
|         self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) | ||||
|         self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) | ||||
|         self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) | ||||
|         self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) | ||||
|  | ||||
|         # Block table. | ||||
|         self.block_table = BlockTable( | ||||
| @ -436,3 +436,77 @@ class InputBatch: | ||||
|     @property | ||||
|     def no_prompt_logprob(self) -> bool: | ||||
|         return len(self.prompt_logprob_reqs) == 0 | ||||
|  | ||||
|  | ||||
| def swap_positions(b: InputBatch, id_1, id_2): | ||||
|     assert id_1 != id_2 | ||||
|     req_id_1 = b.req_ids[id_1] | ||||
|     req_id_2 = b.req_ids[id_2] | ||||
|     assert req_id_1 is not None | ||||
|     assert req_id_2 is not None | ||||
|     assert id_1 == b.req_id_to_index[req_id_1] | ||||
|     assert id_2 == b.req_id_to_index[req_id_2] | ||||
|  | ||||
|     b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] | ||||
|     b.req_id_to_index[req_id_1], b.req_id_to_index[ | ||||
|         req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] | ||||
|  | ||||
|     ids = [id_1, id_2] | ||||
|     rev_ids = [id_2, id_1] | ||||
|     b.num_tokens[ids] = b.num_tokens[rev_ids] | ||||
|     b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] | ||||
|     b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] | ||||
|     b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] | ||||
|  | ||||
|     b.block_table.swap_row(id_1, id_2) | ||||
|  | ||||
|     b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] | ||||
|     b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] | ||||
|     b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] | ||||
|     b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] | ||||
|     b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] | ||||
|     b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] | ||||
|  | ||||
|     b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ | ||||
|         id_1] | ||||
|     b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ | ||||
|         id_2], b.stop_token_ids[id_1] | ||||
|  | ||||
|     gen_1 = b.generators.pop(id_1, None) | ||||
|     gen_2 = b.generators.pop(id_2, None) | ||||
|     if gen_1 is not None: | ||||
|         b.generators[id_2] = gen_1 | ||||
|     if gen_2 is not None: | ||||
|         b.generators[id_1] = gen_2 | ||||
|  | ||||
|  | ||||
| def ensure_decodes_first(b: InputBatch): | ||||
|     num_reqs = b.num_reqs | ||||
|     while True: | ||||
|         # Find the first prompt index | ||||
|         first_prompt_index = None | ||||
|         for i in range(num_reqs): | ||||
|             if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: | ||||
|                 first_prompt_index = i | ||||
|                 break | ||||
|         if first_prompt_index is None: | ||||
|             break | ||||
|  | ||||
|         # Find the last decode index | ||||
|         last_decode_index = None | ||||
|         for i in reversed(range(num_reqs)): | ||||
|             if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: | ||||
|                 last_decode_index = i | ||||
|                 break | ||||
|         if last_decode_index is None: | ||||
|             break | ||||
|  | ||||
|         # Sanity | ||||
|         assert first_prompt_index != last_decode_index | ||||
|  | ||||
|         # Check if done | ||||
|         if first_prompt_index > last_decode_index: | ||||
|             break | ||||
|  | ||||
|         # Swap | ||||
|         swap_positions(b, first_prompt_index, last_decode_index) | ||||
|  | ||||
| @ -5,32 +5,23 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.distributed | ||||
| import torch.nn as nn | ||||
|  | ||||
| from vllm.attention.backends.abstract import AttentionType | ||||
| from vllm.attention.layer import Attention | ||||
| from vllm.config import CompilationLevel, VllmConfig | ||||
| from vllm.distributed.parallel_state import graph_capture | ||||
| from vllm.forward_context import set_forward_context | ||||
| from vllm.inputs import INPUT_REGISTRY | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding | ||||
| from vllm.model_executor.model_loader import get_model | ||||
| from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs | ||||
| from vllm.multimodal.utils import group_mm_inputs_by_modality | ||||
| from vllm.sampling_params import SamplingType | ||||
| from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, | ||||
|                         LayerBlockType, cdiv, is_pin_memory_available) | ||||
| from vllm.utils import DeviceMemoryProfiler, cdiv | ||||
| from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, | ||||
|                                                    FlashAttentionMetadata) | ||||
| from vllm.v1.core.encoder_cache_manager import compute_encoder_budget | ||||
| from vllm.v1.engine.mm_input_mapper import MMInputMapperClient | ||||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||||
|                                         KVCacheSpec) | ||||
| from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig | ||||
| from vllm.v1.outputs import ModelRunnerOutput | ||||
| from vllm.v1.sample.metadata import SamplingMetadata | ||||
| from vllm.v1.utils import bind_kv_cache | ||||
| from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch | ||||
| from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.v1.core.scheduler import SchedulerOutput | ||||
| @ -38,87 +29,17 @@ if TYPE_CHECKING: | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class GPUModelRunner: | ||||
| class GPUModelRunner(ModelRunnerBase): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         device: torch.device, | ||||
|     ): | ||||
|         self.vllm_config = vllm_config | ||||
|         self.model_config = vllm_config.model_config | ||||
|         self.cache_config = vllm_config.cache_config | ||||
|         self.lora_config = vllm_config.lora_config | ||||
|         self.load_config = vllm_config.load_config | ||||
|         self.parallel_config = vllm_config.parallel_config | ||||
|         self.scheduler_config = vllm_config.scheduler_config | ||||
|         self.speculative_config = vllm_config.speculative_config | ||||
|         self.prompt_adapter_config = vllm_config.prompt_adapter_config | ||||
|         self.observability_config = vllm_config.observability_config | ||||
|         super().__init__(vllm_config, device) | ||||
|  | ||||
|         model_config = self.model_config | ||||
|         cache_config = self.cache_config | ||||
|         scheduler_config = self.scheduler_config | ||||
|         parallel_config = self.parallel_config | ||||
|         self.device = device | ||||
|         self.pin_memory = is_pin_memory_available() | ||||
|         self.dtype = self.model_config.dtype | ||||
|         if cache_config.cache_dtype == "auto": | ||||
|             self.kv_cache_dtype = self.dtype | ||||
|         else: | ||||
|             self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ | ||||
|                 cache_config.cache_dtype] | ||||
|  | ||||
|         self.is_multimodal_model = model_config.is_multimodal_model | ||||
|         self.sliding_window = model_config.get_sliding_window() | ||||
|         self.block_size = cache_config.block_size | ||||
|         self.max_model_len = model_config.max_model_len | ||||
|         self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) | ||||
|         self.max_num_tokens = scheduler_config.max_num_batched_tokens | ||||
|         self.max_num_reqs = scheduler_config.max_num_seqs | ||||
|  | ||||
|         # Model-related. | ||||
|         self.num_attn_layers = model_config.get_num_layers_by_block_type( | ||||
|             parallel_config, LayerBlockType.attention) | ||||
|         self.num_query_heads = model_config.get_num_attention_heads( | ||||
|             parallel_config) | ||||
|         self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) | ||||
|         self.head_size = model_config.get_head_size() | ||||
|         self.hidden_size = model_config.get_hidden_size() | ||||
|  | ||||
|         # Multi-modal data support | ||||
|         self.input_registry = INPUT_REGISTRY | ||||
|         self.mm_registry = MULTIMODAL_REGISTRY | ||||
|  | ||||
|         # NOTE: Initialized input mapper is only used for processing dummy | ||||
|         # multimodal data into multimodal kwargs for GPU memory profiling. | ||||
|         self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) | ||||
|         self.mm_input_mapper_profiling.use_cache = False | ||||
|  | ||||
|         encoder_compute_budget, encoder_cache_size = compute_encoder_budget( | ||||
|             model_config=model_config, | ||||
|             scheduler_config=scheduler_config, | ||||
|         ) | ||||
|         self.max_num_encoder_input_tokens = encoder_compute_budget | ||||
|         self.encoder_cache_size = encoder_cache_size | ||||
|  | ||||
|         # Lazy initialization | ||||
|         # self.model: nn.Module  # Set after load_model | ||||
|         # KV caches for forward pass | ||||
|         self.kv_caches: List[torch.Tensor] = [] | ||||
|         # req_id -> (input_id -> encoder_output) | ||||
|         self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} | ||||
|  | ||||
|         # Request states. | ||||
|         self.requests: Dict[str, CachedRequestState] = {} | ||||
|         # Persistent batch. | ||||
|         self.input_batch = InputBatch( | ||||
|             max_num_reqs=self.max_num_reqs, | ||||
|             max_model_len=self.max_model_len, | ||||
|             max_num_blocks_per_req=self.max_num_blocks_per_req, | ||||
|             device=self.device, | ||||
|             pin_memory=self.pin_memory, | ||||
|             vocab_size=model_config.get_vocab_size(), | ||||
|         ) | ||||
|  | ||||
|         self.use_cuda_graph = (self.vllm_config.compilation_config.level | ||||
|                                == CompilationLevel.PIECEWISE | ||||
| @ -202,132 +123,6 @@ class GPUModelRunner: | ||||
|                                         pin_memory=self.pin_memory) | ||||
|         self.seq_lens_np = self.seq_lens_cpu.numpy() | ||||
|  | ||||
|     def _update_states(self, scheduler_output: "SchedulerOutput") -> None: | ||||
|         # Remove stopped requests from the cached states. | ||||
|         # Keep the states of the pre-empted requests. | ||||
|         for req_id in scheduler_output.finished_req_ids: | ||||
|             self.requests.pop(req_id, None) | ||||
|             self.encoder_cache.pop(req_id, None) | ||||
|  | ||||
|         # Free the cached encoder outputs. | ||||
|         for req_id, input_id in scheduler_output.free_encoder_input_ids: | ||||
|             encoder_outputs = self.encoder_cache.get(req_id) | ||||
|             if encoder_outputs is not None: | ||||
|                 encoder_outputs.pop(input_id, None) | ||||
|                 if not encoder_outputs: | ||||
|                     self.encoder_cache.pop(req_id, None) | ||||
|  | ||||
|         # Remove the requests from the persistent batch. | ||||
|         stopped_req_ids = set().union( | ||||
|             scheduler_output.preempted_req_ids, | ||||
|             scheduler_output.finished_req_ids, | ||||
|         ) | ||||
|         removed_req_indices: List[int] = [] | ||||
|         for req_id in stopped_req_ids: | ||||
|             req_index = self.input_batch.remove_request(req_id) | ||||
|             if req_index is not None: | ||||
|                 removed_req_indices.append(req_index) | ||||
|  | ||||
|         # Update the states of the running requests. | ||||
|         for req_data in scheduler_output.scheduled_running_reqs: | ||||
|             req_id = req_data.req_id | ||||
|             req_state = self.requests[req_id] | ||||
|             req_index = self.input_batch.req_id_to_index[req_id] | ||||
|  | ||||
|             # Update the num_computed_tokens. | ||||
|             req_state.num_computed_tokens = req_data.num_computed_tokens | ||||
|             self.input_batch.num_computed_tokens_cpu[req_index] = ( | ||||
|                 req_data.num_computed_tokens) | ||||
|  | ||||
|             # Update the block table. | ||||
|             num_new_blocks = len(req_data.new_block_ids) | ||||
|             if num_new_blocks == 0: | ||||
|                 continue | ||||
|             start_index = len(req_state.block_ids) | ||||
|             req_state.block_ids.extend(req_data.new_block_ids) | ||||
|             self.input_batch.block_table.append_row(req_index, start_index, | ||||
|                                                     req_data.new_block_ids) | ||||
|  | ||||
|         req_ids_to_add: List[str] = [] | ||||
|         # Add new requests to the cached states. | ||||
|         for new_req_data in scheduler_output.scheduled_new_reqs: | ||||
|             req_id = new_req_data.req_id | ||||
|             sampling_params = new_req_data.sampling_params | ||||
|             if sampling_params.sampling_type == SamplingType.RANDOM_SEED: | ||||
|                 generator = torch.Generator(device=self.device) | ||||
|                 generator.manual_seed(sampling_params.seed) | ||||
|             else: | ||||
|                 generator = None | ||||
|  | ||||
|             self.requests[req_id] = CachedRequestState( | ||||
|                 req_id=req_id, | ||||
|                 prompt_token_ids=new_req_data.prompt_token_ids, | ||||
|                 prompt=new_req_data.prompt, | ||||
|                 mm_inputs=new_req_data.mm_inputs, | ||||
|                 mm_positions=new_req_data.mm_positions, | ||||
|                 sampling_params=sampling_params, | ||||
|                 generator=generator, | ||||
|                 block_ids=new_req_data.block_ids, | ||||
|                 num_computed_tokens=new_req_data.num_computed_tokens, | ||||
|                 output_token_ids=[], | ||||
|             ) | ||||
|  | ||||
|             # Only relevant for models using M-RoPE (e.g, Qwen2-VL) | ||||
|             if self.model_config.uses_mrope: | ||||
|                 image_grid_thw = [] | ||||
|                 video_grid_thw = [] | ||||
|                 for mm_input in self.requests[req_id].mm_inputs: | ||||
|                     if mm_input.get("image_grid_thw") is not None: | ||||
|                         image_grid_thw.extend( | ||||
|                             mm_input["image_grid_thw"].tolist()) | ||||
|                     if mm_input.get("video_grid_thw") is not None: | ||||
|                         video_grid_thw.extend( | ||||
|                             mm_input["video_grid_thw"].tolist()) | ||||
|  | ||||
|                 hf_config = self.model_config.hf_config | ||||
|  | ||||
|                 self.requests[req_id].mrope_positions, \ | ||||
|                     self.requests[req_id].mrope_position_delta = \ | ||||
|                     MRotaryEmbedding.get_input_positions_tensor( | ||||
|                         self.requests[req_id].prompt_token_ids, | ||||
|                         image_grid_thw=image_grid_thw, | ||||
|                         video_grid_thw=video_grid_thw, | ||||
|                         image_token_id=hf_config.image_token_id, | ||||
|                         video_token_id=hf_config.video_token_id, | ||||
|                         vision_start_token_id=hf_config.vision_start_token_id, | ||||
|                         vision_end_token_id=hf_config.vision_end_token_id, | ||||
|                         spatial_merge_size=hf_config.vision_config. | ||||
|                         spatial_merge_size, | ||||
|                     ) | ||||
|  | ||||
|             req_ids_to_add.append(req_id) | ||||
|  | ||||
|         # Update the cached states of the resumed requests. | ||||
|         for res_req_data in scheduler_output.scheduled_resumed_reqs: | ||||
|             req_id = res_req_data.req_id | ||||
|             req_state = self.requests[req_id] | ||||
|  | ||||
|             req_state.block_ids = res_req_data.block_ids | ||||
|             req_state.num_computed_tokens = res_req_data.num_computed_tokens | ||||
|             req_ids_to_add.append(req_id) | ||||
|  | ||||
|         # Add the new or resumed requests to the persistent batch. | ||||
|         # The smaller empty indices are filled first. | ||||
|         removed_req_indices = sorted(removed_req_indices, reverse=True) | ||||
|         for req_id in req_ids_to_add: | ||||
|             req_state = self.requests[req_id] | ||||
|             if removed_req_indices: | ||||
|                 # Fill the empty index. | ||||
|                 req_index = removed_req_indices.pop() | ||||
|             else: | ||||
|                 # Append to the end. | ||||
|                 req_index = None | ||||
|             self.input_batch.add_request(req_state, req_index) | ||||
|  | ||||
|         # Condense the batched states if there are empty indices. | ||||
|         if removed_req_indices: | ||||
|             self.input_batch.condense(removed_req_indices) | ||||
|  | ||||
|     def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): | ||||
|         total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | ||||
|         assert total_num_scheduled_tokens > 0 | ||||
| @ -611,6 +406,8 @@ class GPUModelRunner: | ||||
|         return sampling_metadata | ||||
|  | ||||
|     def _execute_encoder(self, scheduler_output: "SchedulerOutput"): | ||||
|         assert self.model is not None | ||||
|  | ||||
|         scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs | ||||
|         if not scheduled_encoder_inputs: | ||||
|             return | ||||
| @ -698,15 +495,14 @@ class GPUModelRunner: | ||||
|                 encoder_outputs.append(encoder_output[start_idx:end_idx]) | ||||
|         return encoder_outputs | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         return self.model | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> ModelRunnerOutput: | ||||
|         self._update_states(scheduler_output) | ||||
|         assert self.model is not None | ||||
|  | ||||
|         self.update_states(scheduler_output) | ||||
|  | ||||
|         if self.is_multimodal_model: | ||||
|             # Run the multimodal encoder if any. | ||||
| @ -833,14 +629,15 @@ class GPUModelRunner: | ||||
|                     self.model_memory_usage / float(2**30)) | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def _dummy_run( | ||||
|     def dummy_run( | ||||
|         self, | ||||
|         kv_caches, | ||||
|         num_tokens: int, | ||||
|         kv_caches: Optional[List[torch.Tensor]] = None, | ||||
|         seq_len: Optional[int] = None, | ||||
|         exec_mode: Optional[ExecutionMode] = None, | ||||
|     ) -> torch.Tensor: | ||||
|         model = self.model | ||||
|         if kv_caches is None: | ||||
|             kv_caches = self.kv_caches | ||||
|         assert self.model is not None | ||||
|  | ||||
|         if self.is_multimodal_model: | ||||
|             input_ids = None | ||||
|             inputs_embeds = self.inputs_embeds[:num_tokens] | ||||
| @ -851,7 +648,7 @@ class GPUModelRunner: | ||||
|             positions = self.mrope_positions[:, :num_tokens] \ | ||||
|                 if self.model_config.uses_mrope \ | ||||
|                 else self.positions[:num_tokens] | ||||
|             hidden_states = model( | ||||
|             hidden_states = self.model( | ||||
|                 input_ids=input_ids, | ||||
|                 positions=positions, | ||||
|                 kv_caches=kv_caches, | ||||
| @ -861,6 +658,7 @@ class GPUModelRunner: | ||||
|         return hidden_states | ||||
|  | ||||
|     def profile_run(self) -> None: | ||||
|         assert self.model is not None | ||||
|         # use an empty tensor instead of `None`` to force Dynamo to pass | ||||
|         # it by reference, rather by specializing on the value `None`. | ||||
|         # the `dtype` argument does not matter, and we use `float32` as | ||||
| @ -966,7 +764,7 @@ class GPUModelRunner: | ||||
|             self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) | ||||
|  | ||||
|         # Trigger compilation for general shape. | ||||
|         hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) | ||||
|         hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens) | ||||
|         logits = self.model.compute_logits(hidden_states, None) | ||||
|         logits = logits[:self.max_num_tokens] | ||||
|         # TODO(woosuk): Consider the memory usage of the sampler. | ||||
| @ -992,8 +790,8 @@ class GPUModelRunner: | ||||
|             for num_tokens in reversed(self.cudagraph_batch_sizes): | ||||
|                 for _ in range(self.vllm_config.compilation_config. | ||||
|                                cudagraph_num_of_warmups): | ||||
|                     self._dummy_run(num_tokens) | ||||
|                 self._dummy_run(num_tokens) | ||||
|                     self.dummy_run(None, num_tokens) | ||||
|                 self.dummy_run(None, num_tokens) | ||||
|  | ||||
|         end_time = time.perf_counter() | ||||
|         end_free_gpu_memory = torch.cuda.mem_get_info()[0] | ||||
| @ -1036,38 +834,3 @@ class GPUModelRunner: | ||||
|             kv_caches, | ||||
|             self.vllm_config.compilation_config.static_forward_context, | ||||
|             self.kv_caches) | ||||
|  | ||||
|     def get_kv_cache_spec(self) -> KVCacheSpec: | ||||
|         """ | ||||
|         Generates the KVCacheSpec by parsing the kv cache format from each  | ||||
|         Attention module in the static forward context. | ||||
|         Returns: | ||||
|             KVCacheSpec: A dictionary mapping layer names to their KV cache  | ||||
|             format. Layers that do not need KV cache are not included. | ||||
|         """ | ||||
|  | ||||
|         forward_ctx = self.vllm_config.compilation_config.static_forward_context | ||||
|         block_size = self.vllm_config.cache_config.block_size | ||||
|         kv_cache_spec: KVCacheSpec = {} | ||||
|         for layer_name, attn_module in forward_ctx.items(): | ||||
|             # TODO: Support other attention modules, e.g., sliding window, | ||||
|             # cross-attention, MLA. | ||||
|             assert isinstance(attn_module, Attention) | ||||
|             if attn_module.attn_type == AttentionType.DECODER: | ||||
|                 kv_cache_spec[layer_name] = FullAttentionSpec( | ||||
|                     block_size=block_size, | ||||
|                     num_kv_heads=attn_module.num_kv_heads, | ||||
|                     head_size=attn_module.head_size, | ||||
|                     dtype=attn_module.dtype, | ||||
|                 ) | ||||
|             elif attn_module.attn_type in (AttentionType.ENCODER, | ||||
|                                            AttentionType.ENCODER_ONLY): | ||||
|                 # encoder-only attention does not need KV cache. | ||||
|                 continue | ||||
|             elif attn_module.attn_type == AttentionType.ENCODER_DECODER: | ||||
|                 raise NotImplementedError | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     f"Unknown attention type: {attn_module.attn_type}") | ||||
|  | ||||
|         return kv_cache_spec | ||||
|  | ||||
| @ -1,13 +1,11 @@ | ||||
| """A GPU worker class.""" | ||||
| import gc | ||||
| import os | ||||
| from typing import TYPE_CHECKING, Optional | ||||
| from typing import Optional | ||||
|  | ||||
| import torch | ||||
| import torch.distributed | ||||
| import torch.nn as nn | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.config import ParallelConfig, VllmConfig | ||||
| from vllm.device_allocator.cumem import CuMemAllocator | ||||
| from vllm.distributed import (ensure_model_parallel_initialized, | ||||
| @ -15,20 +13,17 @@ from vllm.distributed import (ensure_model_parallel_initialized, | ||||
|                               set_custom_all_reduce) | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor import set_random_seed | ||||
| from vllm.platforms import current_platform | ||||
| from vllm.utils import GiB_bytes | ||||
| from vllm.v1.core.scheduler import SchedulerOutput | ||||
| from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec | ||||
| from vllm.v1.kv_cache_interface import KVCacheConfig | ||||
| from vllm.v1.outputs import ModelRunnerOutput | ||||
| from vllm.v1.worker.gpu_model_runner import GPUModelRunner | ||||
| from vllm.v1.worker.worker_base import WorkerBase, check_if_gpu_supports_dtype | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.v1.core.scheduler import SchedulerOutput | ||||
|  | ||||
|  | ||||
| class Worker: | ||||
| class GPUWorker(WorkerBase): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -38,46 +33,8 @@ class Worker: | ||||
|         distributed_init_method: str, | ||||
|         is_driver_worker: bool = False, | ||||
|     ): | ||||
|  | ||||
|         # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) | ||||
|         self.vllm_config = vllm_config | ||||
|         self.model_config = vllm_config.model_config | ||||
|         self.cache_config = vllm_config.cache_config | ||||
|         self.lora_config = vllm_config.lora_config | ||||
|         self.load_config = vllm_config.load_config | ||||
|         self.parallel_config = vllm_config.parallel_config | ||||
|         self.scheduler_config = vllm_config.scheduler_config | ||||
|         self.device_config = vllm_config.device_config | ||||
|         self.speculative_config = vllm_config.speculative_config | ||||
|         self.prompt_adapter_config = vllm_config.prompt_adapter_config | ||||
|         self.observability_config = vllm_config.observability_config | ||||
|  | ||||
|         self.parallel_config.rank = rank | ||||
|         self.local_rank = local_rank | ||||
|         self.rank = rank | ||||
|         self.distributed_init_method = distributed_init_method | ||||
|  | ||||
|         if self.model_config.trust_remote_code: | ||||
|             # note: lazy import to avoid importing torch before initializing | ||||
|             from vllm.utils import init_cached_hf_modules | ||||
|             init_cached_hf_modules() | ||||
|  | ||||
|         # Torch profiler. Enabled and configured through env vars: | ||||
|         # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace | ||||
|         if envs.VLLM_TORCH_PROFILER_DIR: | ||||
|             torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR | ||||
|             logger.info("Profiling enabled. Traces will be saved to: %s", | ||||
|                         torch_profiler_trace_dir) | ||||
|             self.profiler = torch.profiler.profile( | ||||
|                 activities=[ | ||||
|                     torch.profiler.ProfilerActivity.CPU, | ||||
|                     torch.profiler.ProfilerActivity.CUDA, | ||||
|                 ], | ||||
|                 with_stack=True, | ||||
|                 on_trace_ready=torch.profiler.tensorboard_trace_handler( | ||||
|                     torch_profiler_trace_dir, use_gzip=True)) | ||||
|         else: | ||||
|             self.profiler = None | ||||
|         super().__init__(vllm_config, local_rank, rank, | ||||
|                          distributed_init_method) | ||||
|  | ||||
|     def sleep(self, level: int = 1) -> None: | ||||
|         free_bytes_before_sleep = torch.cuda.mem_get_info()[0] | ||||
| @ -97,31 +54,39 @@ class Worker: | ||||
|         allocator.wake_up() | ||||
|  | ||||
|     def init_device(self): | ||||
|         if self.device_config.device.type == "cuda": | ||||
|             # torch.distributed.all_reduce does not free the input tensor until | ||||
|             # the synchronization point. This causes the memory usage to grow | ||||
|             # as the number of all_reduce calls increases. This env var disables | ||||
|             # this behavior. | ||||
|             # Related issue: | ||||
|             # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 | ||||
|             os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | ||||
|         assert self.device_config.device.type == "cuda" | ||||
|  | ||||
|             # This env var set by Ray causes exceptions with graph building. | ||||
|             os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) | ||||
|             self.device = torch.device(f"cuda:{self.local_rank}") | ||||
|             torch.cuda.set_device(self.device) | ||||
|         # torch.distributed.all_reduce does not free the input tensor until | ||||
|         # the synchronization point. This causes the memory usage to grow | ||||
|         # as the number of all_reduce calls increases. This env var disables | ||||
|         # this behavior. | ||||
|         # Related issue: | ||||
|         # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 | ||||
|         os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | ||||
|  | ||||
|         # torch.distributed.all_reduce does not free the input tensor until | ||||
|         # the synchronization point. This causes the memory usage to grow | ||||
|         # as the number of all_reduce calls increases. This env var disables | ||||
|         # this behavior. | ||||
|         # Related issue: | ||||
|         # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 | ||||
|         os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | ||||
|  | ||||
|         # This env var set by Ray causes exceptions with graph building. | ||||
|         os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) | ||||
|         self.device = torch.device(f"cuda:{self.local_rank}") | ||||
|         torch.cuda.set_device(self.device) | ||||
|  | ||||
|         check_if_gpu_supports_dtype(self.model_config.dtype) | ||||
|         gc.collect() | ||||
|         torch.cuda.empty_cache() | ||||
|         self.init_gpu_memory = torch.cuda.mem_get_info()[0] | ||||
|  | ||||
|             _check_if_gpu_supports_dtype(self.model_config.dtype) | ||||
|             gc.collect() | ||||
|             torch.cuda.empty_cache() | ||||
|             self.init_gpu_memory = torch.cuda.mem_get_info()[0] | ||||
|         else: | ||||
|             raise RuntimeError( | ||||
|                 f"Not support device type: {self.device_config.device}") | ||||
|         # Initialize the distributed environment. | ||||
|         init_worker_distributed_environment(self.parallel_config, self.rank, | ||||
|                                             self.distributed_init_method, | ||||
|                                             self.local_rank) | ||||
|         init_cuda_worker_distributed_environment(self.parallel_config, | ||||
|                                                  self.rank, | ||||
|                                                  self.distributed_init_method, | ||||
|                                                  self.local_rank) | ||||
|         # Set random seed. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|  | ||||
| @ -139,6 +104,7 @@ class Worker: | ||||
|             from contextlib import nullcontext | ||||
|             context = nullcontext() | ||||
|         with context: | ||||
|             assert self.model_runner is not None | ||||
|             self.model_runner.load_model() | ||||
|  | ||||
|     @torch.inference_mode() | ||||
| @ -160,6 +126,7 @@ class Worker: | ||||
|         _, total_gpu_memory = torch.cuda.mem_get_info() | ||||
|         # Execute a forward pass with dummy inputs to profile the memory usage | ||||
|         # of the model. | ||||
|         assert self.model_runner is not None | ||||
|         self.model_runner.profile_run() | ||||
|  | ||||
|         free_gpu_memory, _ = torch.cuda.mem_get_info() | ||||
| @ -191,9 +158,6 @@ class Worker: | ||||
|  | ||||
|         return int(available_kv_cache_memory) | ||||
|  | ||||
|     def get_kv_cache_spec(self) -> KVCacheSpec: | ||||
|         return self.model_runner.get_kv_cache_spec() | ||||
|  | ||||
|     def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: | ||||
|         """Allocate GPU KV cache with the specified kv_cache_config.""" | ||||
|         if self.vllm_config.model_config.enable_sleep_mode: | ||||
| @ -203,9 +167,12 @@ class Worker: | ||||
|             from contextlib import nullcontext | ||||
|             context = nullcontext() | ||||
|         with context: | ||||
|             assert self.model_runner is not None | ||||
|             self.model_runner.initialize_kv_cache(kv_cache_config) | ||||
|  | ||||
|     def compile_or_warm_up_model(self) -> None: | ||||
|         assert self.model_runner is not None | ||||
|  | ||||
|         # warm up sizes that are not in cudagraph capture sizes, | ||||
|         # but users still want to compile for better performance, | ||||
|         # e.g. for the max-num-batched token size in chunked prefill. | ||||
| @ -217,44 +184,32 @@ class Worker: | ||||
|             ] | ||||
|         for size in sorted(warmup_sizes, reverse=True): | ||||
|             logger.info("Compile and warming up model for size %d", size) | ||||
|             self.model_runner._dummy_run(size) | ||||
|             self.model_runner.dummy_run(None, size) | ||||
|  | ||||
|         if not self.model_config.enforce_eager: | ||||
|             self.model_runner.capture_model() | ||||
|         # Reset the seed to ensure that the random state is not affected by | ||||
|         # the model initialization and profiling. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         return self.model_runner.get_model() | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> Optional[ModelRunnerOutput]: | ||||
|         assert self.model_runner is not None | ||||
|         output = self.model_runner.execute_model(scheduler_output) | ||||
|         return output if self.rank == 0 else None | ||||
|  | ||||
|     def profile(self, is_start: bool = True): | ||||
|         if self.profiler is None: | ||||
|             raise RuntimeError("Profiler is not enabled.") | ||||
|         if is_start: | ||||
|             self.profiler.start() | ||||
|         else: | ||||
|             self.profiler.stop() | ||||
|  | ||||
|     def check_health(self) -> None: | ||||
|         # worker will always be healthy as long as it's running. | ||||
|         return | ||||
|  | ||||
|  | ||||
| def init_worker_distributed_environment( | ||||
| def init_cuda_worker_distributed_environment( | ||||
|     parallel_config: ParallelConfig, | ||||
|     rank: int, | ||||
|     distributed_init_method: Optional[str] = None, | ||||
|     local_rank: int = -1, | ||||
| ) -> None: | ||||
|     """Initialize the distributed environment.""" | ||||
|  | ||||
|     set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) | ||||
|  | ||||
|     init_distributed_environment(parallel_config.world_size, rank, | ||||
| @ -264,21 +219,22 @@ def init_worker_distributed_environment( | ||||
|                                       parallel_config.pipeline_parallel_size) | ||||
|  | ||||
|  | ||||
| def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): | ||||
|     # Check if the GPU supports the dtype. | ||||
|     if torch_dtype == torch.bfloat16:  # noqa: SIM102 | ||||
|         if not current_platform.has_device_capability(80): | ||||
|             capability = current_platform.get_device_capability() | ||||
|             gpu_name = current_platform.get_device_name() | ||||
| # TODO: Remove | ||||
| # def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): | ||||
| #     # Check if the GPU supports the dtype. | ||||
| #     if torch_dtype == torch.bfloat16:  # noqa: SIM102 | ||||
| #         if not current_platform.has_device_capability(80): | ||||
| #             capability = current_platform.get_device_capability() | ||||
| #             gpu_name = current_platform.get_device_name() | ||||
|  | ||||
|             if capability is None: | ||||
|                 compute_str = "does not have a compute capability" | ||||
|             else: | ||||
|                 version_str = capability.as_version_str() | ||||
|                 compute_str = f"has compute capability {version_str}" | ||||
| #             if capability is None: | ||||
| #                 compute_str = "does not have a compute capability" | ||||
| #             else: | ||||
| #                 version_str = capability.as_version_str() | ||||
| #                 compute_str = f"has compute capability {version_str}" | ||||
|  | ||||
|             raise ValueError( | ||||
|                 "Bfloat16 is only supported on GPUs with compute capability " | ||||
|                 f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " | ||||
|                 "You can use float16 instead by explicitly setting the" | ||||
|                 "`dtype` flag in CLI, for example: --dtype=half.") | ||||
| #             raise ValueError( | ||||
| #                 "Bfloat16 is only supported on GPUs with compute capability " | ||||
| #                 f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " | ||||
| #                 "You can use float16 instead by explicitly setting the" | ||||
| #                 "`dtype` flag in CLI, for example: --dtype=half.") | ||||
|  | ||||
							
								
								
									
										307
									
								
								vllm/v1/worker/model_runner_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										307
									
								
								vllm/v1/worker/model_runner_base.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,307 @@ | ||||
| import enum | ||||
| from typing import TYPE_CHECKING, Dict, List, Optional | ||||
|  | ||||
| import torch | ||||
| import torch.distributed | ||||
| import torch.nn as nn | ||||
|  | ||||
| from vllm.attention.backends.abstract import AttentionType | ||||
| from vllm.attention.layer import Attention | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.inputs import INPUT_REGISTRY | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding | ||||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||||
| from vllm.sampling_params import SamplingType | ||||
| from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available | ||||
| from vllm.v1.core.encoder_cache_manager import compute_encoder_budget | ||||
| from vllm.v1.engine.mm_input_mapper import MMInputMapperClient | ||||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||||
|                                         KVCacheSpec) | ||||
| from vllm.v1.outputs import ModelRunnerOutput | ||||
| from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.v1.core.scheduler import SchedulerOutput | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class ExecutionMode(enum.Enum): | ||||
|     PREFILL = enum.auto() | ||||
|     DECODE = enum.auto() | ||||
|     PREFIX_PREFILL = enum.auto() | ||||
|  | ||||
|     def is_prefill(self) -> bool: | ||||
|         return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) | ||||
|  | ||||
|  | ||||
| class ModelRunnerBase: | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         device: torch.device, | ||||
|     ): | ||||
|         self.vllm_config = vllm_config | ||||
|         self.model_config = vllm_config.model_config | ||||
|         self.cache_config = vllm_config.cache_config | ||||
|         self.lora_config = vllm_config.lora_config | ||||
|         self.load_config = vllm_config.load_config | ||||
|         self.parallel_config = vllm_config.parallel_config | ||||
|         self.scheduler_config = vllm_config.scheduler_config | ||||
|         self.speculative_config = vllm_config.speculative_config | ||||
|         self.prompt_adapter_config = vllm_config.prompt_adapter_config | ||||
|         self.observability_config = vllm_config.observability_config | ||||
|         self.device_config = vllm_config.device_config | ||||
|  | ||||
|         model_config = self.model_config | ||||
|         cache_config = self.cache_config | ||||
|         scheduler_config = self.scheduler_config | ||||
|         parallel_config = self.parallel_config | ||||
|         self.device = device | ||||
|         self.pin_memory = is_pin_memory_available() | ||||
|         self.dtype = self.model_config.dtype | ||||
|  | ||||
|         self.is_multimodal_model = model_config.is_multimodal_model | ||||
|         self.sliding_window = model_config.get_sliding_window() | ||||
|         self.block_size = cache_config.block_size | ||||
|         self.max_model_len = model_config.max_model_len | ||||
|         self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) | ||||
|         self.max_num_tokens = scheduler_config.max_num_batched_tokens | ||||
|         self.max_num_reqs = scheduler_config.max_num_seqs | ||||
|  | ||||
|         # Model-related. | ||||
|         self.num_attn_layers = model_config.get_num_layers_by_block_type( | ||||
|             parallel_config, LayerBlockType.attention) | ||||
|         self.num_query_heads = model_config.get_num_attention_heads( | ||||
|             parallel_config) | ||||
|         self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) | ||||
|         self.head_size = model_config.get_head_size() | ||||
|         self.hidden_size = model_config.get_hidden_size() | ||||
|  | ||||
|         self.model: Optional[nn.Module] = None | ||||
|  | ||||
|         # Persistent batch. | ||||
|         self.input_batch = InputBatch( | ||||
|             max_num_reqs=self.max_num_reqs, | ||||
|             max_model_len=self.max_model_len, | ||||
|             max_num_blocks_per_req=self.max_num_blocks_per_req, | ||||
|             device=self.device, | ||||
|             pin_memory=self.pin_memory, | ||||
|             vocab_size=self.model_config.get_vocab_size(), | ||||
|         ) | ||||
|  | ||||
|         # Request states. | ||||
|         self.requests: Dict[str, CachedRequestState] = {} | ||||
|  | ||||
|         # Multi-modal data support | ||||
|         self.input_registry = INPUT_REGISTRY | ||||
|         self.mm_registry = MULTIMODAL_REGISTRY | ||||
|  | ||||
|         # NOTE: Initialized input mapper is only used for processing dummy | ||||
|         # multimodal data into multimodal kwargs for GPU memory profiling. | ||||
|         self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) | ||||
|         self.mm_input_mapper_profiling.use_cache = False | ||||
|  | ||||
|         encoder_compute_budget, encoder_cache_size = compute_encoder_budget( | ||||
|             model_config=self.model_config, | ||||
|             scheduler_config=self.scheduler_config, | ||||
|         ) | ||||
|         self.max_num_encoder_input_tokens = encoder_compute_budget | ||||
|         self.encoder_cache_size = encoder_cache_size | ||||
|  | ||||
|         # req_id -> (input_id -> encoder_output) | ||||
|         self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} | ||||
|  | ||||
|     def update_states(self, scheduler_output: "SchedulerOutput") -> None: | ||||
|         # Remove stopped requests from the cached states. | ||||
|         # Keep the states of the pre-empted requests. | ||||
|         for req_id in scheduler_output.finished_req_ids: | ||||
|             self.requests.pop(req_id, None) | ||||
|             self.encoder_cache.pop(req_id, None) | ||||
|  | ||||
|         # Free the cached encoder outputs. | ||||
|         for req_id, input_id in scheduler_output.free_encoder_input_ids: | ||||
|             encoder_outputs = self.encoder_cache.get(req_id) | ||||
|             if encoder_outputs is not None: | ||||
|                 encoder_outputs.pop(input_id, None) | ||||
|                 if not encoder_outputs: | ||||
|                     self.encoder_cache.pop(req_id, None) | ||||
|  | ||||
|         # Remove the requests from the persistent batch. | ||||
|         stopped_req_ids = set().union( | ||||
|             scheduler_output.preempted_req_ids, | ||||
|             scheduler_output.finished_req_ids, | ||||
|         ) | ||||
|         removed_req_indices: List[int] = [] | ||||
|         for req_id in stopped_req_ids: | ||||
|             req_index = self.input_batch.remove_request(req_id) | ||||
|             if req_index is not None: | ||||
|                 removed_req_indices.append(req_index) | ||||
|  | ||||
|         # Update the states of the running requests. | ||||
|         for req_data in scheduler_output.scheduled_running_reqs: | ||||
|             req_id = req_data.req_id | ||||
|             req_state = self.requests[req_id] | ||||
|             req_index = self.input_batch.req_id_to_index[req_id] | ||||
|  | ||||
|             # Update the num_computed_tokens. | ||||
|             req_state.num_computed_tokens = req_data.num_computed_tokens | ||||
|             self.input_batch.num_computed_tokens_cpu[req_index] = ( | ||||
|                 req_data.num_computed_tokens) | ||||
|  | ||||
|             # Update the block table. | ||||
|             num_new_blocks = len(req_data.new_block_ids) | ||||
|             if num_new_blocks == 0: | ||||
|                 continue | ||||
|             start_index = len(req_state.block_ids) | ||||
|             req_state.block_ids.extend(req_data.new_block_ids) | ||||
|             self.input_batch.block_table.append_row(req_index, start_index, | ||||
|                                                     req_data.new_block_ids) | ||||
|  | ||||
|         req_ids_to_add: List[str] = [] | ||||
|         # Add new requests to the cached states. | ||||
|         for new_req_data in scheduler_output.scheduled_new_reqs: | ||||
|             req_id = new_req_data.req_id | ||||
|             sampling_params = new_req_data.sampling_params | ||||
|             if sampling_params.sampling_type == SamplingType.RANDOM_SEED: | ||||
|                 generator = torch.Generator(device=self.device) | ||||
|                 generator.manual_seed(sampling_params.seed) | ||||
|             else: | ||||
|                 generator = None | ||||
|  | ||||
|             self.requests[req_id] = CachedRequestState( | ||||
|                 req_id=req_id, | ||||
|                 prompt_token_ids=new_req_data.prompt_token_ids, | ||||
|                 prompt=new_req_data.prompt, | ||||
|                 mm_inputs=new_req_data.mm_inputs, | ||||
|                 mm_positions=new_req_data.mm_positions, | ||||
|                 sampling_params=sampling_params, | ||||
|                 generator=generator, | ||||
|                 block_ids=new_req_data.block_ids, | ||||
|                 num_computed_tokens=new_req_data.num_computed_tokens, | ||||
|                 output_token_ids=[], | ||||
|             ) | ||||
|  | ||||
|             # Only relevant for models using M-RoPE (e.g, Qwen2-VL) | ||||
|             if self.model_config.uses_mrope: | ||||
|                 image_grid_thw = [] | ||||
|                 video_grid_thw = [] | ||||
|                 for mm_input in self.requests[req_id].mm_inputs: | ||||
|                     if mm_input.get("image_grid_thw") is not None: | ||||
|                         image_grid_thw.extend( | ||||
|                             mm_input["image_grid_thw"].tolist()) | ||||
|                     if mm_input.get("video_grid_thw") is not None: | ||||
|                         video_grid_thw.extend( | ||||
|                             mm_input["video_grid_thw"].tolist()) | ||||
|  | ||||
|                 hf_config = self.model_config.hf_config | ||||
|  | ||||
|                 self.requests[req_id].mrope_positions, \ | ||||
|                     self.requests[req_id].mrope_position_delta = \ | ||||
|                     MRotaryEmbedding.get_input_positions_tensor( | ||||
|                         self.requests[req_id].prompt_token_ids, | ||||
|                         image_grid_thw=image_grid_thw, | ||||
|                         video_grid_thw=video_grid_thw, | ||||
|                         image_token_id=hf_config.image_token_id, | ||||
|                         video_token_id=hf_config.video_token_id, | ||||
|                         vision_start_token_id=hf_config.vision_start_token_id, | ||||
|                         vision_end_token_id=hf_config.vision_end_token_id, | ||||
|                         spatial_merge_size=hf_config.vision_config. | ||||
|                         spatial_merge_size, | ||||
|                     ) | ||||
|  | ||||
|             req_ids_to_add.append(req_id) | ||||
|  | ||||
|         # Update the cached states of the resumed requests. | ||||
|         for res_req_data in scheduler_output.scheduled_resumed_reqs: | ||||
|             req_id = res_req_data.req_id | ||||
|             req_state = self.requests[req_id] | ||||
|  | ||||
|             req_state.block_ids = res_req_data.block_ids | ||||
|             req_state.num_computed_tokens = res_req_data.num_computed_tokens | ||||
|             req_ids_to_add.append(req_id) | ||||
|  | ||||
|         # Add the new or resumed requests to the persistent batch. | ||||
|         # The smaller empty indices are filled first. | ||||
|         removed_req_indices = sorted(removed_req_indices, reverse=True) | ||||
|         for req_id in req_ids_to_add: | ||||
|             req_state = self.requests[req_id] | ||||
|             if removed_req_indices: | ||||
|                 # Fill the empty index. | ||||
|                 req_index = removed_req_indices.pop() | ||||
|             else: | ||||
|                 # Append to the end. | ||||
|                 req_index = None | ||||
|             self.input_batch.add_request(req_state, req_index) | ||||
|  | ||||
|         # Condense the batched states if there are empty indices. | ||||
|         if removed_req_indices: | ||||
|             self.input_batch.condense(removed_req_indices) | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         assert self.model is not None | ||||
|         return self.model | ||||
|  | ||||
|     def get_kv_cache_spec(self) -> KVCacheSpec: | ||||
|         """ | ||||
|         Generates the KVCacheSpec by parsing the kv cache format from each  | ||||
|         Attention module in the static forward context. | ||||
|         Returns: | ||||
|             KVCacheSpec: A dictionary mapping layer names to their KV cache  | ||||
|             format. Layers that do not need KV cache are not included. | ||||
|         """ | ||||
|  | ||||
|         forward_ctx = self.vllm_config.compilation_config.static_forward_context | ||||
|         block_size = self.vllm_config.cache_config.block_size | ||||
|         kv_cache_spec: KVCacheSpec = {} | ||||
|         for layer_name, attn_module in forward_ctx.items(): | ||||
|             # TODO: Support other attention modules, e.g., sliding window, | ||||
|             # cross-attention, MLA. | ||||
|             assert isinstance(attn_module, Attention) | ||||
|             if attn_module.attn_type == AttentionType.DECODER: | ||||
|                 kv_cache_spec[layer_name] = FullAttentionSpec( | ||||
|                     block_size=block_size, | ||||
|                     num_kv_heads=attn_module.num_kv_heads, | ||||
|                     head_size=attn_module.head_size, | ||||
|                     dtype=attn_module.dtype, | ||||
|                 ) | ||||
|             elif attn_module.attn_type in (AttentionType.ENCODER, | ||||
|                                            AttentionType.ENCODER_ONLY): | ||||
|                 # encoder-only attention does not need KV cache. | ||||
|                 continue | ||||
|             elif attn_module.attn_type == AttentionType.ENCODER_DECODER: | ||||
|                 raise NotImplementedError | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     f"Unknown attention type: {attn_module.attn_type}") | ||||
|  | ||||
|         return kv_cache_spec | ||||
|  | ||||
|     def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def execute_model( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> ModelRunnerOutput: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def dummy_run( | ||||
|         self, | ||||
|         kv_caches, | ||||
|         num_tokens: int, | ||||
|         seq_len: Optional[int] = None, | ||||
|         exec_mode: Optional[ExecutionMode] = None, | ||||
|     ) -> torch.Tensor: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def profile_run(self) -> None: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def capture_model(self) -> None: | ||||
|         raise NotImplementedError() | ||||
							
								
								
									
										888
									
								
								vllm/v1/worker/tpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										888
									
								
								vllm/v1/worker/tpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,888 @@ | ||||
| import time | ||||
| from dataclasses import dataclass | ||||
| from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.distributed | ||||
| import torch.nn as nn | ||||
| # TPU XLA related | ||||
| import torch_xla.core.xla_model as xm | ||||
| import torch_xla.runtime as xr | ||||
|  | ||||
| from vllm.attention import AttentionMetadata | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.forward_context import set_forward_context | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.model_loader import get_model | ||||
| from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, | ||||
|                                                PallasMetadata) | ||||
| from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig | ||||
| from vllm.v1.outputs import ModelRunnerOutput | ||||
| from vllm.v1.utils import bind_kv_cache | ||||
| from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch, | ||||
|                                             ensure_decodes_first) | ||||
| from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase | ||||
| from vllm.v1.core.kv_cache_utils import get_kv_cache_config | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.v1.core.scheduler import SchedulerOutput | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| # Here we utilize the behavior that out-of-bound index is ignored. | ||||
| # FIXME(woosuk): Find a more reliable way to prevent possible bugs. | ||||
| _PAD_SLOT_ID = 1_000_000_000 | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class PromptDecodeInfo: | ||||
|     prompt_req_ids: List[str] | ||||
|     decode_req_ids: List[str] | ||||
|     prompt_scheduled_tokens: List[int] | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class PromptData: | ||||
|     input_tokens: torch.Tensor | ||||
|     input_positions: torch.Tensor | ||||
|     attn_metadata: PallasMetadata | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class DecodeData: | ||||
|     input_tokens: Optional[torch.Tensor] = None | ||||
|     input_positions: Optional[torch.Tensor] = None | ||||
|     attn_metadata: Optional[PallasMetadata] = None | ||||
|  | ||||
|  | ||||
| class TPUModelRunner(ModelRunnerBase): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         device: torch.device, | ||||
|     ): | ||||
|         super().__init__(vllm_config, device) | ||||
|  | ||||
|         # KV caches for forward pass | ||||
|         self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] | ||||
|  | ||||
|         # Cached torch/numpy tensors | ||||
|         self.num_swaps = 2 | ||||
|         self.cur_swap_id = 0 | ||||
|         self.input_ids_cpu = [] | ||||
|         self.input_ids_np = [] | ||||
|         self.input_positions_cpu = [] | ||||
|         self.input_positions_np = [] | ||||
|         self.slot_mapping_cpu = [] | ||||
|         self.slot_mapping_np = [] | ||||
|         self.prompt_context_lens_cpu = [] | ||||
|         self.prompt_effective_query_lens_cpu = [] | ||||
|         self.decode_context_lens_cpu = [] | ||||
|         self.decode_context_lens_np = [] | ||||
|         for _ in range(self.num_swaps): | ||||
|             self.input_ids_cpu.append( | ||||
|                 torch.empty(self.max_num_tokens, | ||||
|                             dtype=torch.int32, | ||||
|                             device="cpu")) | ||||
|             self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) | ||||
|  | ||||
|             self.input_positions_cpu.append( | ||||
|                 torch.empty(self.max_num_tokens, | ||||
|                             dtype=torch.int32, | ||||
|                             device="cpu")) | ||||
|             self.input_positions_np.append( | ||||
|                 self.input_positions_cpu[-1].numpy()) | ||||
|  | ||||
|             self.slot_mapping_cpu.append( | ||||
|                 torch.empty(self.max_num_tokens, | ||||
|                             dtype=torch.int64, | ||||
|                             device="cpu")) | ||||
|             self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) | ||||
|  | ||||
|             self.prompt_context_lens_cpu.append( | ||||
|                 torch.empty((1), dtype=torch.int32, device="cpu")) | ||||
|             self.prompt_effective_query_lens_cpu.append( | ||||
|                 torch.empty((1), dtype=torch.int32, device="cpu")) | ||||
|  | ||||
|             self.decode_context_lens_cpu.append( | ||||
|                 torch.empty(self.max_num_tokens, | ||||
|                             dtype=torch.int32, | ||||
|                             device="cpu")) | ||||
|             self.decode_context_lens_np.append( | ||||
|                 self.decode_context_lens_cpu[-1].numpy()) | ||||
|  | ||||
|         # Range tensor with values [0 .. self.max_num_tokens - 1]. | ||||
|         # Used to initialize positions / context_lens / seq_lens | ||||
|         self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) | ||||
|  | ||||
|     def swap_step(self): | ||||
|         self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps | ||||
|  | ||||
|     def _get_prompts_and_decodes( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> PromptDecodeInfo: | ||||
|         total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | ||||
|         assert total_num_scheduled_tokens > 0 | ||||
|         num_reqs = self.input_batch.num_reqs | ||||
|         assert num_reqs > 0 | ||||
|  | ||||
|         # Traverse decodes first | ||||
|         decode_req_ids = [] | ||||
|         for i in range(num_reqs): | ||||
|             req_id = self.input_batch.req_ids[i] | ||||
|  | ||||
|             num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] | ||||
|             num_prompt_tokens = self.input_batch.num_prompt_tokens[i] | ||||
|             num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ | ||||
|                 req_id] | ||||
|  | ||||
|             if num_computed_tokens < num_prompt_tokens: | ||||
|                 # This is prompt | ||||
|                 break | ||||
|  | ||||
|             # This is decode | ||||
|             assert num_scheduled_tokens == 1 | ||||
|             decode_req_ids.append(req_id) | ||||
|  | ||||
|         # Traverse prompts | ||||
|         prompt_req_ids = [] | ||||
|         prompt_scheduled_tokens = [] | ||||
|         for i in range(len(decode_req_ids), num_reqs): | ||||
|             req_id = self.input_batch.req_ids[i] | ||||
|  | ||||
|             num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] | ||||
|             num_prompt_tokens = self.input_batch.num_prompt_tokens[i] | ||||
|             num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ | ||||
|                 req_id] | ||||
|  | ||||
|             # Must be prompt | ||||
|             assert num_computed_tokens < num_prompt_tokens | ||||
|  | ||||
|             prompt_req_ids.append(req_id) | ||||
|             prompt_scheduled_tokens.append(num_scheduled_tokens) | ||||
|  | ||||
|         return PromptDecodeInfo(prompt_req_ids, decode_req_ids, | ||||
|                                 prompt_scheduled_tokens) | ||||
|  | ||||
|     def _prepare_prompt(self, req_index: int, | ||||
|                         num_scheduled_tokens: int) -> PromptData: | ||||
|         num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ | ||||
|             req_index] | ||||
|         num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] | ||||
|  | ||||
|         # Must be prompt | ||||
|         assert num_computed_tokens < num_prompt_tokens | ||||
|  | ||||
|         # Prompt len | ||||
|         prompt_len = num_scheduled_tokens | ||||
|         padded_prompt_len = _get_padded_prompt_len(prompt_len) | ||||
|         assert padded_prompt_len <= self.max_model_len | ||||
|  | ||||
|         # Seq len | ||||
|         seq_len = num_computed_tokens + prompt_len | ||||
|         padded_seq_len = num_computed_tokens + padded_prompt_len | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("_prepare_prompt:") | ||||
|         # print("    prompt_len = {}".format(prompt_len)) | ||||
|         # print("    padded_prompt_len = {}".format(padded_prompt_len)) | ||||
|         # print("    num_computed_tokens = {}".format(num_computed_tokens)) | ||||
|         # print("    num_prompt_tokens = {}".format(num_prompt_tokens)) | ||||
|         # print("    seq_len = {}".format(seq_len)) | ||||
|         # print("    padded_seq_len = {}".format(padded_seq_len)) | ||||
|  | ||||
|         # Input tokens | ||||
|         input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ | ||||
|             req_index, num_computed_tokens:padded_seq_len] | ||||
|         input_tokens_cpu[prompt_len:] = 0 | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    input_tokens_cpu.shape = {} val = {}".format( | ||||
|         # input_tokens_cpu.shape, input_tokens_cpu)) | ||||
|  | ||||
|         # Input positions | ||||
|         input_positions_np = self.input_positions_np[ | ||||
|             self.cur_swap_id][:padded_prompt_len] | ||||
|         np.add(num_computed_tokens, | ||||
|                self.arange_np[:padded_prompt_len], | ||||
|                out=input_positions_np) | ||||
|         input_positions_np[prompt_len:] = 0 | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    input_positions_np.shape = {} val = {}".format( | ||||
|         # input_positions_np.shape, input_positions_np)) | ||||
|  | ||||
|         # Slot mapping | ||||
|         block_table_np = \ | ||||
|             self.input_batch.block_table.get_numpy_array() | ||||
|         block_numbers_np = block_table_np[req_index, input_positions_np // | ||||
|                                           self.block_size] | ||||
|         block_offsets_np = input_positions_np % self.block_size | ||||
|  | ||||
|         slot_mapping_np = self.slot_mapping_np[ | ||||
|             self.cur_swap_id][:padded_prompt_len] | ||||
|         np.add(block_numbers_np * self.block_size, | ||||
|                block_offsets_np, | ||||
|                out=slot_mapping_np) | ||||
|         slot_mapping_np[prompt_len:] = _PAD_SLOT_ID | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    slot_mapping_np.shape = {} val = {}".format( | ||||
|         #     slot_mapping_np.shape, slot_mapping_np)) | ||||
|  | ||||
|         # Block table | ||||
|         block_table_cpu = None | ||||
|         if num_computed_tokens > 0: | ||||
|             block_table_cpu = self.input_batch.block_table.get_cpu_tensor() | ||||
|             block_table_cpu = block_table_cpu[req_index] | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    block_table_cpu = {}".format(block_table_cpu)) | ||||
|  | ||||
|         # Context len | ||||
|         self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 | ||||
|         if num_computed_tokens > 0: | ||||
|             self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len | ||||
|  | ||||
|         # Effective query len | ||||
|         self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len | ||||
|  | ||||
|         # Get final tensors | ||||
|         input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) | ||||
|         input_positions = self.input_positions_cpu[ | ||||
|             self.cur_swap_id][:padded_prompt_len].reshape(1, | ||||
|                                                           -1).to(self.device) | ||||
|         slot_mapping = self.slot_mapping_cpu[ | ||||
|             self.cur_swap_id][:padded_prompt_len].reshape(1, | ||||
|                                                           -1).to(self.device) | ||||
|         block_table = block_table_cpu.reshape(1, -1).to( | ||||
|             self.device) if block_table_cpu is not None else None | ||||
|  | ||||
|         context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( | ||||
|             self.device) | ||||
|         effective_query_lens = self.prompt_effective_query_lens_cpu[ | ||||
|             self.cur_swap_id].to(self.device) | ||||
|  | ||||
|         self.swap_step() | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    input_tokens.shape = {} val = {}".format( | ||||
|         #     input_tokens.shape, input_tokens)) | ||||
|         # print("    input_positions.shape = {} val = {}".format( | ||||
|         #     input_positions.shape, input_positions)) | ||||
|         # print("    slot_mapping.shape = {} val = {}".format( | ||||
|         #     slot_mapping.shape, slot_mapping)) | ||||
|         # print("    block_table = {}".format(block_table)) | ||||
|         # print("    context_lens.shape = {} val = {}".format( | ||||
|         #     context_lens.shape, context_lens)) | ||||
|         # print("    effective_query_lens.shape = {} val = {}".format( | ||||
|         #     effective_query_lens.shape, effective_query_lens)) | ||||
|  | ||||
|         # Attn metadata | ||||
|         attn_metadata = PallasMetadata( | ||||
|             num_prefills=1, | ||||
|             num_prefill_tokens=0,  # NOTE: This is not used. | ||||
|             num_decode_tokens=0, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=None, | ||||
|             enable_kv_scales_calculation=True, | ||||
|             block_tables=block_table, | ||||
|             context_lens=context_lens, | ||||
|             effective_query_lens=effective_query_lens, | ||||
|         ) | ||||
|  | ||||
|         return PromptData(input_tokens, input_positions, attn_metadata) | ||||
|  | ||||
|     def _prepare_decode( | ||||
|         self, | ||||
|         decode_req_ids: List[str], | ||||
|     ) -> DecodeData: | ||||
|         # Batch size | ||||
|         batch_size = len(decode_req_ids) | ||||
|         padded_batch_size = _get_padded_batch_size(batch_size) | ||||
|         assert padded_batch_size <= self.max_model_len | ||||
|  | ||||
|         # Init [0 .. batch_size - 1] | ||||
|         req_indices_np = self.arange_np[:padded_batch_size] | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("_prepare_decode:") | ||||
|         # print("    batch_size = {}".format(batch_size)) | ||||
|         # print("    padded_batch_size = {}".format(padded_batch_size)) | ||||
|         # print("    req_indices_np.shape = {} val = {}".format( | ||||
|         #     req_indices_np.shape, req_indices_np)) | ||||
|  | ||||
|         # Input positions | ||||
|         input_positions_np = self.input_positions_np[ | ||||
|             self.cur_swap_id][:padded_batch_size] | ||||
|         np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], | ||||
|                0, | ||||
|                out=input_positions_np) | ||||
|         input_positions_np[batch_size:] = 0 | ||||
|         input_positions_cpu = self.input_positions_cpu[ | ||||
|             self.cur_swap_id][:padded_batch_size] | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    input_positions_cpu.shape = {} data = {}".format( | ||||
|         #     input_positions_cpu.shape, input_positions_cpu)) | ||||
|  | ||||
|         # Input tokens | ||||
|         token_indices_np = ( | ||||
|             input_positions_np + | ||||
|             req_indices_np * self.input_batch.token_ids_cpu.shape[1]) | ||||
|         input_tokens_cpu = self.input_ids_cpu[ | ||||
|             self.cur_swap_id][:padded_batch_size] | ||||
|         torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), | ||||
|                            0, | ||||
|                            torch.from_numpy(token_indices_np), | ||||
|                            out=input_tokens_cpu) | ||||
|         input_tokens_cpu[batch_size:] = 0 | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    token_indices_np.shape = {} val = {}".format( | ||||
|         #     token_indices_np.shape, token_indices_np)) | ||||
|         # print("    input_tokens_cpu.shape = {} data = {}".format( | ||||
|         #     input_tokens_cpu.shape, input_tokens_cpu)) | ||||
|  | ||||
|         # Slot mapping | ||||
|         block_table_indices_np = ( | ||||
|             req_indices_np * self.max_num_blocks_per_req + | ||||
|             input_positions_np // self.block_size) | ||||
|  | ||||
|         # DEBUG | ||||
|         # print( | ||||
|         #     "    block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}" | ||||
|         #     .format(block_table_indices_np.shape, block_table_indices_np, | ||||
|         #             self.max_num_blocks_per_req)) | ||||
|  | ||||
|         block_table_cpu = self.input_batch.block_table.get_cpu_tensor() | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    block_table_cpu.shape = {} data = {}".format( | ||||
|         #     block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10])) | ||||
|  | ||||
|         block_numbers_np = block_table_cpu.flatten( | ||||
|         )[block_table_indices_np].numpy() | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    block_numbers_np.shape = {} data = {}".format( | ||||
|         #     block_numbers_np.shape, block_numbers_np)) | ||||
|  | ||||
|         block_offsets_np = input_positions_np % self.block_size | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    block_offsets_np.shape = {} data = {}".format( | ||||
|         #     block_offsets_np.shape, block_offsets_np)) | ||||
|  | ||||
|         slot_mapping_np = self.slot_mapping_np[ | ||||
|             self.cur_swap_id][:padded_batch_size] | ||||
|         np.add(block_numbers_np * self.block_size, | ||||
|                block_offsets_np, | ||||
|                out=slot_mapping_np) | ||||
|         slot_mapping_np[batch_size:] = _PAD_SLOT_ID | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    slot_mapping_np.shape = {} data = {}".format( | ||||
|         #     slot_mapping_np.shape, slot_mapping_np)) | ||||
|  | ||||
|         block_table_cpu = block_table_cpu[:padded_batch_size] | ||||
|  | ||||
|         # Context lens | ||||
|         context_lens_np = self.decode_context_lens_np[ | ||||
|             self.cur_swap_id][:padded_batch_size] | ||||
|         np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], | ||||
|                1, | ||||
|                out=context_lens_np) | ||||
|         context_lens_np[batch_size:] = 0 | ||||
|  | ||||
|         # Get final tensors | ||||
|         input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) | ||||
|         input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) | ||||
|         slot_mapping = self.slot_mapping_cpu[ | ||||
|             self.cur_swap_id][:padded_batch_size].reshape(-1, | ||||
|                                                           1).to(self.device) | ||||
|         block_table = block_table_cpu.to(self.device) | ||||
|         context_lens = self.decode_context_lens_cpu[ | ||||
|             self.cur_swap_id][:padded_batch_size].to(self.device) | ||||
|  | ||||
|         self.swap_step() | ||||
|  | ||||
|         # DEBUG | ||||
|         # print("    context_lens.shape = {} val = {}".format( | ||||
|         #     context_lens.shape, context_lens)) | ||||
|  | ||||
|         # Attn metadata | ||||
|         attn_metadata = PallasMetadata( | ||||
|             num_prefills=0, | ||||
|             num_prefill_tokens=0, | ||||
|             num_decode_tokens=padded_batch_size, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=None, | ||||
|             enable_kv_scales_calculation=True, | ||||
|             block_tables=block_table, | ||||
|             context_lens=context_lens, | ||||
|             effective_query_lens=None, | ||||
|         ) | ||||
|  | ||||
|         return DecodeData(input_tokens=input_tokens, | ||||
|                           input_positions=input_positions, | ||||
|                           attn_metadata=attn_metadata) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> ModelRunnerOutput: | ||||
|         # Update cached state | ||||
|         self.update_states(scheduler_output) | ||||
|  | ||||
|         # If necessary, swap decodes/prompts to have all decodes on the start | ||||
|         ensure_decodes_first(self.input_batch) | ||||
|  | ||||
|         # Prepare prompts/decodes info | ||||
|         pd_info = self._get_prompts_and_decodes(scheduler_output) | ||||
|  | ||||
|         # Init | ||||
|         num_prompts = len(pd_info.prompt_req_ids) | ||||
|         num_decodes = len(pd_info.decode_req_ids) | ||||
|         decode_data = None | ||||
|         sampled_token_ids = [0] * self.input_batch.num_reqs | ||||
|  | ||||
|         # Run each prompt individually | ||||
|         is_first = True | ||||
|         for i in range(num_prompts): | ||||
|             req_id = pd_info.prompt_req_ids[i] | ||||
|             req_index = num_decodes + i | ||||
|             assert req_index == self.input_batch.req_id_to_index[ | ||||
|                 req_id]  # TODO: Remove | ||||
|             req_state = self.requests[req_id] | ||||
|             num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] | ||||
|             prompt_len = num_scheduled_tokens | ||||
|             seq_len = req_state.num_computed_tokens + num_scheduled_tokens | ||||
|  | ||||
|             # Prepare first prompt | ||||
|             if is_first: | ||||
|                 prompt_data = self._prepare_prompt(req_index, | ||||
|                                                    num_scheduled_tokens) | ||||
|                 is_first = False | ||||
|  | ||||
|             # Run forward pass | ||||
|             with set_forward_context(prompt_data.attn_metadata, | ||||
|                                      self.vllm_config): | ||||
|                 assert self.model is not None | ||||
|                 selected_token_ids = self.model(prompt_data.input_tokens, | ||||
|                                                 prompt_data.input_positions, | ||||
|                                                 prompt_data.attn_metadata, | ||||
|                                                 self.kv_caches) | ||||
|  | ||||
|             # In parallel to TPU execution, prepare the next iteration | ||||
|             if i < num_prompts - 1: | ||||
|                 # There is next prompt => prepare it | ||||
|                 prompt_data = self._prepare_prompt( | ||||
|                     req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) | ||||
|             elif i == num_prompts - 1 and num_decodes > 0: | ||||
|                 # There is next decode => prepare it | ||||
|                 decode_data = self._prepare_decode(pd_info.decode_req_ids) | ||||
|  | ||||
|             # Update cached state (if prompt is fully done) | ||||
|             if seq_len >= len(req_state.prompt_token_ids): | ||||
|                 # Transfer sampled tokens from TPU to CPU | ||||
|                 selected_token_ids_cpu = selected_token_ids.cpu() | ||||
|  | ||||
|                 # Get output token | ||||
|                 token_id = selected_token_ids_cpu[prompt_len - 1].item() | ||||
|                 sampled_token_ids[req_index] = token_id | ||||
|  | ||||
|                 # DEBUG | ||||
|                 # print( | ||||
|                 #     "    -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}" | ||||
|                 #     .format(token_id, prompt_len, req_id, req_index, | ||||
|                 #             selected_token_ids_cpu)) | ||||
|  | ||||
|                 # Add output token to the request | ||||
|                 self.input_batch.token_ids_cpu[req_index, seq_len] = token_id | ||||
|                 self.input_batch.num_tokens[req_index] += 1 | ||||
|                 req_state.output_token_ids.append(token_id) | ||||
|  | ||||
|         # Run decodes (a single batch) | ||||
|         if num_decodes > 0: | ||||
|  | ||||
|             # Prepare decode (if was not yet prepared) | ||||
|             if decode_data is None: | ||||
|                 decode_data = self._prepare_decode(pd_info.decode_req_ids) | ||||
|  | ||||
|             # Run forward pass | ||||
|             with set_forward_context(decode_data.attn_metadata, | ||||
|                                      self.vllm_config): | ||||
|                 assert self.model is not None | ||||
|                 selected_token_ids = self.model(decode_data.input_tokens, | ||||
|                                                 decode_data.input_positions, | ||||
|                                                 decode_data.attn_metadata, | ||||
|                                                 self.kv_caches) | ||||
|  | ||||
|             # Transfer sampled tokens from TPU to CPU | ||||
|             decode_token_ids_cpu = selected_token_ids.cpu() | ||||
|             # Convert to list | ||||
|             decode_token_ids_list = decode_token_ids_cpu.tolist() | ||||
|  | ||||
|             # Update cached state for each decode request | ||||
|             for i in range(num_decodes): | ||||
|                 req_id = pd_info.decode_req_ids[i] | ||||
|                 req_index = i | ||||
|                 assert req_index == self.input_batch.req_id_to_index[ | ||||
|                     req_id]  # TODO: Remove | ||||
|                 req_state = self.requests[req_id] | ||||
|                 seq_len = req_state.num_computed_tokens + 1 | ||||
|  | ||||
|                 token_id = decode_token_ids_list[i] | ||||
|                 sampled_token_ids[req_index] = token_id | ||||
|  | ||||
|                 self.input_batch.token_ids_cpu[req_index, seq_len] = token_id | ||||
|                 self.input_batch.num_tokens[req_index] += 1 | ||||
|                 req_state.output_token_ids.append(token_id) | ||||
|  | ||||
|         # Create output | ||||
|         model_runner_output = ModelRunnerOutput( | ||||
|             req_ids=self.input_batch.req_ids, | ||||
|             req_id_to_index=self.input_batch.req_id_to_index, | ||||
|             sampled_token_ids=sampled_token_ids, | ||||
|             logprob_token_ids_cpu=None, | ||||
|             logprobs_cpu=None, | ||||
|         ) | ||||
|  | ||||
|         return model_runner_output | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         self.device = self.device_config.device | ||||
|  | ||||
|         # NOTE(woosuk): While the executor assigns the TP ranks to the worker | ||||
|         # process, the ranks can be different from the ranks internally assigned | ||||
|         # by the xm runtime. Therefore, there is a mismatch in the rank | ||||
|         # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. | ||||
|         # This is not a problem in linear layers because all-reduce is | ||||
|         # rank-agnostic. However, it matters for all-gather as the ranks | ||||
|         # determine the order of concatenating the output tensors. | ||||
|         # As a workaround, we use the xm's rank assignment only when loading | ||||
|         # the embedding weights. | ||||
|         xm_tp_rank = xr.global_ordinal() | ||||
|         with patch( | ||||
|                 "vllm.model_executor.layers.vocab_parallel_embedding." | ||||
|                 "get_tensor_model_parallel_rank", | ||||
|                 return_value=xm_tp_rank): | ||||
|             model = get_model(vllm_config=self.vllm_config) | ||||
|         model = model.eval() | ||||
|         xm.wait_device_ops() | ||||
|         model = ModelWrapperV1(model) | ||||
|         self.model = torch.compile(model, | ||||
|                                    backend="openxla", | ||||
|                                    fullgraph=True, | ||||
|                                    dynamic=False) | ||||
|  | ||||
|     def dummy_run( | ||||
|         self, | ||||
|         kv_caches, | ||||
|         num_tokens: int, | ||||
|         seq_len: Optional[int] = None, | ||||
|         exec_mode: Optional[ExecutionMode] = None, | ||||
|     ) -> None: | ||||
|         assert seq_len is not None | ||||
|         assert exec_mode is not None | ||||
|  | ||||
|         exec_mode = ExecutionMode(exec_mode) | ||||
|         if exec_mode.is_prefill(): | ||||
|             seq_len = (seq_len + 15) // 16 * 16 | ||||
|             token_ids = torch.zeros((num_tokens, seq_len), | ||||
|                                     dtype=torch.int32, | ||||
|                                     device=self.device) | ||||
|             position_ids = torch.zeros((num_tokens, seq_len), | ||||
|                                        dtype=torch.int32, | ||||
|                                        device=self.device) | ||||
|             slot_mapping = torch.zeros((num_tokens, seq_len), | ||||
|                                        dtype=torch.int64, | ||||
|                                        device=self.device) | ||||
|             if exec_mode == ExecutionMode.PREFILL: | ||||
|                 attn_metadata = PallasMetadata( | ||||
|                     num_prefills=num_tokens, | ||||
|                     num_prefill_tokens=num_tokens * seq_len, | ||||
|                     num_decode_tokens=0, | ||||
|                     slot_mapping=slot_mapping, | ||||
|                     multi_modal_placeholder_index_maps=None, | ||||
|                     enable_kv_scales_calculation=True, | ||||
|                     block_tables=None, | ||||
|                     context_lens=None, | ||||
|                     effective_query_lens=None, | ||||
|                 ) | ||||
|  | ||||
|             else: | ||||
|                 context_lens = torch.ones((num_tokens, ), | ||||
|                                           dtype=torch.int32, | ||||
|                                           device=self.device) | ||||
|  | ||||
|                 block_tables = torch.zeros( | ||||
|                     (num_tokens, self.max_num_blocks_per_req), | ||||
|                     dtype=torch.int32, | ||||
|                     device=self.device) | ||||
|  | ||||
|                 effective_query_lens = torch.ones_like(context_lens) | ||||
|  | ||||
|                 attn_metadata = PallasMetadata( | ||||
|                     num_prefills=num_tokens, | ||||
|                     num_prefill_tokens=num_tokens * seq_len, | ||||
|                     num_decode_tokens=0, | ||||
|                     slot_mapping=slot_mapping, | ||||
|                     multi_modal_placeholder_index_maps=None, | ||||
|                     enable_kv_scales_calculation=True, | ||||
|                     block_tables=block_tables, | ||||
|                     context_lens=context_lens, | ||||
|                     effective_query_lens=effective_query_lens, | ||||
|                 ) | ||||
|         else: | ||||
|             assert seq_len == 1 | ||||
|             token_ids = torch.zeros((num_tokens, seq_len), | ||||
|                                     dtype=torch.int32, | ||||
|                                     device=self.device) | ||||
|             position_ids = torch.zeros((num_tokens, seq_len), | ||||
|                                        dtype=torch.int32, | ||||
|                                        device=self.device) | ||||
|             slot_mapping = torch.zeros((num_tokens, seq_len), | ||||
|                                        dtype=torch.int64, | ||||
|                                        device=self.device) | ||||
|             block_tables = torch.zeros( | ||||
|                 (num_tokens, self.max_num_blocks_per_req), | ||||
|                 dtype=torch.int32, | ||||
|                 device=self.device) | ||||
|             context_lens = torch.ones((num_tokens, ), | ||||
|                                       dtype=torch.int32, | ||||
|                                       device=self.device) | ||||
|             attn_metadata = PallasMetadata( | ||||
|                 num_prefills=0, | ||||
|                 num_prefill_tokens=0, | ||||
|                 num_decode_tokens=num_tokens * seq_len, | ||||
|                 slot_mapping=slot_mapping, | ||||
|                 multi_modal_placeholder_index_maps=None, | ||||
|                 enable_kv_scales_calculation=True, | ||||
|                 block_tables=block_tables, | ||||
|                 context_lens=context_lens, | ||||
|             ) | ||||
|  | ||||
|         # NOTE(woosuk): There are two stages of compilation: torch.compile and | ||||
|         # XLA compilation. Using `mark_dynamic` can reduce the torch.compile | ||||
|         # overhead by reusing the FX graph for different shapes. | ||||
|         # However, the XLA graph will still require static shapes and needs to | ||||
|         # be re-compiled for every different shapes. This overhead is inevitable | ||||
|         # in the first run, but can be skipped afterwards as we cache the XLA | ||||
|         # graphs in the disk (VLLM_XLA_CACHE_PATH). | ||||
|         if exec_mode.is_prefill(): | ||||
|             # Prefll | ||||
|             torch._dynamo.mark_dynamic(token_ids, 1) | ||||
|             torch._dynamo.mark_dynamic(position_ids, 1) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) | ||||
|         else: | ||||
|             # Decode | ||||
|             torch._dynamo.mark_dynamic(token_ids, 0) | ||||
|             torch._dynamo.mark_dynamic(position_ids, 0) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) | ||||
|  | ||||
|         with set_forward_context(attn_metadata, self.vllm_config, 0): | ||||
|             assert self.model is not None | ||||
|             self.model(token_ids, position_ids, attn_metadata, kv_caches) | ||||
|  | ||||
|     def capture_model(self) -> None: | ||||
|         """Compile the model.""" | ||||
|  | ||||
|         # Prefill | ||||
|         logger.info( | ||||
|             "Compiling the model with different input shapes for prefill:") | ||||
|         start = time.time() | ||||
|         for batch_size in [1]: | ||||
|             seq_len = 16 | ||||
|             while seq_len <= self.model_config.max_model_len: | ||||
|                 self.dummy_run(self.kv_caches, | ||||
|                                batch_size, | ||||
|                                seq_len, | ||||
|                                exec_mode=ExecutionMode.PREFILL) | ||||
|                 xm.wait_device_ops() | ||||
|                 logger.info("  batch_size: %d, seq_len: %d", batch_size, | ||||
|                             seq_len) | ||||
|                 num_tokens = batch_size * seq_len | ||||
|                 if num_tokens >= self.scheduler_config.max_num_batched_tokens: | ||||
|                     break | ||||
|                 seq_len = seq_len * 2 | ||||
|  | ||||
|         end = time.time() | ||||
|         logger.info("    -- Compilation for prefill done in %.2f [secs].", | ||||
|                     end - start) | ||||
|  | ||||
|         # Prefix prefill | ||||
|         if self.scheduler_config.enable_chunked_prefill: | ||||
|             logger.info("Compiling the model with different input shapes for " | ||||
|                         "prefix prefill:") | ||||
|             start = time.time() | ||||
|             for batch_size in [1]: | ||||
|                 seq_len = 16 | ||||
|                 while seq_len <= self.model_config.max_model_len: | ||||
|                     self.dummy_run(self.kv_caches, | ||||
|                                    batch_size, | ||||
|                                    seq_len, | ||||
|                                    exec_mode=ExecutionMode.PREFIX_PREFILL) | ||||
|                     xm.wait_device_ops() | ||||
|                     logger.info("  batch_size: %d, seq_len: %d", batch_size, | ||||
|                                 seq_len) | ||||
|                     num_tokens = batch_size * seq_len | ||||
|                     if (num_tokens | ||||
|                             >= self.scheduler_config.max_num_batched_tokens): | ||||
|                         break | ||||
|                     seq_len = seq_len * 2 | ||||
|             end = time.time() | ||||
|             logger.info( | ||||
|                 "    -- Compilation for prefix prefill done in %.2f [secs].", | ||||
|                 end - start) | ||||
|  | ||||
|         # Decode | ||||
|         logger.info( | ||||
|             "Compiling the model with different input shapes for decode:") | ||||
|         start = time.time() | ||||
|         seq_len = 1 | ||||
|         batch_size = 8  # Must be in sync with _get_padded_batch_size() | ||||
|         while True: | ||||
|             self.dummy_run(self.kv_caches, | ||||
|                            batch_size, | ||||
|                            seq_len, | ||||
|                            exec_mode=ExecutionMode.DECODE) | ||||
|             xm.wait_device_ops() | ||||
|             logger.info("  batch_size: %d, seq_len: %d", batch_size, seq_len) | ||||
|  | ||||
|             if batch_size >= self.scheduler_config.max_num_seqs: | ||||
|                 break | ||||
|             batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 | ||||
|  | ||||
|         end = time.time() | ||||
|         logger.info("    -- Compilation for decode done in %.2f [secs].", | ||||
|                     end - start) | ||||
|  | ||||
|     def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | ||||
|         """ | ||||
|         Initialize KV cache based on `kv_cache_config`. | ||||
|         Args: | ||||
|             kv_cache_config: Configuration for the KV cache, including the KV  | ||||
|             cache size of each layer | ||||
|         """ | ||||
|         if len(kv_cache_config.groups) > 1: | ||||
|             raise NotImplementedError( | ||||
|                 "Hybrid models with more than one KV cache type are not " | ||||
|                 "supported yet.") | ||||
|  | ||||
|         kv_caches: Dict[str, torch.Tensor] = {} | ||||
|  | ||||
|         for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): | ||||
|             tensor_config = kv_cache_config.tensors[layer_name] | ||||
|             assert tensor_config.size % layer_spec.page_size_bytes == 0 | ||||
|             num_blocks = tensor_config.size // layer_spec.page_size_bytes | ||||
|             if isinstance(layer_spec, FullAttentionSpec): | ||||
|                 kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( | ||||
|                     num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, | ||||
|                     layer_spec.head_size) | ||||
|                 dtype = layer_spec.dtype | ||||
|  | ||||
|                 tpu_k_cache = torch.zeros(kv_cache_shape, | ||||
|                                           dtype=dtype, | ||||
|                                           device=self.device) | ||||
|                 tpu_v_cache = torch.zeros_like(tpu_k_cache) | ||||
|  | ||||
|                 kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) | ||||
|             else: | ||||
|                 raise NotImplementedError | ||||
|  | ||||
|         bind_kv_cache( | ||||
|             kv_caches, | ||||
|             self.vllm_config.compilation_config.static_forward_context, | ||||
|             self.kv_caches) | ||||
|  | ||||
|  | ||||
| class ModelWrapperV1(nn.Module): | ||||
|  | ||||
|     def __init__(self, model: nn.Module): | ||||
|         super().__init__() | ||||
|         self.model = model | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         token_ids: torch.Tensor, | ||||
|         position_ids: torch.Tensor, | ||||
|         attn_metadata: AttentionMetadata, | ||||
|         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], | ||||
|     ) -> torch.Tensor: | ||||
|         """Executes the forward pass of the model and samples the next token. | ||||
|  | ||||
|         Args: | ||||
|             token_ids: The input token IDs of shape [batch_size, seq_len]. | ||||
|             position_ids: The input position IDs of shape [batch_size, seq_len]. | ||||
|             attn_metadata: The Pallas attention metadata. | ||||
|             input_lens: The actual input lengths of shape [batch_size]. | ||||
|             t: The sampling temperature of shape [batch_size]. | ||||
|             p: The top-p probability of shape [batch_size]. | ||||
|             num_samples: Number of samples to draw from each logits vector. | ||||
|             kv_caches: The key and value caches. They can be None during the | ||||
|                 memory profiling at initialization. | ||||
|         """ | ||||
|         # Skip this in memory profiling at initialization. | ||||
|         if attn_metadata is not None and kv_caches[0][0].numel() > 0: | ||||
|             # index_copy_(slot_mapping) only works when the inserted dimension | ||||
|             # is 0. However, the KV cache in the Pallas backend has the shape | ||||
|             # [num_kv_heads, num_blocks, block_size, head_size]. To make it | ||||
|             # work, we need to flatten the first three dimensions and modify | ||||
|             # the slot_mapping accordingly. | ||||
|             num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape | ||||
|             slot_mapping = attn_metadata.slot_mapping | ||||
|             slot_mapping = slot_mapping.flatten() | ||||
|             head_indicies = torch.arange(0, | ||||
|                                          num_kv_heads, | ||||
|                                          device=slot_mapping.device, | ||||
|                                          dtype=slot_mapping.dtype) | ||||
|             head_indicies *= block_size * num_blocks | ||||
|             slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( | ||||
|                 -1, num_kv_heads) | ||||
|             slot_mapping = slot_mapping + head_indicies.view(1, -1) | ||||
|             slot_mapping = slot_mapping.flatten() | ||||
|             attn_metadata.slot_mapping = slot_mapping | ||||
|  | ||||
|         assert self.model is not None | ||||
|         hidden_states = self.model( | ||||
|             token_ids, | ||||
|             position_ids, | ||||
|             kv_caches, | ||||
|             attn_metadata, | ||||
|         ) | ||||
|  | ||||
|         hidden_states = hidden_states.flatten(0, 1) | ||||
|         logits = self.model.compute_logits(hidden_states, None) | ||||
|  | ||||
|         # Greedy sampling. | ||||
|         argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) | ||||
|         argmax_token_ids = argmax_token_ids.squeeze(dim=-1) | ||||
|         return argmax_token_ids | ||||
|  | ||||
|  | ||||
| def _get_padded_prompt_len(x: int) -> int: | ||||
|     # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence | ||||
|     # length to be a multiple of 16. We pad the prompt length to the nearest | ||||
|     # multiple of 16. This is also good for performance. | ||||
|     if x <= 16: | ||||
|         return 16 | ||||
|     return 1 << (x - 1).bit_length() | ||||
|  | ||||
|  | ||||
| def _get_padded_batch_size(batch_size: int) -> int: | ||||
|     # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. | ||||
|     # To meet this requirement in the simplest way, we set the minimal batch | ||||
|     # size to 8. | ||||
|     if batch_size <= 8: | ||||
|         return 8 | ||||
|     else: | ||||
|         return ((batch_size + 15) // 16) * 16 | ||||
							
								
								
									
										153
									
								
								vllm/v1/worker/tpu_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								vllm/v1/worker/tpu_worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,153 @@ | ||||
| """A TPU worker class.""" | ||||
| import os | ||||
| from typing import Optional, Dict | ||||
|  | ||||
| import torch | ||||
| import torch.distributed | ||||
| import torch_xla.core.xla_model as xm | ||||
| import torch_xla.runtime as xr | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.config import ParallelConfig, VllmConfig | ||||
| from vllm.distributed import (ensure_model_parallel_initialized, | ||||
|                               init_distributed_environment) | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor import set_random_seed | ||||
| from vllm.v1.kv_cache_interface import FullAttentionSpec | ||||
| from vllm.v1.attention.backends.pallas import PallasAttentionBackend | ||||
| from vllm.v1.core.scheduler import SchedulerOutput | ||||
| from vllm.v1.outputs import ModelRunnerOutput | ||||
| from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner | ||||
| from vllm.v1.worker.worker_base import WorkerBase | ||||
| from vllm.v1.utils import bind_kv_cache | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class TPUWorker(WorkerBase): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         local_rank: int, | ||||
|         rank: int, | ||||
|         distributed_init_method: str, | ||||
|         is_driver_worker: bool = False, | ||||
|     ): | ||||
|         super().__init__(vllm_config, local_rank, rank, | ||||
|                          distributed_init_method) | ||||
|  | ||||
|     def init_device(self): | ||||
|         os.environ["PJRT_DEVICE"] = "TPU" | ||||
|         torch.set_grad_enabled(False) | ||||
|         torch.set_default_dtype(self.model_config.dtype) | ||||
|  | ||||
|         # Initialize the distributed environment. | ||||
|         init_tpu_worker_distributed_environment(self.parallel_config, | ||||
|                                                 self.rank, | ||||
|                                                 self.distributed_init_method, | ||||
|                                                 self.local_rank) | ||||
|  | ||||
|         # Device initialization should happen after initializing | ||||
|         # the distributed runtime. | ||||
|         self.device = xm.xla_device() | ||||
|         self.device_config.device = self.device | ||||
|  | ||||
|         # Set random seed. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|         xm.set_rng_state(self.model_config.seed, self.device) | ||||
|  | ||||
|         # Increase the cache size limit, which is the maximum number of | ||||
|         # dynamo graphs that can be compiled. | ||||
|         # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and | ||||
|         # 30-40 graphs for decode. 128 is an arbitrary safe number. | ||||
|         torch._dynamo.config.cache_size_limit = 128 | ||||
|         # Use persistent cache to avoid XLA recompilation. | ||||
|         # NOTE(woosuk): Set per-rank cache path since different ranks | ||||
|         # can have slightly different XLA graphs. | ||||
|         world_size = self.parallel_config.world_size | ||||
|         rank = xr.global_ordinal() | ||||
|         per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, | ||||
|                                      f"tp{world_size}_rank{rank}") | ||||
|         xr.initialize_cache(per_rank_path, readonly=False) | ||||
|  | ||||
|         # Init ModelRunner here, so that we have access to self.device. | ||||
|         self.model_runner = TPUModelRunner(self.vllm_config, self.device) | ||||
|  | ||||
|     def determine_available_memory(self) -> int: | ||||
|         assert self.model_runner is not None | ||||
|  | ||||
|         kv_caches: Dict[str, torch.Tensor] = {} | ||||
|         kv_cache_spec = self.model_runner.get_kv_cache_spec() | ||||
|         for layer_name, layer_spec in kv_cache_spec.items(): | ||||
|             if isinstance(layer_spec, FullAttentionSpec): | ||||
|                 dtype = layer_spec.dtype | ||||
|  | ||||
|                 # Use an empty tensor instead of `None`` to force Dynamo to pass | ||||
|                 # it by reference, rather by specializing on the value ``None``. | ||||
|                 tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device) | ||||
|                 tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device) | ||||
|  | ||||
|                 kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) | ||||
|             else: | ||||
|                 raise NotImplementedError | ||||
|  | ||||
|         runner_kv_caches = [] | ||||
|         bind_kv_cache( | ||||
|             kv_caches, | ||||
|             self.vllm_config.compilation_config.static_forward_context, | ||||
|             runner_kv_caches) | ||||
|  | ||||
|         self.model_runner.dummy_run( | ||||
|             runner_kv_caches, | ||||
|             num_tokens=1, | ||||
|             seq_len=self.scheduler_config.max_num_batched_tokens, | ||||
|             exec_mode=ExecutionMode.PREFILL, | ||||
|         ) | ||||
|  | ||||
|         # Synchronize before measuring the memory usage. | ||||
|         xm.wait_device_ops() | ||||
|  | ||||
|         # Get the maximum amount of memory used by the model weights and | ||||
|         # intermediate activations. | ||||
|         m = xm.get_memory_info(self.device) | ||||
|         total_memory_size = m["bytes_limit"] | ||||
|         profiled = m["peak_bytes_used"]  # Weights + intermediate activations. | ||||
|  | ||||
|         # Calculate the TPU KV cache size based on profiling. | ||||
|         usable_memory_size = int(total_memory_size * | ||||
|                                  self.cache_config.gpu_memory_utilization) | ||||
|         tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) | ||||
|  | ||||
|         return int(tpu_kv_cache_bytes) | ||||
|  | ||||
|     def execute_model( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> Optional[ModelRunnerOutput]: | ||||
|         assert self.model_runner is not None | ||||
|         output = self.model_runner.execute_model(scheduler_output) | ||||
|         return output if self.rank == 0 else None | ||||
|  | ||||
|  | ||||
| def init_tpu_worker_distributed_environment( | ||||
|     parallel_config: ParallelConfig, | ||||
|     rank: int, | ||||
|     distributed_init_method: Optional[str] = None, | ||||
|     local_rank: int = -1, | ||||
| ) -> None: | ||||
|     """Initialize the distributed environment.""" | ||||
|  | ||||
|     # NOTE(woosuk): This is just to initialize the TP group and broadcast | ||||
|     # the input objects on CPU. The all-reduce and all-gather ops on TPU | ||||
|     # are invoked by `xm.all_reduce` and `xm.all_gather` which use their | ||||
|     # own context. | ||||
|     init_distributed_environment( | ||||
|         world_size=parallel_config.world_size, | ||||
|         rank=rank, | ||||
|         local_rank=local_rank, | ||||
|         distributed_init_method=distributed_init_method, | ||||
|         backend="gloo", | ||||
|     ) | ||||
|     ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, | ||||
|                                       parallel_config.pipeline_parallel_size) | ||||
							
								
								
									
										173
									
								
								vllm/v1/worker/worker_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								vllm/v1/worker/worker_base.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,173 @@ | ||||
| """A GPU worker class.""" | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| import torch | ||||
| import torch.distributed | ||||
| import torch.nn as nn | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor import set_random_seed | ||||
| from vllm.platforms import current_platform | ||||
| from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size | ||||
| from vllm.v1.core.scheduler import SchedulerOutput | ||||
| from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec | ||||
| from vllm.v1.outputs import ModelRunnerOutput | ||||
| from vllm.v1.worker.model_runner_base import ModelRunnerBase | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.v1.core.scheduler import SchedulerOutput | ||||
|  | ||||
|  | ||||
| class WorkerBase: | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         local_rank: int, | ||||
|         rank: int, | ||||
|         distributed_init_method: str, | ||||
|         is_driver_worker: bool = False, | ||||
|     ): | ||||
|         self.vllm_config = vllm_config | ||||
|         self.model_config = vllm_config.model_config | ||||
|         self.cache_config = vllm_config.cache_config | ||||
|         self.lora_config = vllm_config.lora_config | ||||
|         self.load_config = vllm_config.load_config | ||||
|         self.parallel_config = vllm_config.parallel_config | ||||
|         self.scheduler_config = vllm_config.scheduler_config | ||||
|         self.device_config = vllm_config.device_config | ||||
|         self.speculative_config = vllm_config.speculative_config | ||||
|         self.prompt_adapter_config = vllm_config.prompt_adapter_config | ||||
|         self.observability_config = vllm_config.observability_config | ||||
|  | ||||
|         self.parallel_config.rank = rank | ||||
|         self.local_rank = local_rank | ||||
|         self.rank = rank | ||||
|         self.distributed_init_method = distributed_init_method | ||||
|  | ||||
|         if self.cache_config.cache_dtype == "auto": | ||||
|             self.cache_dtype = self.model_config.dtype | ||||
|         else: | ||||
|             self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ | ||||
|                 self.cache_config.cache_dtype] | ||||
|  | ||||
|         if self.model_config.trust_remote_code: | ||||
|             # note: lazy import to avoid importing torch before initializing | ||||
|             from vllm.utils import init_cached_hf_modules | ||||
|             init_cached_hf_modules() | ||||
|  | ||||
|         # Torch profiler. Enabled and configured through env vars: | ||||
|         # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace | ||||
|         if envs.VLLM_TORCH_PROFILER_DIR: | ||||
|             torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR | ||||
|             logger.info("Profiling enabled. Traces will be saved to: %s", | ||||
|                         torch_profiler_trace_dir) | ||||
|             self.profiler = torch.profiler.profile( | ||||
|                 activities=[ | ||||
|                     torch.profiler.ProfilerActivity.CPU, | ||||
|                     torch.profiler.ProfilerActivity.CUDA, | ||||
|                 ], | ||||
|                 with_stack=True, | ||||
|                 on_trace_ready=torch.profiler.tensorboard_trace_handler( | ||||
|                     torch_profiler_trace_dir, use_gzip=True)) | ||||
|         else: | ||||
|             self.profiler = None | ||||
|  | ||||
|         # Initialized by the specific platform | ||||
|         self.model_runner: Optional[ModelRunnerBase] = None | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         assert self.model_runner is not None | ||||
|         self.model_runner.load_model() | ||||
|  | ||||
|     def compile_or_warm_up_model(self) -> None: | ||||
|         assert self.model_runner is not None | ||||
|  | ||||
|         if not self.model_config.enforce_eager: | ||||
|             self.model_runner.capture_model() | ||||
|  | ||||
|         # Reset the seed to ensure that the random state is not affected by | ||||
|         # the model initialization and profiling. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         assert self.model_runner is not None | ||||
|         return self.model_runner.get_model() | ||||
|  | ||||
|     def get_kv_cache_spec(self) -> KVCacheSpec: | ||||
|         assert self.model_runner is not None | ||||
|         return self.model_runner.get_kv_cache_spec() | ||||
|  | ||||
|     def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: | ||||
|         """Allocate GPU KV cache with the specified kv_cache_config.""" | ||||
|         assert self.model_runner is not None | ||||
|         self.model_runner.initialize_kv_cache(kv_cache_config) | ||||
|  | ||||
|     def profile(self, is_start: bool = True): | ||||
|         if self.profiler is None: | ||||
|             raise RuntimeError("Profiler is not enabled.") | ||||
|         if is_start: | ||||
|             self.profiler.start() | ||||
|         else: | ||||
|             self.profiler.stop() | ||||
|  | ||||
|     def check_health(self) -> None: | ||||
|         # worker will always be healthy as long as it's running. | ||||
|         return | ||||
|  | ||||
|     def init_device(self): | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def determine_available_memory(self) -> int: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def execute_model( | ||||
|         self, | ||||
|         scheduler_output: "SchedulerOutput", | ||||
|     ) -> Optional[ModelRunnerOutput]: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|  | ||||
| def check_if_gpu_supports_dtype(torch_dtype: torch.dtype): | ||||
|     # Check if the GPU supports the dtype. | ||||
|     if torch_dtype == torch.bfloat16:  # noqa: SIM102 | ||||
|         if not current_platform.has_device_capability(80): | ||||
|             capability = current_platform.get_device_capability() | ||||
|             gpu_name = current_platform.get_device_name() | ||||
|  | ||||
|             if capability is None: | ||||
|                 compute_str = "does not have a compute capability" | ||||
|             else: | ||||
|                 version_str = capability.as_version_str() | ||||
|                 compute_str = f"has compute capability {version_str}" | ||||
|  | ||||
|             raise ValueError( | ||||
|                 "Bfloat16 is only supported on GPUs with compute capability " | ||||
|                 f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " | ||||
|                 "You can use float16 instead by explicitly setting the" | ||||
|                 "`dtype` flag in CLI, for example: --dtype=half.") | ||||
|  | ||||
|  | ||||
| def get_cache_block_size( | ||||
|     cache_config: CacheConfig, | ||||
|     model_config: ModelConfig, | ||||
|     parallel_config: ParallelConfig, | ||||
| ) -> int: | ||||
|     head_size = model_config.get_head_size() | ||||
|     num_heads = model_config.get_num_kv_heads(parallel_config) | ||||
|     num_attention_layers = model_config.get_num_layers_by_block_type( | ||||
|         parallel_config, LayerBlockType.attention) | ||||
|  | ||||
|     key_cache_block = cache_config.block_size * num_heads * head_size | ||||
|     value_cache_block = key_cache_block | ||||
|     total = num_attention_layers * (key_cache_block + value_cache_block) | ||||
|     if cache_config.cache_dtype == "auto": | ||||
|         dtype = model_config.dtype | ||||
|     else: | ||||
|         dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] | ||||
|     dtype_size = get_dtype_size(dtype) | ||||
|     return dtype_size * total | ||||
		Reference in New Issue
	
	Block a user
	