mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			lwilkinson
			...
			v0.9.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a5dd03c1eb | 
| @ -66,10 +66,10 @@ function cpu_tests() { | ||||
|     tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" | ||||
|  | ||||
|   # Run AWQ test | ||||
|   # docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||
|   #   set -e | ||||
|   #   VLLM_USE_V1=0 pytest -s -v \ | ||||
|   #   tests/quantization/test_ipex_quant.py" | ||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||
|     set -e | ||||
|     VLLM_USE_V1=0 pytest -s -v \ | ||||
|     tests/quantization/test_ipex_quant.py" | ||||
|  | ||||
|   # Run chunked-prefill and prefix-cache test | ||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||
|  | ||||
| @ -26,5 +26,7 @@ docker run \ | ||||
|     --name "${container_name}" \ | ||||
|     "${image_name}" \ | ||||
|     sh -c ' | ||||
|     VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m | ||||
|     VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 | ||||
|     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager | ||||
| ' | ||||
|  | ||||
| @ -8,7 +8,7 @@ image: | ||||
|   # -- Image tag | ||||
|   tag: "latest" | ||||
|   # -- Container launch command | ||||
|   command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] | ||||
|   command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] | ||||
|  | ||||
| # -- Container port | ||||
| containerPort: 8000 | ||||
|  | ||||
| @ -36,8 +36,7 @@ DEVICE_REGULAR_ATTN_BACKENDS = { | ||||
| DEVICE_MLA_BLOCK_SIZES = { | ||||
|     "cuda": [16, 64],  # CUDA supports both standard and extended block sizes | ||||
|     "hip": [16, 1],  # HIP requires special handling for block_size=1 | ||||
|     # "cpu": [16]  # CPU uses fixed block size from test cases | ||||
|     "cpu": []  # FIXME(woosuk): Temporarily disable CPU tests | ||||
|     "cpu": [16]  # CPU uses fixed block size from test cases | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -82,14 +81,14 @@ def test_env( | ||||
|         m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") | ||||
|  | ||||
|         if device == "cpu": | ||||
|             if not use_v1: | ||||
|                 pytest.skip("CPU backend only supports V1") | ||||
|  | ||||
|             with patch("vllm.attention.selector.current_platform", | ||||
|                        CpuPlatform()): | ||||
|                 backend = get_attn_backend(16, torch.float16, torch.float16, | ||||
|                                            block_size, False) | ||||
|             if use_v1: | ||||
|                 assert backend.get_name() == "TORCH_SDPA_VLLM_V1" | ||||
|             else: | ||||
|                 assert backend.get_name() == "TORCH_SDPA" | ||||
|  | ||||
|         elif device == "hip": | ||||
|             with patch("vllm.attention.selector.current_platform", | ||||
| @ -205,14 +204,12 @@ def test_fp32_fallback( | ||||
|         m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") | ||||
|  | ||||
|         if device == "cpu": | ||||
|             if not use_v1: | ||||
|                 pytest.skip("CPU backend only supports V1") | ||||
|  | ||||
|             with patch("vllm.attention.selector.current_platform", | ||||
|                        CpuPlatform()): | ||||
|                 backend = get_attn_backend(16, torch.float32, torch.float32, | ||||
|                                            16, False) | ||||
|             assert backend.get_name() == "TORCH_SDPA_VLLM_V1" | ||||
|             assert (backend.get_name() == "TORCH_SDPA_VLLM_V1" | ||||
|                     if use_v1 else "TORCH_SDPA") | ||||
|  | ||||
|         elif device == "cuda": | ||||
|             with patch("vllm.attention.selector.current_platform", | ||||
|  | ||||
							
								
								
									
										307
									
								
								vllm/attention/backends/cpu_mla.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										307
									
								
								vllm/attention/backends/cpu_mla.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,307 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| import torch | ||||
|  | ||||
| import vllm._custom_ops as ops | ||||
| from vllm._ipex_ops import ipex_ops | ||||
| from vllm.attention.backends.abstract import (AttentionBackend, | ||||
|                                               AttentionMetadataBuilder, | ||||
|                                               AttentionType, | ||||
|                                               is_quantized_kv_cache) | ||||
| from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState | ||||
| from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata | ||||
| from vllm.utils import make_tensor_with_pad | ||||
| from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder | ||||
|  | ||||
|  | ||||
| class CPUMLABackend(AttentionBackend): | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_name() -> str: | ||||
|         return "CPU_MLA" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_metadata_cls() -> Type["CPUMLAMetadata"]: | ||||
|         return CPUMLAMetadata | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: | ||||
|         return CPUMLAMetadataBuilder | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_state_cls() -> Type["MLACommonState"]: | ||||
|         return MLACommonState | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_impl_cls() -> Type["CPUMLAImpl"]: | ||||
|         return CPUMLAImpl | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_kv_cache_shape( | ||||
|         num_blocks: int, | ||||
|         block_size: int, | ||||
|         num_kv_heads: int,  # assumed to be 1 for MLA | ||||
|         head_size: int, | ||||
|     ) -> Tuple[int, ...]: | ||||
|         return (num_blocks, block_size, head_size) | ||||
|  | ||||
|     @staticmethod | ||||
|     def swap_blocks( | ||||
|         src_kv_cache: torch.Tensor, | ||||
|         dst_kv_cache: torch.Tensor, | ||||
|         src_to_dst: torch.Tensor, | ||||
|     ) -> None: | ||||
|         ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) | ||||
|  | ||||
|     @staticmethod | ||||
|     def copy_blocks( | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         src_to_dists: torch.Tensor, | ||||
|     ) -> None: | ||||
|         ops.copy_blocks_mla(kv_caches, src_to_dists) | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_supported_head_sizes() -> List[int]: | ||||
|         return [576] | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class CPUMLAMetadata(TorchSDPAMetadata): | ||||
|     # New for MLA | ||||
|     # Input positions for rotrary embeddings since for MLA the rotary | ||||
|     # position embeddings are applied inside the attention backend | ||||
|     input_positions: torch.Tensor = None | ||||
|  | ||||
|     # required by MLACommonImpl | ||||
|     is_profile_run: bool = False | ||||
|  | ||||
|  | ||||
| class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): | ||||
|  | ||||
|     def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: | ||||
|         self.chunked_prefill = input_builder.chunked_prefill | ||||
|         self.input_builder = input_builder | ||||
|         assert not self.chunked_prefill, \ | ||||
|             "chunked prefill is currently not supported" | ||||
|  | ||||
|     def prepare(self): | ||||
|         self.input_data = self.input_builder.input_data | ||||
|  | ||||
|     def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): | ||||
|         input_data = self.input_data | ||||
|         prefill_seq_lens = seq_lens[0:input_data.num_prefills] | ||||
|         prefill_query_lens = query_lens[0:input_data.num_prefills] | ||||
|         slot_mapping = torch.tensor(input_data.slot_mapping, | ||||
|                                     dtype=torch.long, | ||||
|                                     device="cpu") | ||||
|  | ||||
|         # metadata for prefill | ||||
|         if input_data.num_prefills > 0: | ||||
|             query_lens_tensor = torch.tensor(prefill_query_lens, | ||||
|                                              dtype=torch.int32, | ||||
|                                              device="cpu") | ||||
|             kv_lens_tensor = torch.tensor(prefill_seq_lens, | ||||
|                                           dtype=torch.int32, | ||||
|                                           device="cpu") | ||||
|             query_start_loc = torch.zeros(input_data.num_prefills + 1, | ||||
|                                           dtype=torch.int32, | ||||
|                                           device="cpu") | ||||
|             kv_start_loc = torch.zeros(input_data.num_prefills + 1, | ||||
|                                        dtype=torch.int32, | ||||
|                                        device="cpu") | ||||
|             torch.cumsum(query_lens_tensor, | ||||
|                          dim=0, | ||||
|                          dtype=torch.int32, | ||||
|                          out=query_start_loc[1:]) | ||||
|             torch.cumsum(kv_lens_tensor, | ||||
|                          dim=0, | ||||
|                          dtype=torch.int32, | ||||
|                          out=kv_start_loc[1:]) | ||||
|             max_query_len = max(prefill_query_lens) | ||||
|             max_kv_len = max(prefill_seq_lens) | ||||
|  | ||||
|             # for chunked-prefill | ||||
|             if self.chunked_prefill: | ||||
|                 prefill_block_tables = make_tensor_with_pad( | ||||
|                     self.input_data.prefill_block_tables, | ||||
|                     pad=0, | ||||
|                     dtype=torch.int32, | ||||
|                     device="cpu", | ||||
|                 ) | ||||
|             else: | ||||
|                 prefill_block_tables = None | ||||
|  | ||||
|         else: | ||||
|             query_start_loc = None | ||||
|             kv_start_loc = None | ||||
|             max_query_len = None | ||||
|             max_kv_len = None | ||||
|             prefill_block_tables = None | ||||
|  | ||||
|         # metadata for decode | ||||
|         if input_data.num_decode_tokens != 0: | ||||
|             seq_lens_tensor = torch.tensor( | ||||
|                 input_data.seq_lens[input_data.num_prefills:], | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|             block_tables = make_tensor_with_pad( | ||||
|                 self.input_data.decode_block_tables, | ||||
|                 pad=0, | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|         else: | ||||
|             block_tables = torch.tensor([]) | ||||
|             seq_lens_tensor = torch.tensor( | ||||
|                 input_data.seq_lens[:input_data.num_prefills], | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|  | ||||
|         # For multi-modal models | ||||
|         placeholder_index_maps = None | ||||
|         if len(input_data.multi_modal_inputs_list) != 0: | ||||
|             placeholder_index_maps = { | ||||
|                 modality: placeholder_map.index_map() | ||||
|                 for modality, placeholder_map in | ||||
|                 input_data.multi_modal_placeholder_maps.items() | ||||
|             } | ||||
|  | ||||
|         return CPUMLAMetadata( | ||||
|             chunked_prefill=self.chunked_prefill, | ||||
|             seq_lens=prefill_seq_lens, | ||||
|             seq_lens_tensor=seq_lens_tensor, | ||||
|             max_query_len=max_query_len, | ||||
|             max_kv_len=max_kv_len, | ||||
|             prefill_query_start_loc=query_start_loc, | ||||
|             kv_start_loc=kv_start_loc, | ||||
|             max_decode_seq_len=input_data.max_decode_seq_len, | ||||
|             num_prefills=input_data.num_prefills, | ||||
|             num_prefill_tokens=input_data.num_prefill_tokens, | ||||
|             num_decode_tokens=input_data.num_decode_tokens, | ||||
|             block_tables=block_tables, | ||||
|             prefill_block_tables=prefill_block_tables, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=placeholder_index_maps, | ||||
|             enable_kv_scales_calculation=False, | ||||
|             input_positions=torch.tensor([self.input_data.input_positions])) | ||||
|  | ||||
|  | ||||
| class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): | ||||
|  | ||||
|     def __init__( | ||||
|             self, | ||||
|             num_heads: int, | ||||
|             head_size: int, | ||||
|             scale: float, | ||||
|             num_kv_heads: int, | ||||
|             alibi_slopes: Optional[List[float]], | ||||
|             sliding_window: Optional[int], | ||||
|             kv_cache_dtype: str, | ||||
|             blocksparse_params: Optional[Dict[str, Any]], | ||||
|             logits_soft_cap: Optional[float], | ||||
|             attn_type: str, | ||||
|             kv_sharing_target_layer_name: Optional[str], | ||||
|             # MLA Specific Arguments | ||||
|             **mla_args) -> None: | ||||
|         super().__init__(num_heads, head_size, scale, num_kv_heads, | ||||
|                          alibi_slopes, sliding_window, kv_cache_dtype, | ||||
|                          blocksparse_params, logits_soft_cap, attn_type, | ||||
|                          kv_sharing_target_layer_name, **mla_args) | ||||
|  | ||||
|         unsupported_features = [ | ||||
|             alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap | ||||
|         ] | ||||
|         if any(unsupported_features): | ||||
|             raise NotImplementedError( | ||||
|                 "CPUMLAImpl does not support one of the following: " | ||||
|                 "alibi_slopes, sliding_window, blocksparse_params, " | ||||
|                 "logits_soft_cap") | ||||
|  | ||||
|         if attn_type != AttentionType.DECODER: | ||||
|             raise NotImplementedError("Encoder self-attention and " | ||||
|                                       "encoder/decoder cross-attention " | ||||
|                                       "are not implemented for " | ||||
|                                       "CPUMLAImpl") | ||||
|  | ||||
|         # states is implemented. | ||||
|         if is_quantized_kv_cache(self.kv_cache_dtype): | ||||
|             raise NotImplementedError( | ||||
|                 "CPUMLAImpl with FP8 KV cache not yet supported") | ||||
|  | ||||
|     def _forward_prefill( | ||||
|             self, | ||||
|             q: torch.Tensor, | ||||
|             kv_c_normed: torch.Tensor, | ||||
|             k_pe: torch.Tensor, | ||||
|             kv_c_and_k_pe_cache: torch.Tensor, | ||||
|             attn_metadata: CPUMLAMetadata,  # type: ignore[override] | ||||
|     ) -> torch.Tensor: | ||||
|  | ||||
|         prefill_metadata = attn_metadata.prefill_metadata | ||||
|         assert prefill_metadata is not None | ||||
|  | ||||
|         kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ | ||||
|             -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) | ||||
|         k_nope, v = kv_nope\ | ||||
|             .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||||
|  | ||||
|         k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) | ||||
|  | ||||
|         # For MLA the v head dim is smaller than qk head dim so we pad out | ||||
|         # v with 0s to match the qk head dim | ||||
|         v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], | ||||
|                                            value=0) | ||||
|  | ||||
|         output = torch.empty_like(q) | ||||
|         ipex_ops.varlen_attention( | ||||
|             query=q, | ||||
|             key=k, | ||||
|             value=v_padded, | ||||
|             out=output, | ||||
|             seqlen_q=prefill_metadata.prefill_query_start_loc, | ||||
|             seqlen_k=prefill_metadata.prefill_query_start_loc, | ||||
|             max_seqlen_q=prefill_metadata.max_query_len, | ||||
|             max_seqlen_k=prefill_metadata.max_query_len, | ||||
|             pdropout=0.0, | ||||
|             softmax_scale=self.scale, | ||||
|             zero_tensors=False, | ||||
|             is_causal=True, | ||||
|             return_softmax=False, | ||||
|             gen_=None, | ||||
|             logits_soft_cap=0.0, | ||||
|             window_size_left=-1, | ||||
|             window_size_right=-1, | ||||
|             alibi_slopes=None, | ||||
|         ) | ||||
|  | ||||
|         # remove padding | ||||
|         output = output.view(-1, self.num_heads, | ||||
|                              q.shape[-1])[..., :v.shape[-1]] | ||||
|         return output.reshape(-1, self.num_heads * v.shape[-1]) | ||||
|  | ||||
|     def _forward_decode( | ||||
|             self, | ||||
|             q_nope: torch.Tensor, | ||||
|             q_pe: torch.Tensor, | ||||
|             kv_c_and_k_pe_cache: torch.Tensor, | ||||
|             attn_metadata: CPUMLAMetadata,  # type: ignore[override] | ||||
|     ) -> torch.Tensor: | ||||
|         assert kv_c_and_k_pe_cache.numel() > 0 | ||||
|  | ||||
|         decode_meta = attn_metadata.decode_metadata | ||||
|         assert decode_meta is not None | ||||
|  | ||||
|         q = torch.cat([q_nope, q_pe], dim=-1) | ||||
|         o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) | ||||
|  | ||||
|         # Run MQA | ||||
|         ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, | ||||
|                                    decode_meta.block_tables, | ||||
|                                    decode_meta.seq_lens_tensor) | ||||
|         return self._v_up_proj(o) | ||||
							
								
								
									
										403
									
								
								vllm/attention/backends/ipex_attn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										403
									
								
								vllm/attention/backends/ipex_attn.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,403 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| """ Attention layer with torch scaled_dot_product_attention | ||||
|     and PagedAttention.""" | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from vllm._ipex_ops import ipex_ops | ||||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||
|                                               AttentionLayer, | ||||
|                                               AttentionMetadata, AttentionType, | ||||
|                                               is_quantized_kv_cache) | ||||
| from vllm.attention.backends.utils import CommonAttentionState | ||||
| from vllm.attention.ops.paged_attn import (PagedAttention, | ||||
|                                            PagedAttentionMetadata) | ||||
| from vllm.logger import init_logger | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| _PARTITION_SIZE = 512 | ||||
|  | ||||
|  | ||||
| class IpexAttnBackend(AttentionBackend): | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_name() -> str: | ||||
|         return "IPEX" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_impl_cls() -> Type["IpexAttnBackendImpl"]: | ||||
|         return IpexAttnBackendImpl | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_metadata_cls() -> Type["IpexAttnMetadata"]: | ||||
|         return IpexAttnMetadata | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_state_cls() -> Type["CommonAttentionState"]: | ||||
|         return CommonAttentionState | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_kv_cache_shape( | ||||
|         num_blocks: int, | ||||
|         block_size: int, | ||||
|         num_kv_heads: int, | ||||
|         head_size: int, | ||||
|     ) -> Tuple[int, ...]: | ||||
|         return PagedAttention.get_kv_cache_shape(num_blocks, block_size, | ||||
|                                                  num_kv_heads, head_size) | ||||
|  | ||||
|     @staticmethod | ||||
|     def swap_blocks( | ||||
|         src_kv_cache: torch.Tensor, | ||||
|         dst_kv_cache: torch.Tensor, | ||||
|         src_to_dst: torch.Tensor, | ||||
|     ) -> None: | ||||
|         from vllm._ipex_ops import ipex_ops as ops | ||||
|         ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) | ||||
|  | ||||
|     @staticmethod | ||||
|     def copy_blocks( | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         src_to_dists: torch.Tensor, | ||||
|     ) -> None: | ||||
|         from vllm._ipex_ops import ipex_ops as ops | ||||
|         key_caches = [kv_cache[0] for kv_cache in kv_caches] | ||||
|         value_caches = [kv_cache[1] for kv_cache in kv_caches] | ||||
|         ops.copy_blocks(key_caches, value_caches, src_to_dists) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): | ||||
|     """Metadata for IpexAttnBackend. | ||||
|     """ | ||||
|     # Currently, input sequences can only contain all prompts | ||||
|     # or all decoding. True if all sequences are prompts. | ||||
|     is_prompt: bool | ||||
|     slot_mapping: torch.Tensor | ||||
|     seq_lens: Optional[List[int]] | ||||
|     seqlen_q: Optional[torch.Tensor] | ||||
|     max_seqlen: Optional[int] | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         # Set during the execution of the first attention op. | ||||
|         # It is a list because it is needed to set per prompt | ||||
|         # when alibi slopes is used. It is because of the limitation | ||||
|         # from xformer API. | ||||
|         # will not appear in the __repr__ and __init__ | ||||
|         self.attn_bias: Optional[List[torch.Tensor]] = None | ||||
|  | ||||
|     @property | ||||
|     def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: | ||||
|         # Currently chunked prefill is not supported | ||||
|         if self.num_decode_tokens == 0: | ||||
|             assert self.num_prefills > 0 | ||||
|             return self | ||||
|  | ||||
|         return None | ||||
|  | ||||
|     @property | ||||
|     def decode_metadata(self) -> Optional["IpexAttnMetadata"]: | ||||
|         # Currently chunked prefill is not supported | ||||
|         if self.num_prefills > 0: | ||||
|             assert self.num_decode_tokens == 0 | ||||
|             return None | ||||
|  | ||||
|         return self | ||||
|  | ||||
|  | ||||
| class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_heads: int, | ||||
|         head_size: int, | ||||
|         scale: float, | ||||
|         num_kv_heads: int, | ||||
|         alibi_slopes: Optional[List[float]], | ||||
|         sliding_window: Optional[int], | ||||
|         kv_cache_dtype: str, | ||||
|         blocksparse_params: Optional[Dict[str, Any]] = None, | ||||
|         logits_soft_cap: Optional[float] = None, | ||||
|         attn_type: str = AttentionType.DECODER, | ||||
|         kv_sharing_target_layer_name: Optional[str] = None, | ||||
|         use_irope: bool = False, | ||||
|     ) -> None: | ||||
|         if kv_sharing_target_layer_name is not None: | ||||
|             raise NotImplementedError("KV sharing is not supported in V0.") | ||||
|         if use_irope: | ||||
|             logger.warning_once( | ||||
|                 "Using irope in Ipex is not supported yet, it will fall" | ||||
|                 " back to global attention for long context.") | ||||
|         if blocksparse_params is not None: | ||||
|             raise ValueError( | ||||
|                 "IPEX backend does not support block-sparse attention.") | ||||
|         self.num_heads = num_heads | ||||
|         self.head_size = head_size | ||||
|         self.scale = float(scale) | ||||
|         self.num_kv_heads = num_kv_heads | ||||
|         if alibi_slopes is not None: | ||||
|             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||||
|         self.alibi_slopes = alibi_slopes | ||||
|         self.sliding_window = sliding_window | ||||
|         self.kv_cache_dtype = kv_cache_dtype | ||||
|  | ||||
|         self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||||
|         self.need_mask = (self.sliding_window is not None) | ||||
|         if logits_soft_cap is None: | ||||
|             logits_soft_cap = -1 | ||||
|         self.logits_soft_cap = logits_soft_cap | ||||
|  | ||||
|         supported_head_sizes = PagedAttention.get_supported_head_sizes() | ||||
|         if head_size not in supported_head_sizes: | ||||
|             raise ValueError( | ||||
|                 f"Head size {head_size} is not supported by PagedAttention. " | ||||
|                 f"Supported head sizes are: {supported_head_sizes}.") | ||||
|         if is_quantized_kv_cache(kv_cache_dtype): | ||||
|             raise NotImplementedError( | ||||
|                 "IPEX backend does not support FP8 KV cache. " | ||||
|                 "Please use xFormers backend instead.") | ||||
|         if attn_type != AttentionType.DECODER: | ||||
|             raise NotImplementedError("Encoder self-attention and " | ||||
|                                       "encoder/decoder cross-attention " | ||||
|                                       "are not implemented for " | ||||
|                                       "IpexAttnBackendImpl") | ||||
|  | ||||
|     def split_kv_cache( | ||||
|         self, | ||||
|         kv_cache: torch.Tensor, | ||||
|         num_kv_heads: int, | ||||
|         head_size: int, | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         x = 1 | ||||
|         num_blocks = kv_cache.shape[1] | ||||
|  | ||||
|         key_cache = kv_cache[0] | ||||
|         key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, | ||||
|                                    -1, x) | ||||
|         value_cache = kv_cache[1] | ||||
|         value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) | ||||
|         return key_cache, value_cache | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         layer: AttentionLayer, | ||||
|         query: torch.Tensor, | ||||
|         key: torch.Tensor, | ||||
|         value: torch.Tensor, | ||||
|         kv_cache: torch.Tensor, | ||||
|         attn_metadata: IpexAttnMetadata,  # type: ignore | ||||
|         output: Optional[torch.Tensor] = None, | ||||
|         output_scale: Optional[torch.Tensor] = None, | ||||
|     ) -> torch.Tensor: | ||||
|         """Forward pass with IPEX varlen_attention and PagedAttention. | ||||
|  | ||||
|         Args: | ||||
|             query: shape = [num_tokens, num_heads * head_size] | ||||
|             key: shape = [num_tokens, num_kv_heads * head_size] | ||||
|             value: shape = [num_tokens, num_kv_heads * head_size] | ||||
|             kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] | ||||
|                 NOTE: kv_cache will be an empty tensor with shape [0] | ||||
|                 for profiling run. | ||||
|             attn_metadata: Metadata for attention. | ||||
|         Returns: | ||||
|             shape = [num_tokens, num_heads * head_size] | ||||
|         """ | ||||
|         if output_scale is not None: | ||||
|             raise NotImplementedError( | ||||
|                 "fused output quantization is not yet supported" | ||||
|                 " for IpexAttentionImpl") | ||||
|  | ||||
|         assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 | ||||
|         num_tokens, hidden_size = query.shape | ||||
|         # Reshape the query, key, and value tensors. | ||||
|         query = query.view(-1, self.num_heads, self.head_size) | ||||
|         key = key.view(-1, self.num_kv_heads, self.head_size) | ||||
|         value = value.view(-1, self.num_kv_heads, self.head_size) | ||||
|  | ||||
|         if kv_cache.numel() > 0: | ||||
|             key_cache, value_cache = self.split_kv_cache( | ||||
|                 kv_cache, self.num_kv_heads, self.head_size) | ||||
|             ipex_ops.reshape_and_cache( | ||||
|                 key, | ||||
|                 value, | ||||
|                 key_cache, | ||||
|                 value_cache, | ||||
|                 attn_metadata.slot_mapping.flatten(), | ||||
|                 self.kv_cache_dtype, | ||||
|                 layer._k_scale_float, | ||||
|                 layer._v_scale_float, | ||||
|             ) | ||||
|  | ||||
|         if attn_metadata.is_prompt: | ||||
|             assert attn_metadata.seq_lens is not None | ||||
|             if (kv_cache.numel() == 0 | ||||
|                     or attn_metadata.block_tables.numel() == 0): | ||||
|                 if self.num_kv_heads != self.num_heads: | ||||
|                     key = key.repeat_interleave(self.num_queries_per_kv, dim=1) | ||||
|                     value = value.repeat_interleave(self.num_queries_per_kv, | ||||
|                                                     dim=1) | ||||
|  | ||||
|                 if attn_metadata.attn_bias is None: | ||||
|                     if self.sliding_window is not None: | ||||
|                         att_masks = _make_sliding_window_bias( | ||||
|                             attn_metadata.seq_lens, self.sliding_window, | ||||
|                             query.dtype)  # type: ignore | ||||
|                     else: | ||||
|                         att_masks = _make_sliding_window_bias( | ||||
|                             attn_metadata.seq_lens, None, dtype=query.dtype) | ||||
|                     attn_metadata.attn_bias = att_masks | ||||
|  | ||||
|                 output = torch.empty( | ||||
|                     (num_tokens, self.num_heads, self.head_size), | ||||
|                     dtype=query.dtype, | ||||
|                     device=query.device) | ||||
|                 ipex_ops.varlen_attention( | ||||
|                     query, | ||||
|                     key, | ||||
|                     value, | ||||
|                     output, | ||||
|                     attn_metadata.seqlen_q, | ||||
|                     attn_metadata.seqlen_q, | ||||
|                     self.alibi_slopes, | ||||
|                     attn_metadata.max_seqlen, | ||||
|                     attn_metadata.max_seqlen, | ||||
|                     pdropout=0.0, | ||||
|                     softmax_scale=self.scale, | ||||
|                     zero_tensors=False, | ||||
|                     is_causal=True, | ||||
|                     return_softmax=False, | ||||
|                     gen_=None, | ||||
|                     window_size_left=-1, | ||||
|                     window_size_right=-1, | ||||
|                     logits_soft_cap=self.logits_soft_cap, | ||||
|                 ) | ||||
|             else: | ||||
|                 # prefix-enabled attention | ||||
|                 raise RuntimeError( | ||||
|                     "IPEX backend doesn't support prefix decoding.") | ||||
|  | ||||
|         else: | ||||
|             # Decoding run. | ||||
|             max_seq_len = attn_metadata.max_decode_seq_len | ||||
|             output = torch.empty_like(query) | ||||
|             block_size = value_cache.shape[3] | ||||
|             num_seqs, num_heads, head_size = query.shape | ||||
|             max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // | ||||
|                                   _PARTITION_SIZE) | ||||
|             # NOTE(woosuk): We use a simple heuristic to decide whether to use | ||||
|             # PagedAttention V1 or V2. If the number of partitions is 1, we use | ||||
|             # V1 to avoid the overhead of reduction. Also, if the number of | ||||
|             # sequences or heads is large, we use V1 since there is enough work | ||||
|             # to parallelize. | ||||
|             # TODO(woosuk): Tune this heuristic. | ||||
|             # For context len > 8192, use V2 kernel to avoid shared memory | ||||
|             # shortage. | ||||
|             use_v1 = (max_seq_len <= 8192 and | ||||
|                       (max_num_partitions == 1 or num_seqs * num_heads > 512)) | ||||
|             if use_v1: | ||||
|                 # Run PagedAttention V1. | ||||
|                 ipex_ops.paged_attention_v1( | ||||
|                     output, | ||||
|                     query, | ||||
|                     key_cache, | ||||
|                     value_cache, | ||||
|                     self.num_kv_heads, | ||||
|                     self.scale, | ||||
|                     attn_metadata.block_tables, | ||||
|                     attn_metadata.seq_lens_tensor, | ||||
|                     block_size, | ||||
|                     max_seq_len, | ||||
|                     self.alibi_slopes, | ||||
|                     self.kv_cache_dtype, | ||||
|                     layer._k_scale_float, | ||||
|                     layer._v_scale_float, | ||||
|                 ) | ||||
|             else: | ||||
|                 # Run PagedAttention V2. | ||||
|                 assert _PARTITION_SIZE % block_size == 0 | ||||
|                 tmp_output = torch.empty( | ||||
|                     size=(num_seqs, num_heads, max_num_partitions, head_size), | ||||
|                     dtype=output.dtype, | ||||
|                     device=output.device, | ||||
|                 ) | ||||
|                 exp_sums = torch.empty( | ||||
|                     size=(num_seqs, num_heads, max_num_partitions), | ||||
|                     dtype=torch.float32, | ||||
|                     device=output.device, | ||||
|                 ) | ||||
|                 max_logits = torch.empty_like(exp_sums) | ||||
|                 ipex_ops.paged_attention_v2( | ||||
|                     output, | ||||
|                     exp_sums, | ||||
|                     max_logits, | ||||
|                     tmp_output, | ||||
|                     query, | ||||
|                     key_cache, | ||||
|                     value_cache, | ||||
|                     self.num_kv_heads, | ||||
|                     self.scale, | ||||
|                     attn_metadata.block_tables, | ||||
|                     attn_metadata.seq_lens_tensor, | ||||
|                     block_size, | ||||
|                     max_seq_len, | ||||
|                     self.alibi_slopes, | ||||
|                     self.kv_cache_dtype, | ||||
|                     layer._k_scale_float, | ||||
|                     layer._v_scale_float, | ||||
|                 ) | ||||
|  | ||||
|             # Reshape the output tensor. | ||||
|         return output.view(-1, self.num_heads * self.head_size) | ||||
|  | ||||
|  | ||||
| def _make_alibi_bias( | ||||
|     alibi_slopes: torch.Tensor, | ||||
|     dtype: torch.dtype, | ||||
|     seq_lens: List[int], | ||||
| ) -> List[torch.Tensor]: | ||||
|     attn_biases = [] | ||||
|     for seq_len in seq_lens: | ||||
|         bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) | ||||
|         # NOTE(zhuohan): HF uses | ||||
|         #     `bias = bias[None, :].repeat(seq_len, 1)` | ||||
|         # here. We find that both biases give the same results, but | ||||
|         # the bias below more accurately follows the original ALiBi | ||||
|         # paper. | ||||
|         bias = bias[None, :] - bias[:, None] | ||||
|  | ||||
|         num_heads = alibi_slopes.shape[0] | ||||
|         bias = bias[None, :].repeat((num_heads, 1, 1)) | ||||
|         bias.mul_(alibi_slopes[:, None, None]) | ||||
|         inf_mask = torch.empty( | ||||
|             (1, seq_len, seq_len), | ||||
|             dtype=bias.dtype, | ||||
|             device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) | ||||
|         attn_biases.append((bias + inf_mask).to(dtype)) | ||||
|  | ||||
|     return attn_biases | ||||
|  | ||||
|  | ||||
| def _make_sliding_window_bias( | ||||
|     seq_lens: List[int], | ||||
|     window_size: Optional[int], | ||||
|     dtype: torch.dtype, | ||||
| ) -> List[torch.Tensor]: | ||||
|     attn_biases = [] | ||||
|     for seq_len in seq_lens: | ||||
|         tensor = torch.full( | ||||
|             (1, seq_len, seq_len), | ||||
|             dtype=dtype, | ||||
|             fill_value=1, | ||||
|         ) | ||||
|         shift = 0 | ||||
|         mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore | ||||
|         if window_size is not None: | ||||
|             mask = torch.triu(mask, diagonal=shift - window_size + 1) | ||||
|         mask = torch.log(mask) | ||||
|         attn_biases.append(mask.to(dtype)) | ||||
|  | ||||
|     return attn_biases | ||||
							
								
								
									
										356
									
								
								vllm/attention/backends/pallas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										356
									
								
								vllm/attention/backends/pallas.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,356 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| import torch | ||||
| import torch_xla.experimental.custom_kernel  # Required to register custom ops. | ||||
|  | ||||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||
|                                               AttentionLayer, | ||||
|                                               AttentionMetadata, AttentionType, | ||||
|                                               is_quantized_kv_cache) | ||||
| from vllm.attention.backends.utils import CommonAttentionState | ||||
| from vllm.logger import init_logger | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class PallasAttentionBackend(AttentionBackend): | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_name() -> str: | ||||
|         return "PALLAS" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: | ||||
|         return PallasAttentionBackendImpl | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_metadata_cls() -> Type["PallasMetadata"]: | ||||
|         return PallasMetadata | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_state_cls() -> Type["CommonAttentionState"]: | ||||
|         return CommonAttentionState | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_kv_cache_shape( | ||||
|         num_blocks: int, | ||||
|         block_size: int, | ||||
|         num_kv_heads: int, | ||||
|         head_size: int, | ||||
|     ) -> Tuple[int, ...]: | ||||
|         return (num_kv_heads, num_blocks, block_size, head_size) | ||||
|  | ||||
|     @staticmethod | ||||
|     def swap_blocks( | ||||
|         src_kv_cache: torch.Tensor, | ||||
|         dst_kv_cache: torch.Tensor, | ||||
|         src_to_dst: torch.Tensor, | ||||
|     ) -> None: | ||||
|         raise RuntimeError("swap_blocks is not used for the TPU backend.") | ||||
|  | ||||
|     @torch.compile(backend="openxla") | ||||
|     @staticmethod | ||||
|     def copy_blocks( | ||||
|         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], | ||||
|         src_to_dists: Tuple[torch.Tensor, torch.Tensor], | ||||
|     ) -> None: | ||||
|         src_indices, dst_indices = src_to_dists | ||||
|         for k_cache, v_cache in kv_caches: | ||||
|             torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) | ||||
|             k_cache[:, dst_indices] = k_cache[:, src_indices] | ||||
|             torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) | ||||
|             v_cache[:, dst_indices] = v_cache[:, src_indices] | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class PallasMetadata(AttentionMetadata): | ||||
|  | ||||
|     # Currently, input sequences can only contain all prefills | ||||
|     # or all decoding. | ||||
|     block_tables: Optional[torch.Tensor] = None | ||||
|     context_lens: Optional[torch.Tensor] = None | ||||
|     effective_query_lens: Optional[torch.Tensor] = None | ||||
|  | ||||
|     @property | ||||
|     def prefill_metadata(self) -> Optional["PallasMetadata"]: | ||||
|         if self.num_prefills == 0: | ||||
|             return None | ||||
|  | ||||
|         assert self.num_decode_tokens == 0 | ||||
|         return self | ||||
|  | ||||
|     @property | ||||
|     def decode_metadata(self) -> Optional["PallasMetadata"]: | ||||
|         if self.num_decode_tokens == 0: | ||||
|             return None | ||||
|  | ||||
|         assert self.num_prefills == 0 | ||||
|         assert self.num_prefill_tokens == 0 | ||||
|         assert self.block_tables is not None | ||||
|         assert self.context_lens is not None | ||||
|         return self | ||||
|  | ||||
|  | ||||
| class PallasAttentionBackendImpl(AttentionImpl): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_heads: int, | ||||
|         head_size: int, | ||||
|         scale: float, | ||||
|         num_kv_heads: int, | ||||
|         alibi_slopes: Optional[List[float]], | ||||
|         sliding_window: Optional[int], | ||||
|         kv_cache_dtype: str, | ||||
|         blocksparse_params: Optional[Dict[str, Any]] = None, | ||||
|         logits_soft_cap: Optional[float] = None, | ||||
|         attn_type: str = AttentionType.DECODER, | ||||
|         kv_sharing_target_layer_name: Optional[str] = None, | ||||
|         use_irope: bool = False, | ||||
|     ) -> None: | ||||
|         if kv_sharing_target_layer_name is not None: | ||||
|             raise NotImplementedError("KV sharing is not supported in V0.") | ||||
|         if use_irope: | ||||
|             logger.warning_once( | ||||
|                 "Using irope in Pallas is not supported yet, it will fall back " | ||||
|                 "to global attention for long context.") | ||||
|         self.num_heads = num_heads | ||||
|         self.head_size = head_size | ||||
|         self.scale = float(scale) | ||||
|         self.num_kv_heads = num_kv_heads | ||||
|  | ||||
|         self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||||
|         self.logits_soft_cap = logits_soft_cap | ||||
|         if head_size % 128 != 0: | ||||
|             raise NotImplementedError( | ||||
|                 f"Head size must be a multiple of 128, found {head_size}.") | ||||
|         if alibi_slopes is not None: | ||||
|             raise NotImplementedError("Alibi slopes is not supported.") | ||||
|         if sliding_window is not None: | ||||
|             raise NotImplementedError("Sliding window is not supported.") | ||||
|         if is_quantized_kv_cache(kv_cache_dtype): | ||||
|             raise NotImplementedError("FP8 KV cache dtype is not supported.") | ||||
|         if blocksparse_params is not None: | ||||
|             raise NotImplementedError("Blocksparse is not supported.") | ||||
|  | ||||
|         if torch_xla.tpu.version() < 4: | ||||
|             raise NotImplementedError("TPU version must be 4 or higher.") | ||||
|  | ||||
|         self.megacore_mode = None | ||||
|         tpu_env = torch_xla.tpu.get_tpu_env() | ||||
|         tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) | ||||
|                     or tpu_env.get("TYPE", None) | ||||
|                     or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) | ||||
|         assert tpu_type is not None | ||||
|         tpu_type = tpu_type.lower() | ||||
|  | ||||
|         if (("lite" not in tpu_type) and ("v6" not in tpu_type)): | ||||
|             if self.num_kv_heads % 2 == 0: | ||||
|                 self.megacore_mode = "kv_head" | ||||
|             else: | ||||
|                 # NOTE(woosuk): If the batch size is not a multiple of 2, the | ||||
|                 # megacore mode will be None. | ||||
|                 self.megacore_mode = "batch" | ||||
|  | ||||
|         if attn_type != AttentionType.DECODER: | ||||
|             raise NotImplementedError("Encoder self-attention and " | ||||
|                                       "encoder/decoder cross-attention " | ||||
|                                       "are not implemented for " | ||||
|                                       "PallasAttentionBackendImpl") | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         layer: AttentionLayer, | ||||
|         query: torch.Tensor, | ||||
|         key: torch.Tensor, | ||||
|         value: torch.Tensor, | ||||
|         kv_cache: Tuple[torch.Tensor, torch.Tensor], | ||||
|         attn_metadata: PallasMetadata, | ||||
|         output: Optional[torch.Tensor] = None, | ||||
|         output_scale: Optional[torch.Tensor] = None, | ||||
|     ) -> torch.Tensor: | ||||
|         """Forward pass with Pallas attention. | ||||
|  | ||||
|         Args: | ||||
|             query: shape = [batch_size, seq_len, num_heads * head_size] | ||||
|             key: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||||
|             value: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||||
|             kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] | ||||
|             kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] | ||||
|                 NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor  | ||||
|                 with shape [0] for profiling run. | ||||
|             attn_metadata: Metadata for attention. | ||||
|         Returns: | ||||
|             shape = [batch_size, seq_len, num_heads * head_size] | ||||
|         """ | ||||
|         if output_scale is not None: | ||||
|             raise NotImplementedError( | ||||
|                 "fused output quantization is not yet supported" | ||||
|                 " for PallasAttentionImpl") | ||||
|  | ||||
|         assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 | ||||
|         batch_size, seq_len, hidden_size = query.shape | ||||
|         query = query.view(batch_size, seq_len, self.num_heads, self.head_size) | ||||
|         key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) | ||||
|         value = value.view(batch_size, seq_len, self.num_kv_heads, | ||||
|                            self.head_size) | ||||
|  | ||||
|         if kv_cache[0].numel() > 0: | ||||
|             slot_mapping = attn_metadata.slot_mapping | ||||
|             key_cache, value_cache = kv_cache | ||||
|             write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) | ||||
|  | ||||
|         query = query * self.scale | ||||
|         if attn_metadata.num_prefills > 0: | ||||
|             if attn_metadata.block_tables is None: | ||||
|                 # Prefill without paged KV cache. | ||||
|                 assert seq_len % 16 == 0, ( | ||||
|                     "Pallas FlashAttention kernel requires seq_len to be a " | ||||
|                     f"multiple of 16 but got {seq_len}") | ||||
|  | ||||
|                 # Handle GQA/MQA. | ||||
|                 if self.num_kv_heads != self.num_heads: | ||||
|                     key = key.repeat_interleave(self.num_queries_per_kv, | ||||
|                                                 dim=-2) | ||||
|                     key = key.view(batch_size, seq_len, self.num_heads, | ||||
|                                    self.head_size) | ||||
|                     value = value.repeat_interleave(self.num_queries_per_kv, | ||||
|                                                     dim=-2) | ||||
|                     value = value.view(batch_size, seq_len, self.num_heads, | ||||
|                                        self.head_size) | ||||
|                 # FlashAttention kernel requires the input shape to be | ||||
|                 # [batch_size, num_heads, seq_len, d_model] | ||||
|                 # while the input is [batch_size, seq_len, num_heads, d_model]. | ||||
|                 # Permute the input to match the required format. | ||||
|                 output = torch.ops.xla.flash_attention( | ||||
|                     query.permute(0, 2, 1, 3), | ||||
|                     key.permute(0, 2, 1, 3), | ||||
|                     value.permute(0, 2, 1, 3), | ||||
|                     True, | ||||
|                 ) | ||||
|                 output = output.permute(0, 2, 1, 3) | ||||
|             else: | ||||
|                 # Prefill with paged KV cache. | ||||
|                 # TODO(woosuk): Tune the below knobs. | ||||
|                 num_kv_pages_per_compute_block = 16 | ||||
|                 num_queries_per_compute_block = 16 | ||||
|                 assert seq_len % num_queries_per_compute_block == 0 | ||||
|                 output = torch.ops.xla.multi_queries_paged_attention( | ||||
|                     query, | ||||
|                     key_cache, | ||||
|                     value_cache, | ||||
|                     attn_metadata.context_lens, | ||||
|                     attn_metadata.block_tables, | ||||
|                     attn_metadata.effective_query_lens, | ||||
|                     num_kv_pages_per_compute_block, | ||||
|                     num_queries_per_compute_block, | ||||
|                     use_kernel=True, | ||||
|                     attn_logits_soft_cap=self.logits_soft_cap, | ||||
|                 ) | ||||
|         else: | ||||
|             # Decoding run. | ||||
|             assert kv_cache[0].numel() > 0 | ||||
|             query = query.squeeze(dim=1) | ||||
|             pages_per_compute_block = 16  # TODO(woosuk): Tune this value. | ||||
|  | ||||
|             assert attn_metadata.block_tables is not None | ||||
|             assert attn_metadata.context_lens is not None | ||||
|             # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire | ||||
|             # block table in SMEM. Therefore, if the block table is too large, | ||||
|             # the kernel compilation will fail. To avoid this, we split the | ||||
|             # batch dimension into smaller chunks and run the kernel multiple | ||||
|             # times. | ||||
|             MAX_SMEM_USAGE = 512 * 1024 | ||||
|             size_per_seq = 4 * attn_metadata.block_tables.shape[1] | ||||
|             max_num_seq = MAX_SMEM_USAGE // size_per_seq | ||||
|  | ||||
|             if batch_size <= max_num_seq: | ||||
|                 output = paged_attention( | ||||
|                     query, | ||||
|                     key_cache, | ||||
|                     value_cache, | ||||
|                     attn_metadata.context_lens, | ||||
|                     attn_metadata.block_tables, | ||||
|                     pages_per_compute_block, | ||||
|                     self.megacore_mode, | ||||
|                     attn_logits_soft_cap=self.logits_soft_cap, | ||||
|                 ) | ||||
|             else: | ||||
|                 chunk_size = max_num_seq | ||||
|                 # Make sure the chunk size is a multiple of 2. | ||||
|                 chunk_size = chunk_size // 2 * 2 | ||||
|                 num_chunks = (batch_size + chunk_size - 1) // chunk_size | ||||
|  | ||||
|                 output = torch.empty_like(query) | ||||
|                 for chunk_idx in range(num_chunks): | ||||
|                     chunk_start = chunk_idx * chunk_size | ||||
|                     chunk_end = chunk_start + chunk_size | ||||
|                     # NOTE(woosuk): We skip this line because it causes Dynamo | ||||
|                     # compilation error. Instead, we rely on the slice operation | ||||
|                     # to handle the out-of-bound case. | ||||
|                     # chunk_end = min(chunk_end, batch_size) | ||||
|                     chunk_output = paged_attention( | ||||
|                         query[chunk_start:chunk_end], | ||||
|                         key_cache, | ||||
|                         value_cache, | ||||
|                         attn_metadata.context_lens[chunk_start:chunk_end], | ||||
|                         attn_metadata.block_tables[chunk_start:chunk_end], | ||||
|                         pages_per_compute_block, | ||||
|                         self.megacore_mode, | ||||
|                         attn_logits_soft_cap=self.logits_soft_cap, | ||||
|                     ) | ||||
|                     output[chunk_start:chunk_end] = chunk_output | ||||
|  | ||||
|         # Reshape the output tensor. | ||||
|         return output.reshape(batch_size, seq_len, hidden_size) | ||||
|  | ||||
|  | ||||
| def write_to_kv_cache( | ||||
|     key: torch.Tensor, | ||||
|     value: torch.Tensor, | ||||
|     key_cache: torch.Tensor, | ||||
|     value_cache: torch.Tensor, | ||||
|     slot_mapping: torch.Tensor, | ||||
| ) -> None: | ||||
|     torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) | ||||
|     torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) | ||||
|  | ||||
|     key = key.flatten(0, 2) | ||||
|     value = value.flatten(0, 2) | ||||
|     key_cache = key_cache.flatten(0, 2) | ||||
|     value_cache = value_cache.flatten(0, 2) | ||||
|     key_cache.index_copy_(0, slot_mapping, key) | ||||
|     value_cache.index_copy_(0, slot_mapping, value) | ||||
|  | ||||
|  | ||||
| def paged_attention( | ||||
|     query: torch.Tensor, | ||||
|     key_cache: torch.Tensor, | ||||
|     value_cache: torch.Tensor, | ||||
|     context_lens: torch.Tensor, | ||||
|     block_tables: torch.Tensor, | ||||
|     pages_per_compute_block: int, | ||||
|     megacore_mode: Optional[str], | ||||
|     *, | ||||
|     attn_logits_soft_cap: Optional[float], | ||||
| ) -> torch.Tensor: | ||||
|     batch_size = query.shape[0] | ||||
|     if megacore_mode == "batch" and batch_size % 2 != 0: | ||||
|         megacore_mode = None | ||||
|     else: | ||||
|         megacore_mode = megacore_mode | ||||
|  | ||||
|     return torch.ops.xla.paged_attention( | ||||
|         query, | ||||
|         key_cache, | ||||
|         value_cache, | ||||
|         context_lens, | ||||
|         block_tables, | ||||
|         pages_per_compute_block, | ||||
|         megacore_mode=megacore_mode, | ||||
|         attn_logits_soft_cap=attn_logits_soft_cap, | ||||
|     ) | ||||
| @ -3,24 +3,78 @@ | ||||
| """ Attention layer with torch scaled_dot_product_attention | ||||
|     and PagedAttention.""" | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Optional | ||||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| import torch | ||||
| from torch.nn.functional import scaled_dot_product_attention | ||||
|  | ||||
| # yapf conflicts with isort for this block | ||||
| # yapf: disable | ||||
| from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, | ||||
|                                               AttentionMetadata, AttentionType, | ||||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||
|                                               AttentionLayer, | ||||
|                                               AttentionMetadata, | ||||
|                                               AttentionMetadataBuilder, | ||||
|                                               AttentionType, | ||||
|                                               is_quantized_kv_cache) | ||||
| # yapf: enable | ||||
| from vllm.attention.backends.utils import CommonAttentionState | ||||
| from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex | ||||
| from vllm.attention.ops.paged_attn import PagedAttentionMetadata | ||||
| from vllm.logger import init_logger | ||||
| from vllm.utils import make_tensor_with_pad | ||||
| from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class TorchSDPABackend(AttentionBackend): | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_name() -> str: | ||||
|         return "TORCH_SDPA" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_impl_cls() -> Type["TorchSDPABackendImpl"]: | ||||
|         return TorchSDPABackendImpl | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_metadata_cls() -> Type["AttentionMetadata"]: | ||||
|         return TorchSDPAMetadata | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_state_cls() -> Type["CommonAttentionState"]: | ||||
|         return CommonAttentionState | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: | ||||
|         return TorchSDPAMetadataBuilder | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_kv_cache_shape( | ||||
|         num_blocks: int, | ||||
|         block_size: int, | ||||
|         num_kv_heads: int, | ||||
|         head_size: int, | ||||
|     ) -> Tuple[int, ...]: | ||||
|         return PagedAttention.get_kv_cache_shape(num_blocks, block_size, | ||||
|                                                  num_kv_heads, head_size) | ||||
|  | ||||
|     @staticmethod | ||||
|     def swap_blocks( | ||||
|         src_kv_cache: torch.Tensor, | ||||
|         dst_kv_cache: torch.Tensor, | ||||
|         src_to_dst: torch.Tensor, | ||||
|     ) -> None: | ||||
|         raise NotImplementedError("Swap is not supported in TorchSDPABackend.") | ||||
|  | ||||
|     @staticmethod | ||||
|     def copy_blocks( | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         src_to_dists: torch.Tensor, | ||||
|     ) -> None: | ||||
|         PagedAttention.copy_blocks(kv_caches, src_to_dists) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): | ||||
|     """Metadata for TorchSDPABackend. | ||||
| @ -233,6 +287,113 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): | ||||
|             raise AttributeError(f"Invalid attention type {str(attn_type)}") | ||||
|  | ||||
|  | ||||
| class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): | ||||
|  | ||||
|     def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: | ||||
|         self.chunked_prefill = input_builder.chunked_prefill | ||||
|         self.input_builder = input_builder | ||||
|  | ||||
|     def prepare(self): | ||||
|         self.input_data = self.input_builder.input_data | ||||
|  | ||||
|     def build(self, seq_lens: List[int], query_lens: List[int], | ||||
|               cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: | ||||
|         input_data = self.input_data | ||||
|         prefill_seq_lens = seq_lens[0:input_data.num_prefills] | ||||
|         prefill_query_lens = query_lens[0:input_data.num_prefills] | ||||
|         slot_mapping = torch.tensor(input_data.slot_mapping, | ||||
|                                     dtype=torch.long, | ||||
|                                     device="cpu") | ||||
|  | ||||
|         # For chunked-prefill | ||||
|         if self.chunked_prefill and input_data.num_prefill_tokens != 0: | ||||
|             prefill_block_tables = make_tensor_with_pad( | ||||
|                 self.input_data.prefill_block_tables, | ||||
|                 pad=0, | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|             query_lens_tensor = torch.tensor(prefill_query_lens, | ||||
|                                              dtype=torch.int32, | ||||
|                                              device="cpu") | ||||
|             kv_lens_tensor = torch.tensor(prefill_seq_lens, | ||||
|                                           dtype=torch.int32, | ||||
|                                           device="cpu") | ||||
|             query_start_loc = torch.zeros(input_data.num_prefills + 1, | ||||
|                                           dtype=torch.int32, | ||||
|                                           device="cpu") | ||||
|             kv_start_loc = torch.zeros(input_data.num_prefills + 1, | ||||
|                                        dtype=torch.int32, | ||||
|                                        device="cpu") | ||||
|             torch.cumsum(query_lens_tensor, | ||||
|                          dim=0, | ||||
|                          dtype=torch.int32, | ||||
|                          out=query_start_loc[1:]) | ||||
|             torch.cumsum(kv_lens_tensor, | ||||
|                          dim=0, | ||||
|                          dtype=torch.int32, | ||||
|                          out=kv_start_loc[1:]) | ||||
|             max_query_len = max(prefill_query_lens) | ||||
|             max_kv_len = max(prefill_seq_lens) | ||||
|         else: | ||||
|             prefill_block_tables = None | ||||
|             query_start_loc = None | ||||
|             kv_start_loc = None | ||||
|             max_query_len = None | ||||
|             max_kv_len = None | ||||
|  | ||||
|         # For paged attention | ||||
|         if input_data.num_decode_tokens != 0: | ||||
|             seq_lens_tensor = torch.tensor( | ||||
|                 input_data.seq_lens[input_data.num_prefills:], | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|             block_tables = make_tensor_with_pad( | ||||
|                 self.input_data.decode_block_tables, | ||||
|                 pad=0, | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|         else: | ||||
|             block_tables = torch.tensor([]) | ||||
|             seq_lens_tensor = torch.tensor( | ||||
|                 input_data.seq_lens[:input_data.num_prefills], | ||||
|                 dtype=torch.int32, | ||||
|                 device="cpu", | ||||
|             ) | ||||
|  | ||||
|         # For multi-modal models | ||||
|         placeholder_index_maps = None | ||||
|         if len(input_data.multi_modal_inputs_list) != 0: | ||||
|             placeholder_index_maps = { | ||||
|                 modality: placeholder_map.index_map() | ||||
|                 for modality, placeholder_map in | ||||
|                 input_data.multi_modal_placeholder_maps.items() | ||||
|             } | ||||
|  | ||||
|         attn_metadata = TorchSDPAMetadata( | ||||
|             chunked_prefill=self.chunked_prefill, | ||||
|             seq_lens=prefill_seq_lens, | ||||
|             seq_lens_tensor=seq_lens_tensor, | ||||
|             max_query_len=max_query_len, | ||||
|             max_kv_len=max_kv_len, | ||||
|             prefill_query_start_loc=query_start_loc, | ||||
|             kv_start_loc=kv_start_loc, | ||||
|             max_decode_seq_len=input_data.max_decode_seq_len, | ||||
|             num_prefills=input_data.num_prefills, | ||||
|             num_prefill_tokens=input_data.num_prefill_tokens, | ||||
|             num_decode_tokens=input_data.num_decode_tokens, | ||||
|             block_tables=block_tables, | ||||
|             prefill_block_tables=prefill_block_tables, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=placeholder_index_maps, | ||||
|             enable_kv_scales_calculation=False, | ||||
|         ) | ||||
|  | ||||
|         return attn_metadata | ||||
|  | ||||
|  | ||||
| class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): | ||||
|  | ||||
|     def __init__( | ||||
|  | ||||
| @ -64,11 +64,13 @@ class CpuPlatform(Platform): | ||||
|         if selected_backend and selected_backend != _Backend.TORCH_SDPA: | ||||
|             logger.info("Cannot use %s backend on CPU.", selected_backend) | ||||
|         if use_mla: | ||||
|             raise NotImplementedError("MLA is not supported on CPU.") | ||||
|             logger.info("Using CPU MLA backend.") | ||||
|             return "vllm.attention.backends.cpu_mla.CPUMLABackend" | ||||
|         logger.info("Using Torch SDPA backend.") | ||||
|         if not use_v1: | ||||
|             raise ValueError("CPU backend only supports V1.") | ||||
|         if use_v1: | ||||
|             return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" | ||||
|         else: | ||||
|             return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" | ||||
|  | ||||
|     @classmethod | ||||
|     def get_device_total_memory(cls, device_id: int = 0) -> int: | ||||
| @ -145,14 +147,26 @@ class CpuPlatform(Platform): | ||||
|                            parallel_config.distributed_executor_backend) | ||||
|             parallel_config.distributed_executor_backend = "mp" | ||||
|         if parallel_config.worker_cls == "auto": | ||||
|             parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker" | ||||
|             if vllm_config.speculative_config: | ||||
|                 parallel_config.worker_cls = \ | ||||
|                     "vllm.spec_decode.spec_decode_worker.create_spec_worker" | ||||
|                 parallel_config.sd_worker_cls = \ | ||||
|                     "vllm.worker.cpu_worker.CPUWorker" | ||||
|             else: | ||||
|                 if envs.VLLM_USE_V1: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.v1.worker.cpu_worker.CPUWorker" | ||||
|                 else: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.worker.cpu_worker.CPUWorker" | ||||
|  | ||||
|         # Note: workaround for v1 gpu_model_runner | ||||
|         from vllm.config import CompilationLevel | ||||
|         vllm_config.compilation_config.cudagraph_capture_sizes = [] | ||||
|  | ||||
|         compilation_config = vllm_config.compilation_config | ||||
|         if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: | ||||
|         if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level | ||||
|                 == CompilationLevel.PIECEWISE): | ||||
|  | ||||
|             # Note: vLLM V1 is using PIECEWISE level compilation, which will | ||||
|             # take time to compile kernels just-in-time with the inductor | ||||
|  | ||||
| @ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union, cast | ||||
| import torch | ||||
| from tpu_info import device | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.inputs import ProcessorInputs, PromptType | ||||
| from vllm.logger import init_logger | ||||
| from vllm.sampling_params import SamplingParams, SamplingType | ||||
| @ -49,10 +50,12 @@ class TpuPlatform(Platform): | ||||
|                 and selected_backend != _Backend.PALLAS_VLLM_V1): | ||||
|             logger.info("Cannot use %s backend on TPU.", selected_backend) | ||||
|  | ||||
|         if not use_v1: | ||||
|             raise ValueError("TPU backend only supports V1.") | ||||
|         if use_v1: | ||||
|             logger.info("Using Pallas V1 backend.") | ||||
|             return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" | ||||
|         else: | ||||
|             logger.info("Using Pallas backend.") | ||||
|             return "vllm.attention.backends.pallas.PallasAttentionBackend" | ||||
|  | ||||
|     @classmethod | ||||
|     def get_device_name(cls, device_id: int = 0) -> str: | ||||
| @ -65,7 +68,7 @@ class TpuPlatform(Platform): | ||||
|  | ||||
|     @classmethod | ||||
|     def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: | ||||
|         return False | ||||
|         return not envs.VLLM_USE_V1 | ||||
|  | ||||
|     @classmethod | ||||
|     def get_punica_wrapper(cls) -> str: | ||||
| @ -114,7 +117,9 @@ class TpuPlatform(Platform): | ||||
|                 "Using bfloat16 instead.", vllm_config.model_config.dtype) | ||||
|             vllm_config.model_config.dtype = torch.bfloat16 | ||||
|  | ||||
|         from vllm.v1.attention.backends.pallas import PallasAttentionBackend | ||||
|         if envs.VLLM_USE_V1: | ||||
|             from vllm.v1.attention.backends.pallas import ( | ||||
|                 PallasAttentionBackend) | ||||
|             cache_config.block_size = PallasAttentionBackend.get_page_size( | ||||
|                 vllm_config)  # type: ignore[assignment] | ||||
|  | ||||
| @ -122,11 +127,21 @@ class TpuPlatform(Platform): | ||||
|         scheduler_config = vllm_config.scheduler_config | ||||
|         if parallel_config.worker_cls == "auto": | ||||
|             if scheduler_config.is_multi_step: | ||||
|                 if envs.VLLM_USE_V1: | ||||
|                     raise NotImplementedError( | ||||
|                         "Multi-step scheduling is not supported (and not " | ||||
|                         "needed) on vLLM V1. Please launch without " | ||||
|                         "--num-scheduler-steps.") | ||||
|             parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" | ||||
|                 else: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" | ||||
|             else: | ||||
|                 if envs.VLLM_USE_V1: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.v1.worker.tpu_worker.TPUWorker" | ||||
|                 else: | ||||
|                     parallel_config.worker_cls = \ | ||||
|                         "vllm.worker.tpu_worker.TPUWorker" | ||||
|  | ||||
|         assert not vllm_config.speculative_config, ( | ||||
|             "Speculative decoding is not yet supported for TPU backend") | ||||
| @ -174,9 +189,13 @@ class TpuPlatform(Platform): | ||||
|         processed_inputs: ProcessorInputs, | ||||
|     ) -> None: | ||||
|         """Raises if this request is unsupported on this platform""" | ||||
|         if (isinstance(params, SamplingParams) | ||||
|                 and params.sampling_type == SamplingType.RANDOM_SEED): | ||||
|             raise ValueError("Torch XLA does not support per-request seed.") | ||||
|         if isinstance(params, SamplingParams): | ||||
|             if params.guided_decoding is not None and not envs.VLLM_USE_V1: | ||||
|                 raise ValueError("Structured output is not supported on " | ||||
|                                  f"{cls.device_name} V0.") | ||||
|             if params.sampling_type == SamplingType.RANDOM_SEED: | ||||
|                 raise ValueError( | ||||
|                     "Torch XLA does not support per-request seed.") | ||||
|  | ||||
|  | ||||
| try: | ||||
|  | ||||
| @ -39,10 +39,12 @@ class XPUPlatform(Platform): | ||||
|         if selected_backend != _Backend.IPEX: | ||||
|             logger.info("Cannot use %s backend on XPU.", selected_backend) | ||||
|         use_v1 = envs.VLLM_USE_V1 | ||||
|         if not use_v1: | ||||
|             raise ValueError("XPU backend only supports V1.") | ||||
|         if use_v1: | ||||
|             logger.info("Using Flash Attention backend on V1 engine.") | ||||
|             return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" | ||||
|         else: | ||||
|             logger.info("Using IPEX attention backend.") | ||||
|             return "vllm.attention.backends.ipex_attn.IpexAttnBackend" | ||||
|  | ||||
|     @classmethod | ||||
|     def get_device_capability( | ||||
| @ -75,7 +77,10 @@ class XPUPlatform(Platform): | ||||
|         cache_config = vllm_config.cache_config | ||||
|         # in V1(or with ipex chunked prefill) block_size is 64 | ||||
|         if cache_config and cache_config.block_size is None: | ||||
|             if envs.VLLM_USE_V1: | ||||
|                 cache_config.block_size = 64 | ||||
|             else: | ||||
|                 cache_config.block_size = 16 | ||||
|  | ||||
|         # Instances created using VllmConfig() typically have model_config as | ||||
|         # None by default. The modification involves adding a check to prevent | ||||
| @ -101,7 +106,11 @@ class XPUPlatform(Platform): | ||||
|  | ||||
|         # check and update parallel config | ||||
|         parallel_config = vllm_config.parallel_config | ||||
|         parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" | ||||
|         if envs.VLLM_USE_V1: | ||||
|             parallel_config.worker_cls =\ | ||||
|                 "vllm.v1.worker.xpu_worker.XPUWorker" | ||||
|         else: | ||||
|             parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" | ||||
|  | ||||
|         if parallel_config.distributed_executor_backend is None: | ||||
|             if parallel_config.world_size > 1: | ||||
|  | ||||
							
								
								
									
										326
									
								
								vllm/worker/cpu_enc_dec_model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										326
									
								
								vllm/worker/cpu_enc_dec_model_runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,326 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import dataclasses | ||||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from vllm.attention import AttentionMetadata | ||||
| from vllm.forward_context import set_forward_context | ||||
| from vllm.model_executor import SamplingMetadata | ||||
| from vllm.model_executor.layers.sampler import SamplerOutput | ||||
| from vllm.multimodal import MultiModalKwargs | ||||
| from vllm.sequence import IntermediateTensors, SequenceGroupMetadata | ||||
| from vllm.utils import make_tensor_with_pad | ||||
| from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, | ||||
|                                           ModelInputForCPUBuilder, | ||||
|                                           ModelInputForCPUWithSamplingMetadata) | ||||
| from vllm.worker.model_runner_base import ( | ||||
|     _add_attn_metadata_broadcastable_dict, | ||||
|     _add_sampling_metadata_broadcastable_dict) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.attention.backends.abstract import AttentionBackend | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass(frozen=True) | ||||
| class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): | ||||
|     """ | ||||
|     Used by the EncoderDecoderModelRunner. | ||||
|     """ | ||||
|     encoder_input_tokens: Optional[torch.Tensor] = None | ||||
|     encoder_input_positions: Optional[torch.Tensor] = None | ||||
|  | ||||
|     def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: | ||||
|         tensor_dict = { | ||||
|             "input_tokens": self.input_tokens, | ||||
|             "input_positions": self.input_positions, | ||||
|             "encoder_input_tokens": self.encoder_input_tokens, | ||||
|             "encoder_input_positions": self.encoder_input_positions, | ||||
|             "multi_modal_kwargs": self.multi_modal_kwargs, | ||||
|         } | ||||
|         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||||
|         _add_sampling_metadata_broadcastable_dict(tensor_dict, | ||||
|                                                   self.sampling_metadata) | ||||
|         return tensor_dict | ||||
|  | ||||
|     @classmethod | ||||
|     def from_broadcasted_tensor_dict( | ||||
|         cls, | ||||
|         tensor_dict: Dict[str, Any], | ||||
|         attn_backend: Optional["AttentionBackend"] = None, | ||||
|     ) -> "EncoderDecoderModelInputForCPU": | ||||
|         return cast( | ||||
|             EncoderDecoderModelInputForCPU, | ||||
|             super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) | ||||
|  | ||||
|  | ||||
| class CPUEncoderDecoderModelRunner( | ||||
|         CPUModelRunnerBase[EncoderDecoderModelInputForCPU]): | ||||
|     _model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( | ||||
|         EncoderDecoderModelInputForCPU) | ||||
|     _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder | ||||
|  | ||||
|     def _list_to_int32_tensor( | ||||
|         self, | ||||
|         _list: List[int], | ||||
|     ) -> torch.Tensor: | ||||
|         return torch.tensor(_list, dtype=torch.int32, device=self.device) | ||||
|  | ||||
|     def _list_to_long_tensor( | ||||
|         self, | ||||
|         _list: List[int], | ||||
|     ) -> torch.Tensor: | ||||
|         return torch.tensor(_list, dtype=torch.long, device=self.device) | ||||
|  | ||||
|     def _empty_int32_tensor(self) -> torch.Tensor: | ||||
|         return self._list_to_int32_tensor([]) | ||||
|  | ||||
|     def _empty_long_tensor(self) -> torch.Tensor: | ||||
|         return self._list_to_long_tensor([]) | ||||
|  | ||||
|     def make_model_input_from_broadcasted_tensor_dict( | ||||
|             self, tensor_dict: Dict[str, | ||||
|                                     Any]) -> EncoderDecoderModelInputForCPU: | ||||
|         return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict( | ||||
|             tensor_dict, | ||||
|             attn_backend=self.attn_backend, | ||||
|         ) | ||||
|  | ||||
|     def prepare_model_input( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         virtual_engine: int = 0, | ||||
|         finished_requests_ids: Optional[List[str]] = None | ||||
|     ) -> EncoderDecoderModelInputForCPU: | ||||
|         model_input = self._prepare_model_input_tensors( | ||||
|             seq_group_metadata_list, finished_requests_ids) | ||||
|         ( | ||||
|             attn_metadata, | ||||
|             encoder_input_tokens_tensor, | ||||
|             encoder_input_positions_tensor, | ||||
|         ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, | ||||
|                                                       model_input) | ||||
|         # Sampling metadata is only required for the final pp group | ||||
|         generators = self.get_generators(finished_requests_ids) | ||||
|         sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, | ||||
|                                                      model_input.seq_lens, | ||||
|                                                      model_input.query_lens, | ||||
|                                                      self.device, | ||||
|                                                      pin_memory=False, | ||||
|                                                      generators=generators) | ||||
|         return dataclasses.replace( | ||||
|             model_input, | ||||
|             sampling_metadata=sampling_metadata, | ||||
|             attn_metadata=attn_metadata, | ||||
|             encoder_input_tokens=encoder_input_tokens_tensor, | ||||
|             encoder_input_positions=encoder_input_positions_tensor, | ||||
|             virtual_engine=virtual_engine, | ||||
|         ) | ||||
|  | ||||
|     def _prepare_encoder_model_input_tensors( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         model_input: EncoderDecoderModelInputForCPU, | ||||
|     ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], | ||||
|                Optional[torch.Tensor]]: | ||||
|         """Helper method to prepare the encoder- and cross-attn-related | ||||
|         model inputs based on a given sequence group. These additional inputs | ||||
|         are used to augment an already-computed `EncoderDecoderModelInput` | ||||
|         data structure which already has decoder-related model inputs | ||||
|         populated. | ||||
|  | ||||
|         Sets the following attn_metadata fields: | ||||
|         * `num_encoder_tokens` | ||||
|         * `encoder_seq_lens` | ||||
|         * `encoder_seq_lens_tensor` | ||||
|         * `max_encoder_seq_len` | ||||
|         * `cross_slot_mapping` | ||||
|         * `cross_block_tables` | ||||
|  | ||||
|         Constructs a new model inputs data structure, based on | ||||
|         (1) the existing fields in the `model_inputs` argument, | ||||
|         and (2) the following additional fields which are | ||||
|         computed (or in the case of `attn_metadata`, updated)  | ||||
|         by this function: | ||||
|         * attn_metadata | ||||
|         * encoder_input_tokens | ||||
|         * encoder_input_positions | ||||
|  | ||||
|         Arguments: | ||||
|  | ||||
|         * seq_group_metadata_list: list of sequence groups for which to | ||||
|                                    compute inputs | ||||
|         * model_inputs: model inputs data structure with decoder-oriented | ||||
|                         fields already computed. | ||||
|  | ||||
|         Return: | ||||
|  | ||||
|         * Updated model inputs data structure | ||||
|         """ | ||||
|  | ||||
|         if len(seq_group_metadata_list) == 0: | ||||
|             return (model_input.attn_metadata, None, None) | ||||
|  | ||||
|         # Since we are not supporting chunked prefill either the entire | ||||
|         # batch is prefill or it is decode | ||||
|         is_prompt = seq_group_metadata_list[0].is_prompt | ||||
|  | ||||
|         # Build encoder inputs | ||||
|         encoder_seq_lens: List[int] = [] | ||||
|         if is_prompt: | ||||
|             # Prefill phase. | ||||
|             cross_block_tables = self._empty_int32_tensor().view( | ||||
|                 len(seq_group_metadata_list), -1) | ||||
|  | ||||
|             # Extract input tokens/positions, cross-attention slot-mapping, | ||||
|             # & seq len from each sequence group metadata | ||||
|             ( | ||||
|                 encoder_input_tokens, | ||||
|                 encoder_input_positions, | ||||
|                 cross_slot_mapping, | ||||
|             ) = ( | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|             ) | ||||
|             for seq_group_metadata in seq_group_metadata_list: | ||||
|                 # Build seq lens | ||||
|                 seq_len = seq_group_metadata.encoder_seq_data.get_len() | ||||
|                 token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() | ||||
|                 encoder_seq_lens.append(seq_len) | ||||
|  | ||||
|                 # Build slot mapping | ||||
|                 for i in range(0, seq_len): | ||||
|                     block_number = seq_group_metadata.cross_block_table[ | ||||
|                         i // self.block_size] | ||||
|                     block_offset = i % self.block_size | ||||
|                     slot = block_number * self.block_size + block_offset | ||||
|                     cross_slot_mapping.append(slot) | ||||
|  | ||||
|                 # Build encoder input tokens | ||||
|                 encoder_input_tokens.extend(token_ids) | ||||
|                 encoder_input_positions.extend(list(range(0, seq_len))) | ||||
|  | ||||
|             # Convert tokens/positions & cross-attention | ||||
|             # slot-mapping to encoder input tensors | ||||
|             encoder_input_tokens_tensor = self._list_to_long_tensor( | ||||
|                 encoder_input_tokens) | ||||
|             encoder_input_positions_tensor = self._list_to_long_tensor( | ||||
|                 encoder_input_positions) | ||||
|             cross_slot_mapping_tensor = self._list_to_long_tensor( | ||||
|                 cross_slot_mapping) | ||||
|  | ||||
|         else: | ||||
|             # Decode phase. | ||||
|             encoder_input_tokens_tensor = self._empty_long_tensor() | ||||
|             encoder_input_positions_tensor = self._empty_long_tensor() | ||||
|             cross_slot_mapping_tensor = self._empty_long_tensor() | ||||
|             # Extract cross-attention block tables & | ||||
|             # seq len from each sequence group metadata. | ||||
|             # Cross-attention block tables are empty | ||||
|             # during vLLM memory profiling. | ||||
|             cross_block_tables = [] | ||||
|             for seq_group_metadata in seq_group_metadata_list: | ||||
|                 for _ in range(len(seq_group_metadata.seq_data)): | ||||
|                     encoder_seq_lens.append( | ||||
|                         seq_group_metadata.encoder_seq_data.get_len()) | ||||
|                     cross_block_table = seq_group_metadata.cross_block_table | ||||
|                     cross_block_tables.append([] if ( | ||||
|                         cross_block_table is None) else cross_block_table) | ||||
|  | ||||
|             max_len_of_block_table = max( | ||||
|                 len(block_table) for block_table in cross_block_tables) | ||||
|  | ||||
|             cross_block_tables = make_tensor_with_pad( | ||||
|                 cross_block_tables, | ||||
|                 max_len=max_len_of_block_table, | ||||
|                 pad=0, | ||||
|                 dtype=torch.int32, | ||||
|                 device=self.device, | ||||
|             ) | ||||
|  | ||||
|         # Compute encoder sequence lengths & encoder | ||||
|         # sequence starting offset tensors | ||||
|         max_encoder_seq_len = max(encoder_seq_lens, default=0) | ||||
|         encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) | ||||
|         encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + | ||||
|                                             1, | ||||
|                                             dtype=torch.int32, | ||||
|                                             device=self.device) | ||||
|         torch.cumsum(encoder_seq_lens_tensor, | ||||
|                      dim=0, | ||||
|                      dtype=encoder_seq_start_loc.dtype, | ||||
|                      out=encoder_seq_start_loc[1:]) | ||||
|  | ||||
|         # Update attention metadata with encoder-oriented attributes | ||||
|         attn_metadata = model_input.attn_metadata | ||||
|         assert attn_metadata is not None | ||||
|         ( | ||||
|             attn_metadata.num_encoder_tokens, | ||||
|             attn_metadata.encoder_seq_lens, | ||||
|             attn_metadata.encoder_seq_lens_tensor, | ||||
|             attn_metadata.max_encoder_seq_len, | ||||
|             attn_metadata.cross_slot_mapping, | ||||
|             attn_metadata.cross_block_tables, | ||||
|         ) = ( | ||||
|             sum(encoder_seq_lens), | ||||
|             encoder_seq_lens, | ||||
|             encoder_seq_lens_tensor, | ||||
|             max_encoder_seq_len, | ||||
|             cross_slot_mapping_tensor, | ||||
|             cross_block_tables, | ||||
|         ) | ||||
|  | ||||
|         return (attn_metadata, encoder_input_tokens_tensor, | ||||
|                 encoder_input_positions_tensor) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         model_input: EncoderDecoderModelInputForCPU, | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         intermediate_tensors: Optional[IntermediateTensors] = None, | ||||
|         num_steps: int = 1, | ||||
|     ) -> Optional[List[SamplerOutput]]: | ||||
|         if num_steps > 1: | ||||
|             raise ValueError( | ||||
|                 "CPU worker does not support multi-step execution.") | ||||
|  | ||||
|         model_executable = self.model | ||||
|         execute_model_kwargs = { | ||||
|             "input_ids": | ||||
|             model_input.input_tokens, | ||||
|             "positions": | ||||
|             model_input.input_positions, | ||||
|             "encoder_input_ids": | ||||
|             model_input.encoder_input_tokens, | ||||
|             "encoder_positions": | ||||
|             model_input.encoder_input_positions, | ||||
|             **MultiModalKwargs.as_kwargs( | ||||
|                 model_input.multi_modal_kwargs or {}, | ||||
|                 device=self.device, | ||||
|             ), | ||||
|             "intermediate_tensors": | ||||
|             intermediate_tensors, | ||||
|         } | ||||
|  | ||||
|         with set_forward_context(model_input.attn_metadata, self.vllm_config, | ||||
|                                  model_input.virtual_engine): | ||||
|             hidden_states = model_executable(**execute_model_kwargs) | ||||
|  | ||||
|         # Compute the logits. | ||||
|         logits = self.model.compute_logits(hidden_states, | ||||
|                                            model_input.sampling_metadata) | ||||
|  | ||||
|         # Only perform sampling in the driver worker. | ||||
|         if not self.is_driver_worker: | ||||
|             return [] | ||||
|  | ||||
|         # Sample the next token. | ||||
|         output = self.sampler( | ||||
|             logits=logits, | ||||
|             sampling_metadata=model_input.sampling_metadata, | ||||
|         ) | ||||
|         return [output] | ||||
							
								
								
									
										671
									
								
								vllm/worker/cpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										671
									
								
								vllm/worker/cpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,671 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import dataclasses | ||||
| import weakref | ||||
| from collections import defaultdict | ||||
| from dataclasses import dataclass | ||||
| from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, | ||||
|                     TypeVar, Union) | ||||
|  | ||||
| import torch | ||||
| from torch import nn | ||||
|  | ||||
| from vllm.attention import AttentionMetadata, get_attn_backend | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.forward_context import set_forward_context | ||||
| from vllm.logger import init_logger | ||||
| from vllm.lora.layers import LoRAMapping | ||||
| from vllm.lora.request import LoRARequest | ||||
| from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager | ||||
| from vllm.model_executor import SamplingMetadata | ||||
| from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding | ||||
| from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||||
| from vllm.model_executor.model_loader import get_model | ||||
| from vllm.model_executor.models import supports_lora, supports_multimodal | ||||
| from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs, | ||||
|                              MultiModalPlaceholderMap) | ||||
| from vllm.sequence import (IntermediateTensors, SequenceData, | ||||
|                            SequenceGroupMetadata) | ||||
| from vllm.worker.model_runner_base import ( | ||||
|     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, | ||||
|     _add_attn_metadata_broadcastable_dict, | ||||
|     _add_sampling_metadata_broadcastable_dict, | ||||
|     _init_attn_metadata_from_tensor_dict, | ||||
|     _init_sampling_metadata_from_tensor_dict) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.attention.backends.abstract import AttentionBackend | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU") | ||||
| _PAD_SLOT_ID = -1 | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class ModelInputForCPU(ModelRunnerInputBase): | ||||
|     """ | ||||
|     Base class contains metadata needed for the base model forward pass on CPU | ||||
|     """ | ||||
|     input_tokens: Optional[torch.Tensor] = None | ||||
|     input_positions: Optional[torch.Tensor] = None | ||||
|     token_type_ids: Optional[torch.Tensor] = None | ||||
|     attn_metadata: Optional["AttentionMetadata"] = None | ||||
|     multi_modal_kwargs: Optional[BatchedTensorInputs] = None | ||||
|     virtual_engine: Optional[int] = None | ||||
|     seq_lens: Optional[List[int]] = None | ||||
|     query_lens: Optional[List[int]] = None | ||||
|     lora_mapping: Optional["LoRAMapping"] = None | ||||
|     lora_requests: Optional[Set[LoRARequest]] = None | ||||
|  | ||||
|     def as_broadcastable_tensor_dict( | ||||
|             self) -> Dict[str, Union[int, torch.Tensor]]: | ||||
|         tensor_dict = { | ||||
|             "input_tokens": self.input_tokens, | ||||
|             "input_positions": self.input_positions, | ||||
|             "token_type_ids": self.token_type_ids, | ||||
|             "multi_modal_kwargs": self.multi_modal_kwargs, | ||||
|             "lora_requests": self.lora_requests, | ||||
|             "lora_mapping": self.lora_mapping, | ||||
|         } | ||||
|         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||||
|  | ||||
|         return tensor_dict | ||||
|  | ||||
|     @classmethod | ||||
|     def from_broadcasted_tensor_dict( | ||||
|         cls: Type[TModelInputForCPU], | ||||
|         tensor_dict: Dict[str, Any], | ||||
|         attn_backend: Optional["AttentionBackend"] = None | ||||
|     ) -> TModelInputForCPU: | ||||
|         if attn_backend is not None: | ||||
|             tensor_dict = _init_attn_metadata_from_tensor_dict( | ||||
|                 attn_backend, tensor_dict) | ||||
|         return cls(**tensor_dict) | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): | ||||
|     """ | ||||
|     Used by the ModelRunner. | ||||
|     """ | ||||
|     sampling_metadata: Optional["SamplingMetadata"] = None | ||||
|     is_prompt: Optional[bool] = None | ||||
|  | ||||
|     def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: | ||||
|         tensor_dict = { | ||||
|             "input_tokens": self.input_tokens, | ||||
|             "input_positions": self.input_positions, | ||||
|             "token_type_ids": self.token_type_ids, | ||||
|             "multi_modal_kwargs": self.multi_modal_kwargs, | ||||
|         } | ||||
|         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||||
|         _add_sampling_metadata_broadcastable_dict(tensor_dict, | ||||
|                                                   self.sampling_metadata) | ||||
|         return tensor_dict | ||||
|  | ||||
|     @classmethod | ||||
|     def from_broadcasted_tensor_dict( | ||||
|         cls, | ||||
|         tensor_dict: Dict[str, Any], | ||||
|         attn_backend: Optional["AttentionBackend"] = None, | ||||
|     ) -> "ModelInputForCPUWithSamplingMetadata": | ||||
|         tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) | ||||
|         if attn_backend is not None: | ||||
|             tensor_dict = _init_attn_metadata_from_tensor_dict( | ||||
|                 attn_backend, tensor_dict) | ||||
|         return cls(**tensor_dict) | ||||
|  | ||||
|  | ||||
| class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): | ||||
|  | ||||
|     class ModelInputData: | ||||
|  | ||||
|         def __init__(self, use_mrope: bool): | ||||
|             self.use_mrope = use_mrope | ||||
|             self.input_tokens: List[int] = [] | ||||
|             self.input_positions: List[int] = [] | ||||
|             self.token_type_ids: Optional[List[int]] = [] | ||||
|             self.seq_lens: List[int] = [] | ||||
|             self.query_lens: List[int] = [] | ||||
|             self.prefill_block_tables: List[List[int]] = [] | ||||
|             self.decode_block_tables: List[List[int]] = [] | ||||
|             self.max_decode_seq_len: int = 0 | ||||
|             self.num_prefills: int = 0 | ||||
|             self.num_prefill_tokens: int = 0 | ||||
|             self.num_decode_tokens: int = 0 | ||||
|             self.slot_mapping: List[int] = [] | ||||
|             self.multi_modal_inputs_list: List[MultiModalKwargs] = [] | ||||
|             self.multi_modal_placeholder_maps: Dict[ | ||||
|                 str, MultiModalPlaceholderMap] = defaultdict( | ||||
|                     MultiModalPlaceholderMap) | ||||
|             self.input_mrope_positions: List[List[int]] = [[] | ||||
|                                                            for _ in range(3)] | ||||
|  | ||||
|     def __init__(self, | ||||
|                  runner: "CPUModelRunner", | ||||
|                  finished_requests_ids: Optional[List[str]] = None) -> None: | ||||
|         super().__init__() | ||||
|         self.runner = runner | ||||
|         self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled | ||||
|                                 or runner.cache_config.enable_prefix_caching) | ||||
|         self.model_input_cls = self.runner._model_input_cls | ||||
|         self.attn_backend = self.runner.attn_backend | ||||
|         self.sliding_window = self.runner.sliding_window | ||||
|         self.block_size = self.runner.block_size | ||||
|         self.device = self.runner.device | ||||
|         self.enable_lora = self.runner.lora_config is not None | ||||
|         if self.runner.attn_backend is not None: | ||||
|             # spec decode (e.g. Medusa) does not have atten backend | ||||
|             attn_backend = self.runner.attn_backend | ||||
|             self.att_metadata_builder = attn_backend.get_builder_cls()(self) | ||||
|  | ||||
|     def prepare(self, | ||||
|                 finished_requests_ids: Optional[List[str]] = None) -> None: | ||||
|         self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] | ||||
|         self.input_data = ModelInputForCPUBuilder.ModelInputData( | ||||
|             self.runner.model_config.uses_mrope) | ||||
|         self.att_metadata_builder.prepare() | ||||
|  | ||||
|     def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): | ||||
|         self.seq_group_metadata_list.append(seq_group_metadata) | ||||
|  | ||||
|     def set_seq_group_list( | ||||
|             self, seq_group_metadata_list: List[SequenceGroupMetadata]): | ||||
|         self.seq_group_metadata_list = seq_group_metadata_list | ||||
|  | ||||
|     def build(self) -> ModelInputForCPU: | ||||
|         self._build_input_data() | ||||
|  | ||||
|         input_data = self.input_data | ||||
|         input_tokens = torch.tensor(input_data.input_tokens, | ||||
|                                     dtype=torch.long, | ||||
|                                     device="cpu") | ||||
|         input_positions = torch.tensor( | ||||
|             input_data.input_positions | ||||
|             if not any(input_data.input_mrope_positions) else | ||||
|             input_data.input_mrope_positions, | ||||
|             dtype=torch.long, | ||||
|             device="cpu") | ||||
|         token_type_ids = torch.tensor(input_data.token_type_ids, | ||||
|                                     dtype=torch.long, | ||||
|                                     device="cpu") \ | ||||
|                                     if input_data.token_type_ids else None | ||||
|  | ||||
|         # For multi-modal models | ||||
|         multi_modal_kwargs = None | ||||
|         if len(input_data.multi_modal_inputs_list) != 0: | ||||
|             multi_modal_kwargs = MultiModalKwargs.batch( | ||||
|                 input_data.multi_modal_inputs_list) | ||||
|  | ||||
|         attn_metadata = self.att_metadata_builder.build( | ||||
|             input_data.seq_lens, input_data.query_lens, -1, -1) | ||||
|  | ||||
|         is_prompt = (self.seq_group_metadata_list[0].is_prompt | ||||
|                      if self.seq_group_metadata_list else None) | ||||
|         # LoRA data. | ||||
|         lora_requests = set() | ||||
|         lora_mapping = None | ||||
|         if self.enable_lora: | ||||
|             lora_requests = set(seq.lora_request | ||||
|                                 for seq in self.seq_group_metadata_list | ||||
|                                 if seq.lora_request is not None) | ||||
|  | ||||
|             lora_mapping = self._prepare_lora_input( | ||||
|                 self.seq_group_metadata_list, is_prompt) | ||||
|  | ||||
|         return self.model_input_cls(input_tokens=input_tokens, | ||||
|                                     input_positions=input_positions, | ||||
|                                     token_type_ids=token_type_ids, | ||||
|                                     seq_lens=input_data.seq_lens, | ||||
|                                     query_lens=input_data.query_lens, | ||||
|                                     attn_metadata=attn_metadata, | ||||
|                                     multi_modal_kwargs=multi_modal_kwargs, | ||||
|                                     lora_mapping=lora_mapping, | ||||
|                                     lora_requests=lora_requests) | ||||
|  | ||||
|     def _build_input_data(self): | ||||
|         for seq_group_metadata in self.seq_group_metadata_list: | ||||
|             for seq_id, seq_data in seq_group_metadata.seq_data.items(): | ||||
|                 if seq_group_metadata.is_prompt: | ||||
|                     self._compute_prompt_input_tokens(self.input_data, | ||||
|                                                       seq_group_metadata, | ||||
|                                                       seq_data, seq_id) | ||||
|                     if seq_group_metadata.multi_modal_data: | ||||
|                         self._compute_multi_modal_input( | ||||
|                             seq_group_metadata, seq_data) | ||||
|                 else: | ||||
|                     self._compute_decode_input_tokens(self.input_data, | ||||
|                                                       seq_group_metadata, | ||||
|                                                       seq_data, seq_id) | ||||
|  | ||||
|     def _compute_decode_input_tokens(self, data: ModelInputData, | ||||
|                                      seq_group_metadata: SequenceGroupMetadata, | ||||
|                                      seq_data: SequenceData, seq_id: int): | ||||
|         """ | ||||
|         Compute decode input tokens, positions, block table and slot mapping. | ||||
|         """ | ||||
|         block_size = self.runner.block_size | ||||
|  | ||||
|         block_table = seq_group_metadata.block_tables[seq_id] | ||||
|         seq_len = seq_data.get_len() | ||||
|         context_len = seq_data.get_num_computed_tokens() | ||||
|  | ||||
|         tokens = seq_data.get_last_token_id() | ||||
|         token_positions = seq_len - 1 | ||||
|         block_number = block_table[token_positions // block_size] | ||||
|         block_offset = token_positions % block_size | ||||
|         slot = block_number * block_size + block_offset | ||||
|  | ||||
|         # For paged_attention kernel | ||||
|         if self.runner.sliding_window: | ||||
|             start_idx = max(0, seq_len - self.runner.sliding_window) | ||||
|             start_block = start_idx // block_size | ||||
|             start_idx = start_block * block_size | ||||
|             seq_len = seq_len - start_idx | ||||
|             block_table = block_table[start_block:] | ||||
|  | ||||
|         # For MRotaryEmbedding | ||||
|         if seq_data.mrope_position_delta is not None: | ||||
|             next_pos = MRotaryEmbedding.get_next_input_positions( | ||||
|                 seq_data.mrope_position_delta, | ||||
|                 context_len, | ||||
|                 seq_len, | ||||
|             ) | ||||
|             for idx in range(3): | ||||
|                 data.input_mrope_positions[idx].extend(  # type: ignore | ||||
|                     next_pos[idx]) | ||||
|         else: | ||||
|             data.input_positions.append(token_positions)  # type: ignore | ||||
|  | ||||
|         # Update fields | ||||
|         data.input_tokens.append(tokens) | ||||
|         data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) | ||||
|         data.num_decode_tokens += 1 | ||||
|         data.slot_mapping.append(slot) | ||||
|         data.decode_block_tables.append(block_table) | ||||
|         data.query_lens.append(1) | ||||
|         data.seq_lens.append(seq_len) | ||||
|  | ||||
|     def _compute_prompt_input_tokens(self, data: ModelInputData, | ||||
|                                      seq_group_metadata: SequenceGroupMetadata, | ||||
|                                      seq_data: SequenceData, seq_id: int): | ||||
|         """ | ||||
|         Compute prompt input tokens, positions, block table and slot mapping. | ||||
|         """ | ||||
|         token_chunk_size = seq_group_metadata.token_chunk_size | ||||
|         block_size = self.runner.block_size | ||||
|  | ||||
|         block_table = seq_group_metadata.block_tables[seq_id] | ||||
|         seq_len = seq_data.get_len() | ||||
|         context_len = seq_data.get_num_computed_tokens() | ||||
|         seq_len = min(seq_len, context_len + token_chunk_size) | ||||
|  | ||||
|         # For prefix caching | ||||
|         prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) | ||||
|         if prefix_cache_block_num > 0: | ||||
|             prefix_cache_len = (prefix_cache_block_num * | ||||
|                                 self.runner.block_size) | ||||
|             if prefix_cache_len <= context_len: | ||||
|                 # We already passed the cache hit region, | ||||
|                 # so do normal computation. | ||||
|                 pass | ||||
|             elif context_len < prefix_cache_len < seq_len: | ||||
|                 # Partial hit. Compute the missing part. | ||||
|                 context_len = prefix_cache_len | ||||
|                 token_chunk_size = seq_len - context_len | ||||
|             elif seq_len <= prefix_cache_len: | ||||
|                 # Full hit. Only compute the last token to avoid | ||||
|                 # erroneous behavior. FIXME: Ideally we should directly | ||||
|                 # mark all tokens as computed in the scheduler and do not | ||||
|                 # schedule this sequence, so this case should not happen. | ||||
|                 context_len = seq_len - 1 | ||||
|                 token_chunk_size = 1 | ||||
|  | ||||
|         tokens = seq_data.get_token_ids() | ||||
|         tokens = tokens[context_len:seq_len] | ||||
|         token_positions = range(context_len, seq_len) | ||||
|         token_types = seq_group_metadata.token_type_ids | ||||
|  | ||||
|         # For encoder-only models, the block_table is None, | ||||
|         # and there is no need to initialize the slot_mapping. | ||||
|         if block_table is not None: | ||||
|             slot_mapping = [_PAD_SLOT_ID] * len(token_positions) | ||||
|             for i, pos in enumerate(token_positions): | ||||
|                 block_number = block_table[pos // block_size] | ||||
|                 block_offset = pos % block_size | ||||
|                 slot = block_number * block_size + block_offset | ||||
|                 slot_mapping[i] = slot | ||||
|             data.slot_mapping.extend(slot_mapping) | ||||
|  | ||||
|         # The MROPE positions are prepared in _compute_multi_modal_input | ||||
|         data.input_positions.extend(token_positions) | ||||
|  | ||||
|         if data.token_type_ids is not None: | ||||
|             data.token_type_ids.extend(token_types if token_types else []) | ||||
|  | ||||
|         # Update fields | ||||
|         data.input_tokens.extend(tokens) | ||||
|         data.num_prefills += 1 | ||||
|         data.num_prefill_tokens += len(tokens) | ||||
|         data.query_lens.append(len(tokens)) | ||||
|         data.prefill_block_tables.append(block_table) | ||||
|         data.seq_lens.append(seq_len) | ||||
|  | ||||
|     def _compute_multi_modal_input(self, | ||||
|                                    seq_group_metadata: SequenceGroupMetadata, | ||||
|                                    seq_data: SequenceData): | ||||
|         computed_len = seq_data.get_num_computed_tokens() | ||||
|         seq_len = self.input_data.seq_lens[-1] | ||||
|  | ||||
|         # NOTE: mm_kwargs only includes the subset of multi-modal items that | ||||
|         # intersect with the current prefill positions. | ||||
|         mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( | ||||
|             seq_group_metadata, range(computed_len, seq_len)) | ||||
|  | ||||
|         if not mm_kwargs: | ||||
|             return | ||||
|  | ||||
|         # special processing for mrope position deltas. | ||||
|         if self.runner.model_config.uses_mrope: | ||||
|             assert not self.chunked_prefill, \ | ||||
|                 "MROPE on CPU does not support chunked-prefill." | ||||
|  | ||||
|             image_grid_thw = mm_kwargs.get("image_grid_thw", None) | ||||
|             video_grid_thw = mm_kwargs.get("video_grid_thw", None) | ||||
|             audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", | ||||
|                                                   None) | ||||
|             assert ( | ||||
|                 image_grid_thw is not None or video_grid_thw is not None | ||||
|                 or audio_feature_lengths is not None), ( | ||||
|                     "mrope embedding type requires multi-modal input mapper " | ||||
|                     "returns 'image_grid_thw' or 'video_grid_thw' or " | ||||
|                     "'audio_feature_lengths'.") | ||||
|  | ||||
|             second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) | ||||
|             use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) | ||||
|             hf_config = self.runner.model_config.hf_config | ||||
|             token_ids = seq_data.get_token_ids() | ||||
|  | ||||
|             mrope_positions, mrope_position_delta = \ | ||||
|                 MRotaryEmbedding.get_input_positions( | ||||
|                     token_ids, | ||||
|                     hf_config=hf_config, | ||||
|                     image_grid_thw=image_grid_thw, | ||||
|                     video_grid_thw=video_grid_thw, | ||||
|                     second_per_grid_ts=second_per_grid_ts, | ||||
|                     context_len=computed_len, | ||||
|                     audio_feature_lengths=audio_feature_lengths, | ||||
|                     use_audio_in_video=use_audio_in_video, | ||||
|                 ) | ||||
|             seq_data.mrope_position_delta = mrope_position_delta | ||||
|  | ||||
|             for i in range(3): | ||||
|                 self.input_data.input_mrope_positions[  # type: ignore | ||||
|                     i].extend(mrope_positions[i]) | ||||
|  | ||||
|         self.input_data.multi_modal_inputs_list.append(mm_kwargs) | ||||
|         for modality, placeholder_map in placeholder_maps.items(): | ||||
|             self.input_data.multi_modal_placeholder_maps[modality].extend( | ||||
|                 placeholder_map) | ||||
|  | ||||
|     def _prepare_lora_input( | ||||
|             self, seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|             is_prefill: bool) -> LoRAMapping: | ||||
|         index_mapping = [] | ||||
|         prompt_mapping = [] | ||||
|         for seq in seq_group_metadata_list: | ||||
|             lora_id = seq.lora_int_id | ||||
|             query_len = seq.token_chunk_size | ||||
|  | ||||
|             index_mapping += [lora_id] * query_len | ||||
|             prompt_mapping += [lora_id] * ( | ||||
|                 query_len if seq.sampling_params | ||||
|                 and seq.sampling_params.prompt_logprobs is not None else 1) | ||||
|  | ||||
|         return LoRAMapping(index_mapping=tuple(index_mapping), | ||||
|                            prompt_mapping=tuple(prompt_mapping), | ||||
|                            is_prefill=is_prefill) | ||||
|  | ||||
|  | ||||
| class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): | ||||
|     """ | ||||
|     Helper class for shared methods between CPU model runners. | ||||
|     """ | ||||
|     _model_input_cls: Type[TModelInputForCPU] | ||||
|     _builder_cls: Type[ModelInputForCPUBuilder] | ||||
|     builder: ModelInputForCPUBuilder | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         kv_cache_dtype: Optional[str] = "auto", | ||||
|         is_driver_worker: bool = False, | ||||
|         return_hidden_states: bool = False, | ||||
|         *args, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         ModelRunnerBase.__init__(self, vllm_config) | ||||
|         model_config = self.model_config | ||||
|         cache_config = self.cache_config | ||||
|  | ||||
|         self.is_driver_worker = is_driver_worker | ||||
|         self.return_hidden_states = return_hidden_states | ||||
|  | ||||
|         self.device = self.device_config.device | ||||
|         self.pin_memory = False | ||||
|  | ||||
|         self.kv_cache_dtype = kv_cache_dtype | ||||
|         self.sliding_window = model_config.get_sliding_window() | ||||
|         self.block_size = cache_config.block_size | ||||
|         num_attn_heads = self.model_config.get_num_attention_heads( | ||||
|             self.parallel_config) | ||||
|         needs_attn_backend = (num_attn_heads != 0 | ||||
|                               or self.model_config.is_attention_free) | ||||
|         self.attn_backend = get_attn_backend( | ||||
|             self.model_config.get_head_size(), | ||||
|             self.model_config.dtype, | ||||
|             self.kv_cache_dtype, | ||||
|             self.block_size, | ||||
|             self.model_config.is_attention_free, | ||||
|             use_mla=self.model_config.use_mla, | ||||
|         ) if needs_attn_backend else None | ||||
|  | ||||
|         # Lazy initialization. | ||||
|         self.model: nn.Module  # Set after init_Model | ||||
|         # Set after load_model. | ||||
|         self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None | ||||
|         self.sampler = get_sampler() | ||||
|  | ||||
|         if hasattr(self, "_builder_cls"): | ||||
|             # multi-step model runner does not have `_builder_cls` | ||||
|             self.builder = self._builder_cls(weakref.proxy(self)) | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         self.model = get_model(vllm_config=self.vllm_config) | ||||
|  | ||||
|         if self.lora_config: | ||||
|             assert supports_lora( | ||||
|                 self.model | ||||
|             ), f"{self.model.__class__.__name__} does not support LoRA yet." | ||||
|  | ||||
|             if supports_multimodal(self.model): | ||||
|                 logger.warning("Regarding multimodal models, vLLM currently " | ||||
|                                "only supports adding LoRA to language model.") | ||||
|  | ||||
|             # Use get_text_config() in case of multimodal models | ||||
|             text_config = self.model_config.hf_config.get_text_config() | ||||
|  | ||||
|             self.lora_manager = LRUCacheWorkerLoRAManager( | ||||
|                 self.scheduler_config.max_num_seqs, | ||||
|                 self.scheduler_config.max_num_batched_tokens, | ||||
|                 self.vocab_size, | ||||
|                 self.lora_config, | ||||
|                 self.device, | ||||
|                 self.model.embedding_modules, | ||||
|                 self.model.embedding_padding_modules, | ||||
|                 max_position_embeddings=text_config.max_position_embeddings, | ||||
|             ) | ||||
|             self.model = self.lora_manager.create_lora_manager(self.model) | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         return self.model | ||||
|  | ||||
|     def _prepare_model_input_tensors( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         finished_requests_ids: Optional[List[str]] = None | ||||
|     ) -> TModelInputForCPU: | ||||
|         """Helper method to prepare the model input based on a given sequence | ||||
|         group. Prepares metadata needed for the base model forward pass but not | ||||
|         metadata for possible additional steps, e.g., sampling. | ||||
|  | ||||
|         """ | ||||
|         self.builder.prepare(finished_requests_ids) | ||||
|         self.builder.set_seq_group_list(seq_group_metadata_list) | ||||
|  | ||||
|         return self.builder.build()  # type: ignore | ||||
|  | ||||
|     @property | ||||
|     def vocab_size(self) -> int: | ||||
|         return self.model_config.get_vocab_size() | ||||
|  | ||||
|     def remove_all_loras(self): | ||||
|         if not self.lora_manager: | ||||
|             raise RuntimeError("LoRA is not enabled.") | ||||
|         self.lora_manager.remove_all_adapters() | ||||
|  | ||||
|     def set_active_loras(self, lora_requests: Set[LoRARequest], | ||||
|                          lora_mapping: LoRAMapping) -> None: | ||||
|         if not self.lora_manager: | ||||
|             raise RuntimeError("LoRA is not enabled.") | ||||
|         self.lora_manager.set_active_adapters(lora_requests, lora_mapping) | ||||
|  | ||||
|     def add_lora(self, lora_request: LoRARequest) -> bool: | ||||
|         if not self.lora_manager: | ||||
|             raise RuntimeError("LoRA is not enabled.") | ||||
|         return self.lora_manager.add_adapter(lora_request) | ||||
|  | ||||
|     def remove_lora(self, lora_id: int) -> bool: | ||||
|         if not self.lora_manager: | ||||
|             raise RuntimeError("LoRA is not enabled.") | ||||
|         return self.lora_manager.remove_adapter(lora_id) | ||||
|  | ||||
|     def pin_lora(self, lora_id: int) -> bool: | ||||
|         if not self.lora_manager: | ||||
|             raise RuntimeError("LoRA is not enabled.") | ||||
|         return self.lora_manager.pin_adapter(lora_id) | ||||
|  | ||||
|     def list_loras(self) -> Set[int]: | ||||
|         if not self.lora_manager: | ||||
|             raise RuntimeError("LoRA is not enabled.") | ||||
|         return self.lora_manager.list_adapters() | ||||
|  | ||||
|  | ||||
| class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): | ||||
|     _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( | ||||
|         ModelInputForCPUWithSamplingMetadata) | ||||
|     _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder | ||||
|  | ||||
|     def make_model_input_from_broadcasted_tensor_dict( | ||||
|         self, | ||||
|         tensor_dict: Dict[str, Any], | ||||
|     ) -> ModelInputForCPUWithSamplingMetadata: | ||||
|         return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict(  # noqa: E501 | ||||
|             tensor_dict, | ||||
|             attn_backend=self.attn_backend, | ||||
|         ) | ||||
|  | ||||
|     def prepare_model_input( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         virtual_engine: int = 0, | ||||
|         finished_requests_ids: Optional[List[str]] = None | ||||
|     ) -> ModelInputForCPUWithSamplingMetadata: | ||||
|         """Prepare the model input based on a given sequence group, including | ||||
|         metadata for the sampling step. | ||||
|  | ||||
|         """ | ||||
|         model_input = self._prepare_model_input_tensors( | ||||
|             seq_group_metadata_list, finished_requests_ids) | ||||
|         # Sampling metadata is only required for the final pp group | ||||
|         generators = self.get_generators(finished_requests_ids) | ||||
|         sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, | ||||
|                                                      model_input.seq_lens, | ||||
|                                                      model_input.query_lens, | ||||
|                                                      self.device, | ||||
|                                                      pin_memory=False, | ||||
|                                                      generators=generators) | ||||
|  | ||||
|         is_prompt = (seq_group_metadata_list[0].is_prompt | ||||
|                      if seq_group_metadata_list else None) | ||||
|         return dataclasses.replace(model_input, | ||||
|                                    sampling_metadata=sampling_metadata, | ||||
|                                    virtual_engine=virtual_engine, | ||||
|                                    is_prompt=is_prompt) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         model_input: ModelInputForCPUWithSamplingMetadata, | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         intermediate_tensors: Optional[IntermediateTensors] = None, | ||||
|         num_steps: int = 1, | ||||
|         previous_hidden_states: Optional[torch.Tensor] = None, | ||||
|     ) -> Optional[List[SamplerOutput]]: | ||||
|         if num_steps > 1: | ||||
|             raise ValueError( | ||||
|                 "CPU worker does not support multi-step execution.") | ||||
|  | ||||
|         if self.lora_config: | ||||
|             assert model_input.lora_requests is not None | ||||
|             assert model_input.lora_mapping is not None | ||||
|             self.set_active_loras(model_input.lora_requests, | ||||
|                                   model_input.lora_mapping) | ||||
|  | ||||
|         model_executable = self.model | ||||
|  | ||||
|         multimodal_kwargs = {} | ||||
|         if model_input.multi_modal_kwargs is not None: | ||||
|             multimodal_kwargs = MultiModalKwargs.as_kwargs( | ||||
|                 model_input.multi_modal_kwargs, | ||||
|                 device=self.device, | ||||
|             ) | ||||
|         execute_model_kwargs = {} | ||||
|         if previous_hidden_states is not None: | ||||
|             execute_model_kwargs.update( | ||||
|                 {"previous_hidden_states": previous_hidden_states}) | ||||
|  | ||||
|         with set_forward_context(model_input.attn_metadata, self.vllm_config, | ||||
|                                  model_input.virtual_engine): | ||||
|             hidden_states = model_executable( | ||||
|                 input_ids=model_input.input_tokens, | ||||
|                 positions=model_input.input_positions, | ||||
|                 intermediate_tensors=intermediate_tensors, | ||||
|                 **execute_model_kwargs, | ||||
|                 **multimodal_kwargs, | ||||
|             ) | ||||
|  | ||||
|         # Compute the logits. | ||||
|         logits = self.model.compute_logits(hidden_states, | ||||
|                                            model_input.sampling_metadata) | ||||
|  | ||||
|         # Only perform sampling in the driver worker. | ||||
|         if not self.is_driver_worker: | ||||
|             return [] | ||||
|  | ||||
|         # Sample the next token. | ||||
|         output = self.sampler( | ||||
|             logits=logits, | ||||
|             sampling_metadata=model_input.sampling_metadata, | ||||
|         ) | ||||
|         if self.return_hidden_states: | ||||
|             # we only need to pass hidden states of most recent token | ||||
|             if model_input.is_prompt: | ||||
|                 output.prefill_hidden_states = hidden_states | ||||
|             output.hidden_states = hidden_states | ||||
|         return [output] | ||||
|  | ||||
|     def generate_proposals(self, *args, **kwargs): | ||||
|         return self.model.generate_proposals(*args, **kwargs) | ||||
							
								
								
									
										125
									
								
								vllm/worker/cpu_pooling_model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								vllm/worker/cpu_pooling_model_runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,125 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import dataclasses | ||||
| from typing import Any, Dict, List, Optional, Tuple, Type, Union | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from vllm.forward_context import set_forward_context | ||||
| from vllm.model_executor.pooling_metadata import PoolingMetadata | ||||
| from vllm.multimodal import MultiModalKwargs | ||||
| from vllm.pooling_params import PoolingParams | ||||
| from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, | ||||
|                            SequenceGroupMetadata) | ||||
| from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, | ||||
|                                           ModelInputForCPUBuilder) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass(frozen=True) | ||||
| class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): | ||||
|     """ | ||||
|     Used by the CPUPoolingModelRunner. | ||||
|     """ | ||||
|     pooling_metadata: Optional["PoolingMetadata"] = None | ||||
|  | ||||
|  | ||||
| class CPUPoolingModelRunner( | ||||
|         CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): | ||||
|     _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( | ||||
|         ModelInputForCPUWithPoolingMetadata) | ||||
|     _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         model_input: ModelInputForCPUWithPoolingMetadata, | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         intermediate_tensors: Optional[IntermediateTensors] = None, | ||||
|         num_steps: int = 1, | ||||
|     ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: | ||||
|         if num_steps > 1: | ||||
|             raise ValueError( | ||||
|                 "CPU worker does not support multi-step execution.") | ||||
|  | ||||
|         model_executable = self.model | ||||
|         cross_enc_kwargs = {} | ||||
|         if model_input.token_type_ids is not None: | ||||
|             cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids | ||||
|         execute_model_kwargs = { | ||||
|             "input_ids": | ||||
|             model_input.input_tokens, | ||||
|             "positions": | ||||
|             model_input.input_positions, | ||||
|             **MultiModalKwargs.as_kwargs( | ||||
|                 model_input.multi_modal_kwargs or {}, | ||||
|                 device=self.device, | ||||
|             ), | ||||
|             **cross_enc_kwargs, | ||||
|             "intermediate_tensors": | ||||
|             intermediate_tensors, | ||||
|         } | ||||
|  | ||||
|         with set_forward_context(model_input.attn_metadata, self.vllm_config, | ||||
|                                  model_input.virtual_engine): | ||||
|             hidden_states = model_executable(**execute_model_kwargs) | ||||
|  | ||||
|         # Only perform pooling in the driver worker. | ||||
|         if not self.is_driver_worker: | ||||
|             return [] | ||||
|  | ||||
|         return [ | ||||
|             self.model.pooler(hidden_states=hidden_states, | ||||
|                               pooling_metadata=model_input.pooling_metadata) | ||||
|         ] | ||||
|  | ||||
|     def make_model_input_from_broadcasted_tensor_dict( | ||||
|             self, | ||||
|             tensor_dict: Dict[str, | ||||
|                               Any]) -> ModelInputForCPUWithPoolingMetadata: | ||||
|         return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict( | ||||
|             tensor_dict, | ||||
|             attn_backend=self.attn_backend, | ||||
|         ) | ||||
|  | ||||
|     def prepare_model_input( | ||||
|         self, | ||||
|         seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | ||||
|         virtual_engine: int = 0, | ||||
|         finished_requests_ids: Optional[List[str]] = None | ||||
|     ) -> ModelInputForCPUWithPoolingMetadata: | ||||
|         assert seq_group_metadata_list is not None | ||||
|         model_input = self._prepare_model_input_tensors( | ||||
|             seq_group_metadata_list, finished_requests_ids) | ||||
|         # Prepare PoolingMetadata. | ||||
|         assert model_input.seq_lens is not None | ||||
|         pooling_metadata = self._prepare_pooling(seq_group_metadata_list, | ||||
|                                                  model_input.seq_lens) | ||||
|  | ||||
|         return dataclasses.replace(model_input, | ||||
|                                    virtual_engine=virtual_engine, | ||||
|                                    pooling_metadata=pooling_metadata) | ||||
|  | ||||
|     def _prepare_pooling( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         prompt_lens: List[int], | ||||
|     ) -> PoolingMetadata: | ||||
|         """Prepare PoolingMetadata for the sequence group metadata list.""" | ||||
|         seq_groups: List[Tuple[List[int], PoolingParams]] = [] | ||||
|         for i, seq_group_metadata in enumerate(seq_group_metadata_list): | ||||
|             seq_ids = list(seq_group_metadata.seq_data.keys()) | ||||
|             pooling_params = seq_group_metadata.pooling_params | ||||
|             seq_groups.append((seq_ids, pooling_params)) | ||||
|  | ||||
|         seq_data: Dict[int, SequenceData] = {} | ||||
|         for seq_group_metadata in seq_group_metadata_list: | ||||
|             seq_data.update(seq_group_metadata.seq_data) | ||||
|  | ||||
|         pooling_metadata = PoolingMetadata( | ||||
|             seq_groups=seq_groups, | ||||
|             seq_data=seq_data, | ||||
|             prompt_lens=prompt_lens, | ||||
|         ) | ||||
|  | ||||
|         return pooling_metadata | ||||
							
								
								
									
										452
									
								
								vllm/worker/cpu_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										452
									
								
								vllm/worker/cpu_worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,452 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| """A CPU worker class.""" | ||||
| import os | ||||
| from importlib import util | ||||
| from typing import List, Optional, Set, Tuple, Type | ||||
|  | ||||
| import torch | ||||
| import torch.distributed | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.attention import get_attn_backend | ||||
| from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, | ||||
|                          ParallelConfig, VllmConfig) | ||||
| from vllm.distributed import (ensure_model_parallel_initialized, | ||||
|                               init_distributed_environment) | ||||
| from vllm.logger import init_logger | ||||
| from vllm.lora.request import LoRARequest | ||||
| from vllm.model_executor import set_random_seed | ||||
| from vllm.sequence import ExecuteModelRequest | ||||
| from vllm.utils import bind_kv_cache | ||||
| from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner | ||||
| from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase | ||||
| from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner | ||||
| from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, | ||||
|                                      WorkerInput) | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class CPUCacheEngine: | ||||
|     """Manages the KV cache for CPU backend. | ||||
|  | ||||
|     This class is responsible for initializing and managing CPU KV | ||||
|     caches. It also provides methods for performing KV cache operations, such | ||||
|     as copying. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, | ||||
|                  parallel_config: ParallelConfig, | ||||
|                  device_config: DeviceConfig) -> None: | ||||
|         assert device_config.device_type == "cpu" | ||||
|         self.cache_config = cache_config | ||||
|         self.model_config = model_config | ||||
|         self.parallel_config = parallel_config | ||||
|  | ||||
|         self.head_size = model_config.get_head_size() | ||||
|         self.num_layers = model_config.get_num_layers(parallel_config) | ||||
|         self.num_heads = model_config.get_num_kv_heads(parallel_config) | ||||
|  | ||||
|         self.block_size = cache_config.block_size | ||||
|         # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks | ||||
|         # for CPU backend, because we want to reuse KV cache management | ||||
|         # in the scheduler. | ||||
|         self.num_cpu_blocks = cache_config.num_gpu_blocks | ||||
|  | ||||
|         self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, | ||||
|                                                        model_config) | ||||
|  | ||||
|         # Get attention backend. | ||||
|         self.attn_backend = get_attn_backend( | ||||
|             self.model_config.get_head_size(), | ||||
|             self.model_config.dtype, | ||||
|             cache_config.cache_dtype, | ||||
|             self.block_size, | ||||
|             self.model_config.is_attention_free, | ||||
|             use_mla=self.model_config.use_mla, | ||||
|         ) | ||||
|  | ||||
|         # Initialize the cache. | ||||
|         self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) | ||||
|  | ||||
|     def _allocate_kv_cache( | ||||
|         self, | ||||
|         num_blocks: int, | ||||
|     ) -> List[torch.Tensor]: | ||||
|         """Allocates KV cache on CPU.""" | ||||
|         kv_cache_shape = self.attn_backend.get_kv_cache_shape( | ||||
|             num_blocks, self.block_size, self.num_heads, self.head_size) | ||||
|         kv_cache: List[torch.Tensor] = [] | ||||
|         for _ in range(self.num_layers): | ||||
|             kv_cache.append( | ||||
|                 torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) | ||||
|         return kv_cache | ||||
|  | ||||
|     def swap_in(self, src_to_dst: torch.Tensor) -> None: | ||||
|         raise NotImplementedError("Swap is not supported in CPUCacheEngine.") | ||||
|  | ||||
|     def swap_out(self, src_to_dst: torch.Tensor) -> None: | ||||
|         raise NotImplementedError("Swap is not supported in CPUCacheEngine.") | ||||
|  | ||||
|     def copy(self, src_to_dsts: torch.Tensor) -> None: | ||||
|         self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_kv_cache_dtype(cache_config: CacheConfig, | ||||
|                            model_config: ModelConfig): | ||||
|         if cache_config.cache_dtype == "auto": | ||||
|             return model_config.dtype | ||||
|         elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: | ||||
|             return torch.float8_e5m2 | ||||
|         else: | ||||
|             raise NotImplementedError(f"Unsupported KV cache type " | ||||
|                                       f"{cache_config.cache_dtype}.") | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_cache_block_size( | ||||
|         cache_config: CacheConfig, | ||||
|         model_config: ModelConfig, | ||||
|         parallel_config: ParallelConfig, | ||||
|     ) -> int: | ||||
|         head_size = model_config.get_head_size() | ||||
|         num_heads = model_config.get_num_kv_heads(parallel_config) | ||||
|         num_layers = model_config.get_num_layers(parallel_config) | ||||
|  | ||||
|         key_cache_block = cache_config.block_size * num_heads * head_size | ||||
|         value_cache_block = key_cache_block if not model_config.use_mla else 0 | ||||
|         total = num_layers * (key_cache_block + value_cache_block) | ||||
|         dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config) | ||||
|         dtype_size = torch.tensor([], dtype=dtype).element_size() | ||||
|         return dtype_size * total | ||||
|  | ||||
|  | ||||
| class CPUWorker(LocalOrDistributedWorkerBase): | ||||
|     """A worker class that executes (a partition of) the model on a CPU socket. | ||||
|  | ||||
|     Each worker is associated with a single CPU socket. The worker is  | ||||
|     responsible for maintaining the KV cache and executing the model on the  | ||||
|     CPU. In case of distributed inference, each worker is assigned a partition | ||||
|     of the model. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         local_rank: int, | ||||
|         rank: int, | ||||
|         distributed_init_method: str, | ||||
|         kv_cache_dtype: Optional[str] = "auto", | ||||
|         is_driver_worker: bool = False, | ||||
|         model_runner_cls: Optional[Type[CPUModelRunner]] = None, | ||||
|     ) -> None: | ||||
|         WorkerBase.__init__(self, vllm_config=vllm_config) | ||||
|  | ||||
|         self.local_rank = local_rank | ||||
|         self.rank = rank | ||||
|         vllm_config.parallel_config.rank = rank | ||||
|  | ||||
|         self.distributed_init_method = distributed_init_method | ||||
|  | ||||
|         self.is_driver_worker = is_driver_worker | ||||
|         if self.is_driver_worker: | ||||
|             assert self.rank == 0, "The driver worker must have rank 0." | ||||
|  | ||||
|         if self.model_config.trust_remote_code: | ||||
|             # note: lazy import to avoid importing torch before initializing | ||||
|             from vllm.utils import init_cached_hf_modules | ||||
|             init_cached_hf_modules() | ||||
|  | ||||
|         # Setup OpenMP threads affinity. | ||||
|         omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND | ||||
|         self.local_omp_cpuid = "all" | ||||
|         if omp_cpuids == "auto": | ||||
|             self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( | ||||
|             ) | ||||
|         else: | ||||
|             self.local_omp_cpuid = omp_cpuids.split("|")[rank] | ||||
|  | ||||
|         # Return hidden states from target model if the draft model is an | ||||
|         # mlp_speculator | ||||
|         speculative_config = self.speculative_config | ||||
|         model_config = self.model_config | ||||
|         speculative_args = {} if speculative_config is None \ | ||||
|             or (speculative_config.draft_model_config.model == | ||||
|                 model_config.model) \ | ||||
|             or (speculative_config.draft_model_config.hf_config.model_type | ||||
|                 not in ["medusa", "mlp_speculator", "eagle"]) \ | ||||
|                     else {"return_hidden_states": True} | ||||
|         ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner | ||||
|         if self.model_config.runner_type == "pooling": | ||||
|             ModelRunnerClass = CPUPoolingModelRunner | ||||
|         elif self.model_config.is_encoder_decoder: | ||||
|             ModelRunnerClass = CPUEncoderDecoderModelRunner | ||||
|         self.model_runner: CPUModelRunnerBase = ModelRunnerClass( | ||||
|             vllm_config=vllm_config, | ||||
|             kv_cache_dtype=kv_cache_dtype, | ||||
|             is_driver_worker=is_driver_worker, | ||||
|             **speculative_args, | ||||
|         ) | ||||
|         if model_runner_cls is not None: | ||||
|             self.model_runner = model_runner_cls(self.model_runner) | ||||
|         # Uninitialized cache engine. Will be initialized by | ||||
|         # initialize_cache. | ||||
|         self.cache_engine: List[CPUCacheEngine] | ||||
|         # Initialize cpu_cache as pooling models don't initialize kv_caches | ||||
|         self.cpu_cache: Optional[List[List[torch.Tensor]]] = None | ||||
|  | ||||
|         # Torch profiler. Enabled and configured through env vars: | ||||
|         # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace | ||||
|         if envs.VLLM_TORCH_PROFILER_DIR: | ||||
|             torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR | ||||
|             logger.info("Profiling enabled. Traces will be saved to: %s", | ||||
|                         torch_profiler_trace_dir) | ||||
|             self.profiler = torch.profiler.profile( | ||||
|                 activities=[ | ||||
|                     torch.profiler.ProfilerActivity.CPU, | ||||
|                 ], | ||||
|                 with_stack=True, | ||||
|                 on_trace_ready=torch.profiler.tensorboard_trace_handler( | ||||
|                     torch_profiler_trace_dir, use_gzip=True)) | ||||
|         else: | ||||
|             self.profiler = None | ||||
|  | ||||
|     def start_profile(self): | ||||
|         if self.profiler is None: | ||||
|             raise RuntimeError("Profiler is not enabled.") | ||||
|         self.profiler.start() | ||||
|  | ||||
|     def stop_profile(self): | ||||
|         if self.profiler is None: | ||||
|             raise RuntimeError("Profiler is not enabled.") | ||||
|         self.profiler.stop() | ||||
|  | ||||
|     def init_device(self) -> None: | ||||
|         if self.local_omp_cpuid != "all": | ||||
|             ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) | ||||
|             if ret: | ||||
|                 logger.info(ret) | ||||
|  | ||||
|         # Note: unique identifier for creating allreduce shared memory | ||||
|         os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( | ||||
|             ":")[-1] | ||||
|         self.device = torch.device("cpu") | ||||
|         self.init_distributed_environment() | ||||
|         # Set random seed. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|  | ||||
|     def load_model(self): | ||||
|         self.model_runner.load_model() | ||||
|  | ||||
|     def determine_num_available_blocks(self) -> Tuple[int, int]: | ||||
|         """Determine the number of blocks available for the KV cache. | ||||
|  | ||||
|         This determines how many KV blocks can fit into the configured CPU | ||||
|         KV cache space. | ||||
|  | ||||
|         Note that since vLLM assumes a block resides on GPU if it can be | ||||
|         modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. | ||||
|         This allows us to reuse the scheduler of vLLM without generalizing it | ||||
|         to different devices. | ||||
|         """ | ||||
|         # For CPU device, the block number will be calculated based on the | ||||
|         # cpu_kvcache_space. | ||||
|         cache_block_size = self.get_cache_block_size_bytes() | ||||
|         num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // | ||||
|                              cache_block_size) | ||||
|         num_cpu_blocks = max(num_cpu_blocks, 0) | ||||
|  | ||||
|         # Note: To reuse the cache management procedure, | ||||
|         # use cpu cache as 'gpu cache'. | ||||
|         num_gpu_blocks = num_cpu_blocks | ||||
|         num_cpu_blocks = 0 | ||||
|         return num_gpu_blocks, num_cpu_blocks | ||||
|  | ||||
|     def initialize_cache(self, num_gpu_blocks: int, | ||||
|                          num_cpu_blocks: int) -> None: | ||||
|         """Initialize the KV cache. Currently, swappable CPU memory is not | ||||
|         supported. | ||||
|  | ||||
|         Since this worker does not support GPUs, we use the num_gpu_blocks to | ||||
|         determine how many non-swappable CPU blocks to allocate. | ||||
|         """ | ||||
|         assert (num_cpu_blocks == 0 | ||||
|                 ), f"{type(self)} does not support swappable cache" | ||||
|  | ||||
|         # Note: To reuse the cache management procedure, | ||||
|         # use cpu cache as 'gpu cache'. | ||||
|         num_cpu_blocks = num_gpu_blocks | ||||
|  | ||||
|         self._validate_num_cpu_blocks(num_cpu_blocks) | ||||
|         self.cache_config.num_gpu_blocks = num_cpu_blocks | ||||
|         self.cache_config.num_cpu_blocks = 0 | ||||
|  | ||||
|         # Initialize the cache. | ||||
|         self._init_cache_engine() | ||||
|  | ||||
|     def add_lora(self, lora_request: LoRARequest) -> bool: | ||||
|         return self.model_runner.add_lora(lora_request) | ||||
|  | ||||
|     def remove_lora(self, lora_id: int) -> bool: | ||||
|         return self.model_runner.remove_lora(lora_id) | ||||
|  | ||||
|     def pin_lora(self, lora_id: int) -> bool: | ||||
|         return self.model_runner.pin_lora(lora_id) | ||||
|  | ||||
|     def list_loras(self) -> Set[int]: | ||||
|         return self.model_runner.list_loras() | ||||
|  | ||||
|     def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: | ||||
|         """Raise errors if the num_cpu_blocks is invalid. | ||||
|         """ | ||||
|         if num_cpu_blocks <= 0: | ||||
|             raise ValueError("No available memory for the cache blocks. " | ||||
|                              "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " | ||||
|                              "initializing the engine.") | ||||
|  | ||||
|         max_seq_len = self.cache_config.block_size * num_cpu_blocks | ||||
|         if self.model_config.max_model_len > max_seq_len: | ||||
|             raise ValueError( | ||||
|                 f"The model's max seq len ({self.model_config.max_model_len}) " | ||||
|                 "is larger than the maximum number of tokens that can be " | ||||
|                 f"stored in KV cache ({max_seq_len}). Try increasing " | ||||
|                 "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " | ||||
|                 "initializing the engine.") | ||||
|  | ||||
|     def _init_cache_engine(self) -> None: | ||||
|         self.cache_engine = [ | ||||
|             CPUCacheEngine(self.cache_config, self.model_config, | ||||
|                            self.parallel_config, self.device_config) | ||||
|             for _ in range(self.parallel_config.pipeline_parallel_size) | ||||
|         ] | ||||
|         self.cpu_cache = [ | ||||
|             self.cache_engine[ve].cpu_cache | ||||
|             for ve in range(self.parallel_config.pipeline_parallel_size) | ||||
|         ] | ||||
|         bind_kv_cache(self.compilation_config.static_forward_context, | ||||
|                       self.cpu_cache) | ||||
|         self.model_runner.block_size = self.cache_engine[0].block_size | ||||
|  | ||||
|         assert all( | ||||
|             self.cpu_cache[ve] is not None | ||||
|             for ve in range(self.parallel_config.pipeline_parallel_size)) | ||||
|  | ||||
|         # Populate the cache to warmup the memory | ||||
|         for ve in range(self.parallel_config.pipeline_parallel_size): | ||||
|             for layer_cache in self.cpu_cache[ve]: | ||||
|                 layer_cache.fill_(0) | ||||
|  | ||||
|     @property | ||||
|     def do_metadata_broadcast(self) -> bool: | ||||
|         return self.parallel_config.tensor_parallel_size > 1 | ||||
|  | ||||
|     @property | ||||
|     def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: | ||||
|         return self.cpu_cache | ||||
|  | ||||
|     @property | ||||
|     def vocab_size(self) -> int: | ||||
|         return self.model_runner.vocab_size | ||||
|  | ||||
|     @property | ||||
|     def max_model_len(self) -> int: | ||||
|         return self.model_config.max_model_len | ||||
|  | ||||
|     def execute_worker( | ||||
|         self, | ||||
|         worker_input: WorkerInput, | ||||
|     ) -> None: | ||||
|         if (worker_input.blocks_to_copy is not None | ||||
|                 and worker_input.blocks_to_copy.numel() > 0): | ||||
|             self.cache_engine[worker_input.virtual_engine].copy( | ||||
|                 worker_input.blocks_to_copy) | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def prepare_worker_input( | ||||
|             self, execute_model_req: ExecuteModelRequest) -> WorkerInput: | ||||
|         assert execute_model_req is not None | ||||
|         virtual_engine: int = execute_model_req.virtual_engine | ||||
|         num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) | ||||
|         blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, | ||||
|                                       device="cpu", | ||||
|                                       dtype=torch.int64).view(-1, 2) | ||||
|         assert len(execute_model_req.blocks_to_swap_in) == 0 | ||||
|         assert len(execute_model_req.blocks_to_swap_out) == 0 | ||||
|         return WorkerInput( | ||||
|             num_seq_groups=num_seq_groups, | ||||
|             blocks_to_copy=blocks_to_copy, | ||||
|             virtual_engine=virtual_engine, | ||||
|         ) | ||||
|  | ||||
|     def init_distributed_environment(self) -> None: | ||||
|         """Initialize the distributed environment.""" | ||||
|  | ||||
|         parallel_config = self.parallel_config | ||||
|         rank = self.rank | ||||
|         distributed_init_method = self.distributed_init_method | ||||
|         init_distributed_environment( | ||||
|             world_size=parallel_config.world_size, | ||||
|             rank=rank, | ||||
|             distributed_init_method=distributed_init_method, | ||||
|             backend="gloo", | ||||
|         ) | ||||
|  | ||||
|         # A small all_reduce for warmup. | ||||
|         torch.distributed.all_reduce(torch.zeros(1).cpu()) | ||||
|  | ||||
|         ensure_model_parallel_initialized( | ||||
|             parallel_config.tensor_parallel_size, | ||||
|             parallel_config.pipeline_parallel_size) | ||||
|  | ||||
|     def get_cache_block_size_bytes(self) -> int: | ||||
|         """Return the size in bytes of a single KV cache block. | ||||
|         """ | ||||
|         return CPUCacheEngine.get_cache_block_size(self.cache_config, | ||||
|                                                    self.model_config, | ||||
|                                                    self.parallel_config) | ||||
|  | ||||
|     def get_cpus_id_binding_based_on_numa_nodes(self) -> str: | ||||
|         """Return CPUs id binding based on NUMA nodes. | ||||
|         """ | ||||
|         rank_to_cpus = self.local_omp_cpuid | ||||
|         # Setup OpenMP thread affinity based on NUMA nodes automatically | ||||
|         world_size = self.vllm_config.parallel_config.world_size | ||||
|         libnuma_found = util.find_spec("numa") is not None | ||||
|         psutil_found = util.find_spec("psutil") is not None | ||||
|         if libnuma_found and psutil_found: | ||||
|             import psutil | ||||
|             from numa import info | ||||
|             cpu_count = psutil.cpu_count(logical=False) | ||||
|             cpus_allow_list = psutil.Process().cpu_affinity() | ||||
|             numa_size = info.get_num_configured_nodes() | ||||
|             cpu_count_per_numa = cpu_count // numa_size | ||||
|             num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU, | ||||
|                                       cpu_count_per_numa // 2) | ||||
|  | ||||
|             # check allow node_to_cpus list | ||||
|             node_to_cpus = [] | ||||
|             for i in range(numa_size): | ||||
|                 node_intersect = set( | ||||
|                     info.node_to_cpus(i)).intersection(cpus_allow_list) | ||||
|                 if bool(node_intersect): | ||||
|                     node_to_cpus.append(list(node_intersect)) | ||||
|  | ||||
|             if world_size > len(node_to_cpus): | ||||
|                 logger.error( | ||||
|                     "Auto thread-binding failed due to " | ||||
|                     "world size: %d is larger than " | ||||
|                     "allowed NUMA nodes number: %d." | ||||
|                     "Please try to bind threads manually.", world_size, | ||||
|                     len(node_to_cpus)) | ||||
|             else: | ||||
|                 end = cpu_count_per_numa - num_of_reserved_cpu | ||||
|                 rank_to_cpus_list = node_to_cpus[self.rank][:end] | ||||
|                 rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list) | ||||
|                 logger.info("auto thread-binding list: %s", rank_to_cpus) | ||||
|         else: | ||||
|             logger.warning( | ||||
|                 "Auto thread-binding is not supported due to " | ||||
|                 "the lack of package numa and psutil," | ||||
|                 "fallback to no thread-binding. To get better performance," | ||||
|                 "please try to manually bind threads.") | ||||
|         return rank_to_cpus | ||||
							
								
								
									
										108
									
								
								vllm/worker/multi_step_tpu_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								vllm/worker/multi_step_tpu_worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import dataclasses | ||||
| from typing import Dict, Optional, Tuple | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from vllm.distributed import broadcast_tensor_dict | ||||
| from vllm.sequence import ExecuteModelRequest | ||||
| from vllm.worker.tpu_model_runner import ModelInputForTPU | ||||
| from vllm.worker.tpu_worker import TPUWorker | ||||
| from vllm.worker.worker_base import WorkerInput | ||||
|  | ||||
|  | ||||
| class MultiStepTPUWorker(TPUWorker): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.cached_model_input: Optional[ModelInputForTPU] = None | ||||
|  | ||||
|     def _get_driver_input_and_broadcast( | ||||
|         self, execute_model_req: ExecuteModelRequest | ||||
|     ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: | ||||
|         assert self.is_driver_worker | ||||
|         assert execute_model_req.virtual_engine == 0 | ||||
|  | ||||
|         is_first_multi_step = execute_model_req.is_first_multi_step | ||||
|         is_last_step = execute_model_req.is_last_step | ||||
|         if is_first_multi_step: | ||||
|             worker_input: WorkerInput = self.prepare_worker_input( | ||||
|                 execute_model_req=execute_model_req) | ||||
|             worker_input = dataclasses.replace( | ||||
|                 worker_input, | ||||
|                 num_steps=execute_model_req.num_lookahead_slots + 1) | ||||
|             model_input: ModelInputForTPU = ( | ||||
|                 self.model_runner.prepare_model_input( | ||||
|                     execute_model_req.seq_group_metadata_list, | ||||
|                     execute_model_req.virtual_engine, | ||||
|                     execute_model_req.finished_requests_ids)) | ||||
|  | ||||
|             if execute_model_req.async_callback: | ||||
|                 model_input = dataclasses.replace( | ||||
|                     model_input, | ||||
|                     async_callback=execute_model_req.async_callback) | ||||
|         else: | ||||
|             assert self.cached_model_input is not None | ||||
|             model_input = self.cached_model_input | ||||
|             worker_input = WorkerInput() | ||||
|         model_input = dataclasses.replace( | ||||
|             model_input, | ||||
|             is_first_multi_step=is_first_multi_step, | ||||
|             is_last_step=is_last_step) | ||||
|  | ||||
|         if self.do_metadata_broadcast: | ||||
|             if is_first_multi_step: | ||||
|                 broadcast_data = worker_input.as_broadcastable_tensor_dict() | ||||
|                 broadcast_data.update( | ||||
|                     model_input.as_broadcastable_tensor_dict()) | ||||
|                 broadcast_tensor_dict(broadcast_data, src=0) | ||||
|             else: | ||||
|                 broadcast_data = { | ||||
|                     "is_first_multi_step": is_first_multi_step, | ||||
|                     "is_last_step": is_last_step, | ||||
|                 } | ||||
|                 broadcast_tensor_dict(broadcast_data, src=0) | ||||
|  | ||||
|         # Retuning empty dict here to keep this compatible with | ||||
|         # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` | ||||
|         return model_input, worker_input, {} | ||||
|  | ||||
|     def prepare_input( | ||||
|         self, | ||||
|         execute_model_req: Optional[ExecuteModelRequest] = None, | ||||
|     ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, | ||||
|                                                             torch.Tensor]]]: | ||||
|         if self.is_driver_worker: | ||||
|             if execute_model_req is None: | ||||
|                 if self.do_metadata_broadcast: | ||||
|                     broadcast_tensor_dict({}, src=0) | ||||
|                 return None | ||||
|  | ||||
|             model_input, worker_input, _ = self._get_driver_input_and_broadcast( | ||||
|                 execute_model_req) | ||||
|             if model_input.is_first_multi_step: | ||||
|                 self.cached_model_input = model_input | ||||
|             return model_input, worker_input, {} | ||||
|         else: | ||||
|             broadcast_data = broadcast_tensor_dict(src=0) | ||||
|             if not broadcast_data: | ||||
|                 return None | ||||
|  | ||||
|             if len(broadcast_data) == 2: | ||||
|                 assert self.cached_model_input is not None | ||||
|                 self.cached_model_input = dataclasses.replace( | ||||
|                     self.cached_model_input, | ||||
|                     is_first_multi_step=broadcast_data["is_first_multi_step"], | ||||
|                     is_last_step=broadcast_data["is_last_step"]) | ||||
|                 empty_worker_input = WorkerInput() | ||||
|                 return self.cached_model_input, empty_worker_input, {} | ||||
|  | ||||
|             worker_input = WorkerInput.from_broadcasted_tensor_dict( | ||||
|                 broadcast_data) | ||||
|             model_input = ( | ||||
|                 self.model_runner. | ||||
|                 make_model_input_from_broadcasted_tensor_dict(broadcast_data)) | ||||
|             self.cached_model_input = model_input | ||||
|             return model_input, worker_input, {} | ||||
							
								
								
									
										909
									
								
								vllm/worker/tpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										909
									
								
								vllm/worker/tpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,909 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import enum | ||||
| import time | ||||
| from dataclasses import dataclass | ||||
| from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, | ||||
|                     Type, Union) | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch_xla.core.xla_model as xm | ||||
| import torch_xla.runtime as xr | ||||
|  | ||||
| from vllm.attention import AttentionMetadata, get_attn_backend | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.forward_context import get_forward_context, set_forward_context | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.sampler import SamplerOutput | ||||
| from vllm.model_executor.model_loader import get_model | ||||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||||
| from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, | ||||
|                            Logprob, SequenceGroupMetadata, SequenceOutput) | ||||
| from vllm.worker.model_runner_base import ( | ||||
|     ModelRunnerBase, ModelRunnerInputBase, | ||||
|     _add_attn_metadata_broadcastable_dict, | ||||
|     _init_attn_metadata_from_tensor_dict) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.attention.backends.abstract import AttentionBackend | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| # Here we utilize the behavior that out-of-bound index is ignored. | ||||
| # FIXME(woosuk): Find a more reliable way to prevent possible bugs. | ||||
| _PAD_SLOT_ID = 1_000_000_000 | ||||
| # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. | ||||
| _ENABLE_TOP_P = False | ||||
| # FIXME(woosuk): A temporary hack to support `n > 1`. | ||||
| # This can significantly affect the performance if too large. | ||||
| _MAX_NUM_SAMPLES = 128 | ||||
|  | ||||
|  | ||||
| class ExecutionMode(enum.Enum): | ||||
|     PREFILL = enum.auto() | ||||
|     DECODE = enum.auto() | ||||
|     PREFIX_PREFILL = enum.auto() | ||||
|  | ||||
|     def is_prefill(self) -> bool: | ||||
|         return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class ModelInputForTPU(ModelRunnerInputBase): | ||||
|     token_ids: torch.Tensor | ||||
|     position_ids: torch.Tensor | ||||
|     attn_metadata: AttentionMetadata | ||||
|     input_lens: torch.Tensor | ||||
|     t: torch.Tensor | ||||
|     p: torch.Tensor | ||||
|     num_samples: int | ||||
|     n: List[int] | ||||
|     seq_groups: List[List[int]] | ||||
|     is_first_multi_step: bool = True | ||||
|     is_last_step: bool = True | ||||
|     virtual_engine: int = 0 | ||||
|     async_callback: Optional[Callable] = None | ||||
|  | ||||
|     def as_broadcastable_tensor_dict( | ||||
|             self) -> Dict[str, Union[int, torch.Tensor]]: | ||||
|         tensor_dict = { | ||||
|             "token_ids": self.token_ids, | ||||
|             "position_ids": self.position_ids, | ||||
|             "input_lens": self.input_lens, | ||||
|             "t": self.t, | ||||
|             "p": self.p, | ||||
|             "num_samples": self.num_samples, | ||||
|             "n": self.n, | ||||
|             "seq_groups": self.seq_groups, | ||||
|             "is_first_multi_step": self.is_first_multi_step, | ||||
|             "is_last_step": self.is_last_step, | ||||
|             "virtual_engine": self.virtual_engine, | ||||
|         } | ||||
|         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||||
|         return tensor_dict | ||||
|  | ||||
|     @classmethod | ||||
|     def from_broadcasted_tensor_dict( | ||||
|         cls: Type["ModelInputForTPU"], | ||||
|         tensor_dict: Dict[str, Any], | ||||
|         attn_backend: Optional["AttentionBackend"] = None, | ||||
|     ) -> "ModelInputForTPU": | ||||
|         if attn_backend is not None: | ||||
|             tensor_dict = _init_attn_metadata_from_tensor_dict( | ||||
|                 attn_backend, tensor_dict) | ||||
|         return cls(**tensor_dict) | ||||
|  | ||||
|  | ||||
| class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         is_driver_worker: bool = False, | ||||
|     ): | ||||
|         ModelRunnerBase.__init__(self, vllm_config=vllm_config) | ||||
|         self.is_driver_worker = is_driver_worker | ||||
|  | ||||
|         self.block_size = self.cache_config.block_size | ||||
|         self.max_num_blocks_per_seq = (self.model_config.max_model_len // | ||||
|                                        self.block_size) | ||||
|         self.block_tables = np.zeros( | ||||
|             (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), | ||||
|             dtype=np.int32) | ||||
|         self.attn_backend = get_attn_backend( | ||||
|             self.model_config.get_head_size(), | ||||
|             self.model_config.dtype, | ||||
|             self.cache_config.cache_dtype, | ||||
|             self.block_size, | ||||
|             self.model_config.is_attention_free, | ||||
|             False, | ||||
|         ) | ||||
|         self.cached_step_outputs: List[torch.Tensor] = [] | ||||
|  | ||||
|         smem_size = 512 * 1024 | ||||
|         block_table_size = 4 * self.block_tables.size | ||||
|         if block_table_size >= smem_size: | ||||
|             logger.warning( | ||||
|                 "The max_model_len (%d) is too large. This may degrade the " | ||||
|                 "performance due to the insufficient smem size. Consider " | ||||
|                 "setting --max-model-len to a smaller value, like %d.", | ||||
|                 self.model_config.max_model_len, | ||||
|                 self.model_config.max_model_len / | ||||
|                 (block_table_size / smem_size)) | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         self.device = self.device_config.device | ||||
|  | ||||
|         # NOTE(woosuk): While the executor assigns the TP ranks to the worker | ||||
|         # process, the ranks can be different from the ranks internally assigned | ||||
|         # by the xm runtime. Therefore, there is a mismatch in the rank | ||||
|         # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. | ||||
|         # This is not a problem in linear layers because all-reduce is | ||||
|         # rank-agnostic. However, it matters for all-gather as the ranks | ||||
|         # determine the order of concatenating the output tensors. | ||||
|         # As a workaround, we use the xm's rank assignment only when loading | ||||
|         # the embedding weights. | ||||
|         xm_tp_rank = xr.global_ordinal() | ||||
|         with patch( | ||||
|                 "vllm.model_executor.layers.vocab_parallel_embedding." | ||||
|                 "get_tensor_model_parallel_rank", | ||||
|                 return_value=xm_tp_rank): | ||||
|             model = get_model(vllm_config=self.vllm_config) | ||||
|         model = model.eval() | ||||
|         xm.wait_device_ops() | ||||
|         model = ModelWrapper(model) | ||||
|         self.model = torch.compile(model, | ||||
|                                    backend="openxla", | ||||
|                                    fullgraph=True, | ||||
|                                    dynamic=False) | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         return self.model.model | ||||
|  | ||||
|     def _dummy_run( | ||||
|         self, | ||||
|         batch_size: int, | ||||
|         seq_len: int, | ||||
|         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], | ||||
|         exec_mode: ExecutionMode, | ||||
|     ) -> None: | ||||
|         exec_mode = ExecutionMode(exec_mode) | ||||
|         if exec_mode.is_prefill(): | ||||
|             seq_len = (seq_len + 15) // 16 * 16 | ||||
|             token_ids = torch.zeros((batch_size, seq_len), | ||||
|                                     dtype=torch.int32, | ||||
|                                     device=self.device) | ||||
|             position_ids = torch.zeros((batch_size, seq_len), | ||||
|                                        dtype=torch.int32, | ||||
|                                        device=self.device) | ||||
|             slot_mapping = torch.zeros((batch_size, seq_len), | ||||
|                                        dtype=torch.int64, | ||||
|                                        device=self.device) | ||||
|             input_lens = torch.ones((batch_size, ), | ||||
|                                     dtype=torch.int32, | ||||
|                                     device=self.device) | ||||
|             if exec_mode == ExecutionMode.PREFILL: | ||||
|                 attn_metadata = self.attn_backend.make_metadata( | ||||
|                     num_prefills=batch_size, | ||||
|                     num_prefill_tokens=batch_size * seq_len, | ||||
|                     num_decode_tokens=0, | ||||
|                     slot_mapping=slot_mapping, | ||||
|                     multi_modal_placeholder_index_maps=None, | ||||
|                     enable_kv_scales_calculation=False, | ||||
|                     block_tables=None, | ||||
|                     context_lens=None, | ||||
|                     effective_query_lens=None, | ||||
|                 ) | ||||
|             else: | ||||
|                 context_lens = torch.ones((batch_size, ), | ||||
|                                           dtype=torch.int32, | ||||
|                                           device=self.device) | ||||
|                 block_tables = torch.tensor(self.block_tables[:batch_size], | ||||
|                                             dtype=torch.int32, | ||||
|                                             device=self.device) | ||||
|                 effective_query_lens = torch.ones_like(context_lens) | ||||
|                 attn_metadata = self.attn_backend.make_metadata( | ||||
|                     num_prefills=batch_size, | ||||
|                     num_prefill_tokens=batch_size * seq_len, | ||||
|                     num_decode_tokens=0, | ||||
|                     slot_mapping=slot_mapping, | ||||
|                     multi_modal_placeholder_index_maps=None, | ||||
|                     enable_kv_scales_calculation=False, | ||||
|                     block_tables=block_tables, | ||||
|                     context_lens=context_lens, | ||||
|                     effective_query_lens=effective_query_lens, | ||||
|                 ) | ||||
|         else: | ||||
|             assert seq_len == 1 | ||||
|             token_ids = torch.zeros((batch_size, seq_len), | ||||
|                                     dtype=torch.int32, | ||||
|                                     device=self.device) | ||||
|             position_ids = torch.zeros((batch_size, seq_len), | ||||
|                                        dtype=torch.int32, | ||||
|                                        device=self.device) | ||||
|             slot_mapping = torch.zeros((batch_size, seq_len), | ||||
|                                        dtype=torch.int64, | ||||
|                                        device=self.device) | ||||
|             block_tables = torch.zeros( | ||||
|                 (batch_size, self.max_num_blocks_per_seq), | ||||
|                 dtype=torch.int32, | ||||
|                 device=self.device) | ||||
|             context_lens = torch.ones((batch_size, ), | ||||
|                                       dtype=torch.int32, | ||||
|                                       device=self.device) | ||||
|             input_lens = torch.ones((batch_size, ), | ||||
|                                     dtype=torch.int32, | ||||
|                                     device=self.device) | ||||
|             attn_metadata = self.attn_backend.make_metadata( | ||||
|                 num_prefills=0, | ||||
|                 num_prefill_tokens=0, | ||||
|                 num_decode_tokens=batch_size * seq_len, | ||||
|                 slot_mapping=slot_mapping, | ||||
|                 multi_modal_placeholder_index_maps=None, | ||||
|                 enable_kv_scales_calculation=False, | ||||
|                 block_tables=block_tables, | ||||
|                 context_lens=context_lens, | ||||
|             ) | ||||
|         t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) | ||||
|         p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) | ||||
|         num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 | ||||
|  | ||||
|         # NOTE(woosuk): There are two stages of compilation: torch.compile and | ||||
|         # XLA compilation. Using `mark_dynamic` can reduce the torch.compile | ||||
|         # overhead by reusing the FX graph for different shapes. | ||||
|         # However, the XLA graph will still require static shapes and needs to | ||||
|         # be re-compiled for every different shapes. This overhead is inevitable | ||||
|         # in the first run, but can be skipped afterwards as we cache the XLA | ||||
|         # graphs in the disk (VLLM_XLA_CACHE_PATH). | ||||
|         if exec_mode.is_prefill(): | ||||
|             # Prefll | ||||
|             torch._dynamo.mark_dynamic(token_ids, 1) | ||||
|             torch._dynamo.mark_dynamic(position_ids, 1) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) | ||||
|         else: | ||||
|             # Decode | ||||
|             torch._dynamo.mark_dynamic(token_ids, 0) | ||||
|             torch._dynamo.mark_dynamic(position_ids, 0) | ||||
|             torch._dynamo.mark_dynamic(input_lens, 0) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) | ||||
|             torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) | ||||
|             torch._dynamo.mark_dynamic(t, 0) | ||||
|             torch._dynamo.mark_dynamic(p, 0) | ||||
|         # Dummy run. | ||||
|         with set_forward_context(attn_metadata, self.vllm_config, 0): | ||||
|             self.model(token_ids, position_ids, input_lens, t, p, num_samples, | ||||
|                        kv_caches) | ||||
|  | ||||
|     def warmup_model( | ||||
|         self, | ||||
|         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], | ||||
|     ) -> None: | ||||
|         # Prefill | ||||
|         logger.info("Compiling the model with different input shapes...") | ||||
|         start = time.time() | ||||
|         for batch_size in [1]: | ||||
|             seq_len = 16 | ||||
|             while seq_len <= self.model_config.max_model_len: | ||||
|                 self._dummy_run(batch_size, | ||||
|                                 seq_len, | ||||
|                                 kv_caches, | ||||
|                                 exec_mode=ExecutionMode.PREFILL) | ||||
|                 xm.wait_device_ops() | ||||
|                 logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) | ||||
|                 num_tokens = batch_size * seq_len | ||||
|                 if num_tokens >= self.scheduler_config.max_num_batched_tokens: | ||||
|                     break | ||||
|                 seq_len = seq_len * 2 | ||||
|  | ||||
|         end = time.time() | ||||
|         logger.info("Compilation for prefill done in %.2f s.", end - start) | ||||
|  | ||||
|         # Prefix prefill | ||||
|         if self.cache_config.enable_prefix_caching: | ||||
|             logger.info("Compiling the model with different input shapes for " | ||||
|                         "prefix prefill...") | ||||
|             start = time.time() | ||||
|             for batch_size in [1]: | ||||
|                 seq_len = 16 | ||||
|                 while seq_len <= self.model_config.max_model_len: | ||||
|                     self._dummy_run(batch_size, | ||||
|                                     seq_len, | ||||
|                                     kv_caches, | ||||
|                                     exec_mode=ExecutionMode.PREFIX_PREFILL) | ||||
|                     xm.wait_device_ops() | ||||
|                     logger.info("batch_size: %d, seq_len: %d", batch_size, | ||||
|                                 seq_len) | ||||
|                     num_tokens = batch_size * seq_len | ||||
|                     if (num_tokens | ||||
|                             >= self.scheduler_config.max_num_batched_tokens): | ||||
|                         break | ||||
|                     seq_len = seq_len * 2 | ||||
|             end = time.time() | ||||
|             logger.info("Compilation for prefix prefill done in %.2f s.", | ||||
|                         end - start) | ||||
|  | ||||
|         # Decode | ||||
|         start = time.time() | ||||
|         seq_len = 1 | ||||
|         batch_size = 8  # Must be in sync with _get_padded_batch_size() | ||||
|         while True: | ||||
|             self._dummy_run(batch_size, | ||||
|                             seq_len, | ||||
|                             kv_caches, | ||||
|                             exec_mode=ExecutionMode.DECODE) | ||||
|             xm.wait_device_ops() | ||||
|             logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) | ||||
|  | ||||
|             if batch_size >= self.scheduler_config.max_num_seqs: | ||||
|                 break | ||||
|             batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 | ||||
|  | ||||
|         end = time.time() | ||||
|         logger.info("Compilation for decode done in %.2f s.", end - start) | ||||
|  | ||||
|     def _prepare_prompt( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: | ||||
|         assert len(seq_group_metadata_list) > 0 | ||||
|         input_tokens: List[int] = [] | ||||
|         input_positions: List[int] = [] | ||||
|         prompt_lens: List[int] = [] | ||||
|         context_lens: List[int] = [] | ||||
|         slot_mapping: List[int] = [] | ||||
|  | ||||
|         for batch_idx, seq_group_metadata in enumerate( | ||||
|                 seq_group_metadata_list): | ||||
|             assert seq_group_metadata.is_prompt | ||||
|             seq_ids = list(seq_group_metadata.seq_data.keys()) | ||||
|             assert len(seq_ids) == 1 | ||||
|             seq_id = seq_ids[0] | ||||
|  | ||||
|             seq_data = seq_group_metadata.seq_data[seq_id] | ||||
|             # Could include output tokens when a request is preempted. | ||||
|             prompt_tokens = seq_data.get_token_ids() | ||||
|             seq_len = len(prompt_tokens) | ||||
|  | ||||
|             num_computed_blocks = len(seq_group_metadata.computed_block_nums) | ||||
|             num_computed_tokens = num_computed_blocks * self.block_size | ||||
|             if num_computed_tokens > 0: | ||||
|                 prompt_tokens = prompt_tokens[num_computed_tokens:] | ||||
|                 context_lens.append(seq_len) | ||||
|             else: | ||||
|                 context_lens.append(0) | ||||
|  | ||||
|             prompt_len = len(prompt_tokens) | ||||
|             prompt_lens.append(prompt_len) | ||||
|  | ||||
|             input_tokens.extend(prompt_tokens) | ||||
|             input_positions.extend(range(num_computed_tokens, seq_len)) | ||||
|  | ||||
|             assert seq_group_metadata.block_tables is not None | ||||
|             block_table = seq_group_metadata.block_tables[seq_id] | ||||
|             for i in range(num_computed_tokens, seq_len): | ||||
|                 block_number = block_table[i // self.block_size] | ||||
|                 block_offset = i % self.block_size | ||||
|                 slot = block_number * self.block_size + block_offset | ||||
|                 slot_mapping.append(slot) | ||||
|             if num_computed_tokens > 0: | ||||
|                 self.block_tables[batch_idx, :len(block_table)] = block_table | ||||
|  | ||||
|             # Add paddings to EACH prompt to the smallest power of 2 that is | ||||
|             # greater than or equal to the prompt length. | ||||
|             # We pad the seq_len to reduce the compilation overhead. | ||||
|             # We execute each prompt individually (i.e., with batch_size 1) | ||||
|             # because the FlashAttention kernel does not support ragged inputs. | ||||
|             # TODO(woosuk): Use SplashAttention to support ragged inputs. | ||||
|             padded_prompt_len = _get_padded_prefill_len(prompt_len) | ||||
|             num_paddings = padded_prompt_len - prompt_len | ||||
|             input_tokens += [0] * num_paddings | ||||
|             input_positions += [0] * num_paddings | ||||
|             slot_mapping += [_PAD_SLOT_ID] * num_paddings | ||||
|  | ||||
|         assert len(prompt_lens) > 0 | ||||
|         num_prefills = len(prompt_lens) | ||||
|         input_tokens = torch.tensor(input_tokens, | ||||
|                                     dtype=torch.int32, | ||||
|                                     device="cpu") | ||||
|         input_positions = torch.tensor(input_positions, | ||||
|                                        dtype=torch.int32, | ||||
|                                        device="cpu") | ||||
|         slot_mapping = torch.tensor(slot_mapping, | ||||
|                                     dtype=torch.int64, | ||||
|                                     device="cpu") | ||||
|         prompt_lens = torch.tensor(prompt_lens, | ||||
|                                    dtype=torch.int32, | ||||
|                                    device="cpu") | ||||
|         context_lens = torch.tensor(context_lens, | ||||
|                                     dtype=torch.int32, | ||||
|                                     device="cpu") | ||||
|         block_tables = torch.tensor(self.block_tables[:num_prefills], | ||||
|                                     dtype=torch.int32, | ||||
|                                     device="cpu") | ||||
|         attn_metadata = self.attn_backend.make_metadata( | ||||
|             num_prefills=num_prefills, | ||||
|             num_prefill_tokens=0,  # NOTE: This is not used. | ||||
|             num_decode_tokens=0, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=None, | ||||
|             enable_kv_scales_calculation=False, | ||||
|             block_tables=block_tables, | ||||
|             context_lens=context_lens, | ||||
|             effective_query_lens=prompt_lens, | ||||
|         ) | ||||
|         return input_tokens, input_positions, attn_metadata, prompt_lens | ||||
|  | ||||
|     def _prepare_decode( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: | ||||
|         assert len(seq_group_metadata_list) > 0 | ||||
|         input_tokens: List[List[int]] = [] | ||||
|         input_positions: List[List[int]] = [] | ||||
|         slot_mapping: List[List[int]] = [] | ||||
|         context_lens: List[int] = [] | ||||
|  | ||||
|         batch_idx = 0 | ||||
|         for seq_group_metadata in seq_group_metadata_list: | ||||
|             assert not seq_group_metadata.is_prompt | ||||
|             seq_ids = list(seq_group_metadata.seq_data.keys()) | ||||
|             for seq_id in seq_ids: | ||||
|                 seq_data = seq_group_metadata.seq_data[seq_id] | ||||
|                 generation_token = seq_data.get_last_token_id() | ||||
|                 input_tokens.append([generation_token]) | ||||
|  | ||||
|                 seq_len = seq_data.get_len() | ||||
|                 position = seq_len - 1 | ||||
|                 input_positions.append([position]) | ||||
|                 context_lens.append(seq_len) | ||||
|  | ||||
|                 assert seq_group_metadata.block_tables is not None | ||||
|                 block_table = seq_group_metadata.block_tables[seq_id] | ||||
|                 self.block_tables[batch_idx, :len(block_table)] = block_table | ||||
|                 batch_idx += 1 | ||||
|  | ||||
|                 block_number = block_table[position // self.block_size] | ||||
|                 block_offset = position % self.block_size | ||||
|                 slot = block_number * self.block_size + block_offset | ||||
|                 slot_mapping.append([slot]) | ||||
|  | ||||
|         batch_size = _get_padded_batch_size(batch_idx) | ||||
|         num_paddings = batch_size - batch_idx | ||||
|         input_tokens = input_tokens + [[0]] * num_paddings | ||||
|         input_positions = input_positions + [[0]] * num_paddings | ||||
|         slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings | ||||
|         context_lens = context_lens + [0] * num_paddings | ||||
|  | ||||
|         input_tokens = torch.tensor(input_tokens, | ||||
|                                     dtype=torch.int32, | ||||
|                                     device="cpu") | ||||
|         input_positions = torch.tensor(input_positions, | ||||
|                                        dtype=torch.int32, | ||||
|                                        device="cpu") | ||||
|         slot_mapping = torch.tensor(slot_mapping, | ||||
|                                     dtype=torch.int64, | ||||
|                                     device="cpu") | ||||
|         context_lens = torch.tensor(context_lens, | ||||
|                                     dtype=torch.int32, | ||||
|                                     device="cpu") | ||||
|         block_tables = torch.tensor(self.block_tables[:batch_size], | ||||
|                                     dtype=torch.int32, | ||||
|                                     device="cpu") | ||||
|         input_lens = torch.tensor([1] * batch_size, | ||||
|                                   dtype=torch.int32, | ||||
|                                   device="cpu") | ||||
|         attn_metadata = self.attn_backend.make_metadata( | ||||
|             num_prefills=0, | ||||
|             num_prefill_tokens=0, | ||||
|             num_decode_tokens=batch_size, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=None, | ||||
|             enable_kv_scales_calculation=False, | ||||
|             block_tables=block_tables, | ||||
|             context_lens=context_lens, | ||||
|         ) | ||||
|         return input_tokens, input_positions, attn_metadata, input_lens | ||||
|  | ||||
|     def _prepare_sample( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         padded_batch_size: int, | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: | ||||
|         assert len(seq_group_metadata_list) > 0 | ||||
|         t = [] | ||||
|         p = [] | ||||
|         n = [] | ||||
|         for seq_group_metadata in seq_group_metadata_list: | ||||
|             sampling_params = seq_group_metadata.sampling_params | ||||
|             t.append(sampling_params.temperature) | ||||
|             if sampling_params.top_p != 1 and not _ENABLE_TOP_P: | ||||
|                 raise NotImplementedError( | ||||
|                     "Top-p sampling is currently disabled for the TPU backend " | ||||
|                     "due to performance issues.") | ||||
|             p.append(sampling_params.top_p) | ||||
|             if sampling_params.top_k > 0: | ||||
|                 raise NotImplementedError( | ||||
|                     "Top-k sampling is currently disabled for the TPU backend " | ||||
|                     "due to performance issues.") | ||||
|             if sampling_params.n > _MAX_NUM_SAMPLES: | ||||
|                 raise NotImplementedError( | ||||
|                     f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " | ||||
|                     "backend.") | ||||
|             n.append(sampling_params.n) | ||||
|             if sampling_params.logprobs is not None: | ||||
|                 raise NotImplementedError( | ||||
|                     "logprobs is not currently supported by the TPU backend.") | ||||
|             if sampling_params.prompt_logprobs is not None: | ||||
|                 raise NotImplementedError( | ||||
|                     "prompt_logprobs is not currently supported by the TPU " | ||||
|                     "backend.") | ||||
|  | ||||
|             # Repeat the sampling params if the seq group has multiple seqs. | ||||
|             num_seqs = len(seq_group_metadata.seq_data) | ||||
|             t += [t[-1]] * (num_seqs - 1) | ||||
|             p += [p[-1]] * (num_seqs - 1) | ||||
|             n += [n[-1]] * (num_seqs - 1) | ||||
|  | ||||
|         num_paddings = padded_batch_size - len(t) | ||||
|         t += [1.0] * num_paddings | ||||
|         p += [1.0] * num_paddings | ||||
|  | ||||
|         t = torch.tensor(t, dtype=torch.float32, device="cpu") | ||||
|         p = torch.tensor(p, dtype=torch.float32, device="cpu") | ||||
|         return t, p, n | ||||
|  | ||||
|     def prepare_model_input( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         virtual_engine: int = 0, | ||||
|         finished_requests_ids: Optional[List[str]] = None, | ||||
|     ) -> ModelInputForTPU: | ||||
|         del finished_requests_ids  # Unused. | ||||
|         assert virtual_engine == 0 | ||||
|         assert len(seq_group_metadata_list) > 0 | ||||
|         # NOTE: We assume that all sequences in the group are all prompts or | ||||
|         # all decodes. | ||||
|         is_prompt = seq_group_metadata_list[0].is_prompt | ||||
|         if is_prompt: | ||||
|             inputs = self._prepare_prompt(seq_group_metadata_list) | ||||
|         else: | ||||
|             inputs = self._prepare_decode(seq_group_metadata_list) | ||||
|         input_tokens, input_positions, attn_metadata, input_lens = inputs | ||||
|         padded_batch_size = input_tokens.shape[0] | ||||
|         t, p, n = self._prepare_sample(seq_group_metadata_list, | ||||
|                                        padded_batch_size) | ||||
|         num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 | ||||
|  | ||||
|         seq_groups = [ | ||||
|             list(metadata.seq_data.keys()) | ||||
|             for metadata in seq_group_metadata_list | ||||
|         ] | ||||
|         return ModelInputForTPU(input_tokens, input_positions, attn_metadata, | ||||
|                                 input_lens, t, p, num_samples, n, seq_groups) | ||||
|  | ||||
|     def make_model_input_from_broadcasted_tensor_dict( | ||||
|             self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: | ||||
|         model_input = ModelInputForTPU.from_broadcasted_tensor_dict( | ||||
|             tensor_dict, attn_backend=self.attn_backend) | ||||
|         return model_input | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         model_input: ModelInputForTPU, | ||||
|         kv_caches: Optional[List[Any]], | ||||
|         intermediate_tensors: Optional[IntermediateTensors] = None, | ||||
|         num_steps: int = 1, | ||||
|     ) -> List[SamplerOutput]: | ||||
|         assert intermediate_tensors is None | ||||
|         if not model_input.is_first_multi_step: | ||||
|             if not model_input.is_last_step: | ||||
|                 return [] | ||||
|  | ||||
|             use_async_out_proc = model_input.async_callback is not None | ||||
|             sampler_outputs = [] | ||||
|             num_outputs = len(self.cached_step_outputs) | ||||
|             for i in range(num_outputs): | ||||
|                 next_token_ids = self.cached_step_outputs.pop(0) | ||||
|                 next_token_ids = next_token_ids.cpu().tolist() | ||||
|                 sampler_output = _make_decode_output(next_token_ids, | ||||
|                                                      model_input.seq_groups) | ||||
|                 sampler_outputs.append(sampler_output) | ||||
|  | ||||
|                 if i < num_outputs - 1 and use_async_out_proc: | ||||
|                     assert model_input.async_callback is not None | ||||
|                     ctx = model_input.async_callback.keywords[  # type: ignore | ||||
|                         "ctx"] | ||||
|                     ctx.append_output( | ||||
|                         outputs=[sampler_output], | ||||
|                         seq_group_metadata_list=ctx.seq_group_metadata_list, | ||||
|                         scheduler_outputs=ctx.scheduler_outputs, | ||||
|                         is_async=False, | ||||
|                         is_last_step=False, | ||||
|                         is_first_step_output=i == 0) | ||||
|                     model_input.async_callback() | ||||
|             if use_async_out_proc: | ||||
|                 return [sampler_outputs[-1]] | ||||
|             else: | ||||
|                 return sampler_outputs | ||||
|  | ||||
|         is_prompt = model_input.attn_metadata.num_prefills > 0 | ||||
|         if is_prompt: | ||||
|             assert num_steps == 1 | ||||
|             # NOTE(woosuk): Since the FlashAttention kernel does not support | ||||
|             # ragged inputs, we split the prompts into different batches and | ||||
|             # process them separately. This is a temporary hack that should be | ||||
|             # optimized by using SplashAttention. | ||||
|             orig_slot_mapping = model_input.attn_metadata.slot_mapping | ||||
|             orig_block_tables = model_input.attn_metadata.block_tables | ||||
|             orig_context_lens = model_input.attn_metadata.context_lens | ||||
|             orig_effective_query_lens = \ | ||||
|                 model_input.attn_metadata.effective_query_lens | ||||
|             batch_size = model_input.input_lens.shape[0] | ||||
|             start_idx = 0 | ||||
|             next_token_ids = [] | ||||
|             for i in range(batch_size): | ||||
|                 # Get the actual prefill_len. | ||||
|                 prefill_len = model_input.input_lens[i:i + 1].item() | ||||
|                 prefill_len = _get_padded_prefill_len(prefill_len) | ||||
|                 end_idx = start_idx + prefill_len | ||||
|  | ||||
|                 token_ids = model_input.token_ids[None, start_idx:end_idx].to( | ||||
|                     self.device) | ||||
|                 position_ids = model_input.position_ids[None, | ||||
|                                                         start_idx:end_idx].to( | ||||
|                                                             self.device) | ||||
|                 attn_metadata = model_input.attn_metadata | ||||
|                 attn_metadata.num_prefills = 1 | ||||
|                 attn_metadata.slot_mapping = orig_slot_mapping[ | ||||
|                     None, start_idx:end_idx].to(self.device) | ||||
|                 if orig_context_lens[i].item() > 0: | ||||
|                     attn_metadata.context_lens = orig_context_lens[i:i + 1].to( | ||||
|                         self.device) | ||||
|                     attn_metadata.block_tables = orig_block_tables[ | ||||
|                         i].unsqueeze(0).to(self.device) | ||||
|                     attn_metadata.effective_query_lens = \ | ||||
|                         orig_effective_query_lens[i:i + 1].to(self.device) | ||||
|                 else: | ||||
|                     attn_metadata.context_lens = None | ||||
|                     attn_metadata.block_tables = None | ||||
|                     attn_metadata.effective_query_lens = None | ||||
|                 input_lens = model_input.input_lens[i:i + 1].to(self.device) | ||||
|                 t = model_input.t[i:i + 1].to(self.device) | ||||
|                 p = model_input.p[i:i + 1].to(self.device) | ||||
|                 with set_forward_context(model_input.attn_metadata, | ||||
|                                          self.vllm_config, | ||||
|                                          model_input.virtual_engine): | ||||
|                     output_token_ids = self.model(token_ids, position_ids, | ||||
|                                                   input_lens, t, p, | ||||
|                                                   model_input.num_samples, | ||||
|                                                   kv_caches) | ||||
|                 next_token_ids.append(output_token_ids[0]) | ||||
|                 start_idx = end_idx | ||||
|  | ||||
|             if model_input.async_callback is not None: | ||||
|                 model_input.async_callback() | ||||
|             # Retrieve the outputs to CPU. | ||||
|             next_token_ids = [ | ||||
|                 output_token_ids.cpu().tolist() | ||||
|                 for output_token_ids in next_token_ids | ||||
|             ] | ||||
|  | ||||
|             # NOTE(woosuk): Minimal code to construct the sampler outputs. | ||||
|             # The TPU backend does not reuse the sampler, since the TPU backend | ||||
|             # does not support advanced sampling parameters such as logprobs. | ||||
|             zero_logprob = Logprob(0.0) | ||||
|             sampler_outputs = [] | ||||
|             for i, seq_group in enumerate(model_input.seq_groups): | ||||
|                 seq_ids = seq_group | ||||
|                 assert len(seq_ids) == 1 | ||||
|                 seq_id = seq_ids[0] | ||||
|                 seq_outputs = [] | ||||
|                 for j in range(model_input.n[i]): | ||||
|                     next_token_id = next_token_ids[i][j] | ||||
|                     seq_outputs.append( | ||||
|                         SequenceOutput(seq_id, next_token_id, | ||||
|                                        {next_token_id: zero_logprob})) | ||||
|                 sampler_outputs.append( | ||||
|                     CompletionSequenceGroupOutput(seq_outputs, None)) | ||||
|             return [SamplerOutput(sampler_outputs)] | ||||
|         else: | ||||
|             token_ids = model_input.token_ids.to(self.device) | ||||
|             position_ids = model_input.position_ids.to(self.device) | ||||
|             attn_metadata = model_input.attn_metadata | ||||
|             attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( | ||||
|                 self.device) | ||||
|             attn_metadata.block_tables = attn_metadata.block_tables.to( | ||||
|                 self.device) | ||||
|             attn_metadata.context_lens = attn_metadata.context_lens.to( | ||||
|                 self.device) | ||||
|             t = model_input.t.to(self.device) | ||||
|             p = model_input.p.to(self.device) | ||||
|             input_lens = model_input.input_lens.to(self.device) | ||||
|             for i in range(num_steps): | ||||
|                 slot_mapping = attn_metadata.slot_mapping | ||||
|                 with set_forward_context(model_input.attn_metadata, | ||||
|                                          self.vllm_config, | ||||
|                                          model_input.virtual_engine): | ||||
|                     output_token_ids = self.model(token_ids, position_ids, | ||||
|                                                   input_lens, t, p, | ||||
|                                                   model_input.num_samples, | ||||
|                                                   kv_caches) | ||||
|                 self.cached_step_outputs.append(output_token_ids) | ||||
|  | ||||
|                 if i < num_steps - 1: | ||||
|                     # Prepare the inputs for the next step. | ||||
|                     token_ids = output_token_ids.unsqueeze(dim=1).int() | ||||
|                     position_ids = position_ids + 1 | ||||
|                     attn_metadata.context_lens = attn_metadata.context_lens + 1 | ||||
|  | ||||
|                     block_tables = attn_metadata.block_tables | ||||
|                     block_number = block_tables.gather( | ||||
|                         1, | ||||
|                         position_ids.long() // self.block_size) | ||||
|                     block_offset = position_ids % self.block_size | ||||
|  | ||||
|                     is_padding = slot_mapping == _PAD_SLOT_ID | ||||
|                     slot_mapping = block_number * self.block_size + block_offset | ||||
|                     slot_mapping = slot_mapping.long() | ||||
|                     slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, | ||||
|                                                slot_mapping) | ||||
|                     attn_metadata.slot_mapping = slot_mapping | ||||
|  | ||||
|             if model_input.async_callback is not None: | ||||
|                 model_input.async_callback() | ||||
|  | ||||
|             if num_steps > 1: | ||||
|                 return [] | ||||
|             # Retrieve the outputs to CPU. | ||||
|             next_token_ids = self.cached_step_outputs.pop(0) | ||||
|             next_token_ids = next_token_ids.cpu().tolist() | ||||
|             sampler_output = _make_decode_output(next_token_ids, | ||||
|                                                  model_input.seq_groups) | ||||
|             return [sampler_output] | ||||
|  | ||||
|  | ||||
| class ModelWrapper(nn.Module): | ||||
|  | ||||
|     def __init__(self, model: nn.Module): | ||||
|         super().__init__() | ||||
|         self.model = model | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         token_ids: torch.Tensor, | ||||
|         position_ids: torch.Tensor, | ||||
|         input_lens: torch.Tensor, | ||||
|         t: torch.Tensor, | ||||
|         p: torch.Tensor, | ||||
|         num_samples: int, | ||||
|         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], | ||||
|     ) -> torch.Tensor: | ||||
|         """Executes the forward pass of the model and samples the next token. | ||||
|  | ||||
|         Args: | ||||
|             token_ids: The input token IDs of shape [batch_size, seq_len]. | ||||
|             position_ids: The input position IDs of shape [batch_size, seq_len]. | ||||
|             input_lens: The actual input lengths of shape [batch_size]. | ||||
|             t: The sampling temperature of shape [batch_size]. | ||||
|             p: The top-p probability of shape [batch_size]. | ||||
|             num_samples: Number of samples to draw from each logits vector. | ||||
|             kv_caches: The key and value caches. They can be None during the | ||||
|                 memory profiling at initialization. | ||||
|         """ | ||||
|         batch_size, seq_len = token_ids.shape | ||||
|         # Calculate the positions to sample from. | ||||
|         start_indices = torch.arange( | ||||
|             batch_size, dtype=torch.int32, device=input_lens.device) * seq_len | ||||
|         logits_indices = start_indices + input_lens - 1 | ||||
|         attn_metadata = get_forward_context().attn_metadata | ||||
|  | ||||
|         # FIXME(woosuk): This is a temporary hack to avoid using the existing | ||||
|         # sampler and sampling metadata. | ||||
|         sampling_metadata = SamplingMetadata( | ||||
|             seq_groups=[], | ||||
|             selected_token_indices=logits_indices, | ||||
|             categorized_sample_indices={}, | ||||
|             num_prompts=attn_metadata.num_prefills, | ||||
|         ) | ||||
|  | ||||
|         # Skip this in memory profiling at initialization. | ||||
|         if kv_caches[0][0].numel() > 0: | ||||
|             # index_copy_(slot_mapping) only works when the inserted dimension | ||||
|             # is 0. However, the KV cache in the Pallas backend has the shape | ||||
|             # [num_kv_heads, num_blocks, block_size, head_size]. To make it | ||||
|             # work, we need to flatten the first three dimensions and modify | ||||
|             # the slot_mapping accordingly. | ||||
|             num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape | ||||
|             slot_mapping = attn_metadata.slot_mapping | ||||
|             slot_mapping = slot_mapping.flatten() | ||||
|             head_indices = torch.arange(0, | ||||
|                                         num_kv_heads, | ||||
|                                         device=slot_mapping.device, | ||||
|                                         dtype=slot_mapping.dtype) | ||||
|             head_indices *= block_size * num_blocks | ||||
|             slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( | ||||
|                 -1, num_kv_heads) | ||||
|             slot_mapping = slot_mapping + head_indices.view(1, -1) | ||||
|             slot_mapping = slot_mapping.flatten() | ||||
|             attn_metadata.slot_mapping = slot_mapping | ||||
|  | ||||
|         hidden_states = self.model(token_ids, position_ids) | ||||
|         hidden_states = hidden_states.flatten(0, 1) | ||||
|         logits = self.model.compute_logits(hidden_states, sampling_metadata) | ||||
|  | ||||
|         # Argmax sampling. | ||||
|         argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) | ||||
|         argmax_token_ids = argmax_token_ids.repeat(1, num_samples) | ||||
|  | ||||
|         # Zero temperature means greedy decoding. Avoid division by zero. | ||||
|         nonzero_t = torch.where(t != 0, t, 1.0) | ||||
|         logits = logits / nonzero_t.unsqueeze(dim=1) | ||||
|         if _ENABLE_TOP_P: | ||||
|             logits = _apply_top_p(logits, p.unsqueeze(dim=1)) | ||||
|  | ||||
|         # Random sampling. | ||||
|         probs = torch.softmax(logits, dim=-1, dtype=torch.float32) | ||||
|         sampled_token_ids = torch.multinomial(probs, | ||||
|                                               num_samples, | ||||
|                                               replacement=True) | ||||
|         if num_samples == 1: | ||||
|             argmax_token_ids = argmax_token_ids.squeeze(dim=-1) | ||||
|             sampled_token_ids = sampled_token_ids.squeeze(dim=-1) | ||||
|         next_token_ids = torch.where(t != 0, sampled_token_ids, | ||||
|                                      argmax_token_ids) | ||||
|         return next_token_ids | ||||
|  | ||||
|  | ||||
| def _get_padded_prefill_len(x: int) -> int: | ||||
|     # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence | ||||
|     # length to be a multiple of 16. We pad the prompt length to the nearest | ||||
|     # multiple of 16. This is also good for performance. | ||||
|     if x <= 16: | ||||
|         return 16 | ||||
|     return 1 << (x - 1).bit_length() | ||||
|  | ||||
|  | ||||
| def _get_padded_batch_size(batch_size: int) -> int: | ||||
|     # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. | ||||
|     # To meet this requirement in the simplest way, we set the minimal batch | ||||
|     # size to 8. | ||||
|     if batch_size <= 8: | ||||
|         return 8 | ||||
|     else: | ||||
|         return ((batch_size + 15) // 16) * 16 | ||||
|  | ||||
|  | ||||
| def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: | ||||
|     logits_sorted = torch.sort(logits, dim=-1, descending=True).values | ||||
|     sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) | ||||
|     cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) | ||||
|     cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) | ||||
|     logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) | ||||
|     return logits | ||||
|  | ||||
|  | ||||
| def _make_decode_output( | ||||
|     next_token_ids: List[int], | ||||
|     seq_groups: List[List[int]], | ||||
| ) -> SamplerOutput: | ||||
|     zero_logprob = Logprob(0.0) | ||||
|     sampler_outputs = [] | ||||
|     batch_idx = 0 | ||||
|     for seq_group in seq_groups: | ||||
|         seq_ids = seq_group | ||||
|         seq_outputs = [] | ||||
|         for seq_id in seq_ids: | ||||
|             next_token_id = next_token_ids[batch_idx] | ||||
|             seq_outputs.append( | ||||
|                 SequenceOutput(seq_id, next_token_id, | ||||
|                                {next_token_id: zero_logprob})) | ||||
|             batch_idx += 1 | ||||
|         sampler_outputs.append(CompletionSequenceGroupOutput( | ||||
|             seq_outputs, None)) | ||||
|     return SamplerOutput(sampler_outputs) | ||||
							
								
								
									
										337
									
								
								vllm/worker/tpu_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										337
									
								
								vllm/worker/tpu_worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,337 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import os | ||||
| from typing import List, Optional, Tuple, Union | ||||
|  | ||||
| import torch | ||||
| import torch_xla.core.xla_model as xm | ||||
| import torch_xla.debug.profiler as xp | ||||
| import torch_xla.runtime as xr | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.distributed import (ensure_model_parallel_initialized, | ||||
|                               init_distributed_environment) | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor import set_random_seed | ||||
| from vllm.sequence import ExecuteModelRequest | ||||
| from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size | ||||
| from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner | ||||
| from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, | ||||
|                                      LoRANotSupportedWorkerBase, WorkerBase, | ||||
|                                      WorkerInput) | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         local_rank: int, | ||||
|         rank: int, | ||||
|         distributed_init_method: str, | ||||
|         is_driver_worker: bool, | ||||
|     ) -> None: | ||||
|         WorkerBase.__init__(self, vllm_config=vllm_config) | ||||
|         self.parallel_config.rank = rank | ||||
|         self.local_rank = local_rank | ||||
|         self.rank = rank | ||||
|         self.distributed_init_method = distributed_init_method | ||||
|         self.is_driver_worker = is_driver_worker | ||||
|  | ||||
|         assert self.device_config.device_type == "tpu" | ||||
|         if self.cache_config.cache_dtype == "auto": | ||||
|             self.cache_dtype = self.model_config.dtype | ||||
|         else: | ||||
|             self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ | ||||
|                 self.cache_config.cache_dtype] | ||||
|  | ||||
|         self.model_runner: TPUModelRunner = TPUModelRunner( | ||||
|             vllm_config=vllm_config, is_driver_worker=is_driver_worker) | ||||
|  | ||||
|         if self.model_config.seed is None: | ||||
|             self.model_config.seed = 0 | ||||
|  | ||||
|         if vllm_config.lora_config is not None: | ||||
|             raise NotImplementedError( | ||||
|                 "The V0 TPU backend doesn't support LoRA serving") | ||||
|  | ||||
|     def init_device(self) -> None: | ||||
|         os.environ["PJRT_DEVICE"] = "TPU" | ||||
|         torch.set_grad_enabled(False) | ||||
|         torch.set_default_dtype(self.model_config.dtype) | ||||
|  | ||||
|         # NOTE(woosuk): This is just to initialize the TP group and broadcast | ||||
|         # the input objects on CPU. The all-reduce and all-gather ops on TPU | ||||
|         # are invoked by `xm.all_reduce` and `xm.all_gather` which use their | ||||
|         # own context. | ||||
|         init_distributed_environment( | ||||
|             world_size=self.parallel_config.world_size, | ||||
|             rank=self.rank, | ||||
|             local_rank=self.local_rank, | ||||
|             distributed_init_method=self.distributed_init_method, | ||||
|             backend="gloo", | ||||
|         ) | ||||
|         ensure_model_parallel_initialized( | ||||
|             self.parallel_config.tensor_parallel_size, | ||||
|             self.parallel_config.pipeline_parallel_size) | ||||
|  | ||||
|         # Device initialization should happen after initializing the distributed | ||||
|         # runtime. | ||||
|         self.device = xm.xla_device() | ||||
|         self.device_config.device = self.device | ||||
|  | ||||
|         # Set random seed. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|         xm.set_rng_state(self.model_config.seed, self.device) | ||||
|  | ||||
|         # Increase the cache size limit, which is the maximum number of | ||||
|         # dynamo graphs that can be compiled. | ||||
|         # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and | ||||
|         # 30-40 graphs for decode. 128 is an arbitrary safe number. | ||||
|         torch._dynamo.config.cache_size_limit = 128 | ||||
|         # Use persistent cache to avoid XLA recompilation. | ||||
|         # NOTE(woosuk): Set per-rank cache path since different ranks | ||||
|         # can have slightly different XLA graphs. | ||||
|         world_size = self.parallel_config.world_size | ||||
|         rank = xr.global_ordinal() | ||||
|         # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. | ||||
|         # Consequently, changes in optimization flags, which affect compilation | ||||
|         # results, don't change the cache key. This can result in the wrong | ||||
|         # compilation being used. To prevent this, disabling the XLA compilation | ||||
|         # cache during development is recommended.We can disable it by | ||||
|         # `export VLLM_XLA_CACHE_PATH=` | ||||
|         if envs.VLLM_XLA_CACHE_PATH: | ||||
|             per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, | ||||
|                                          f"tp{world_size}_rank{rank}") | ||||
|             xr.initialize_cache(per_rank_path, readonly=False) | ||||
|  | ||||
|         self.profiler = None | ||||
|         if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: | ||||
|             # For TPU, we can only have 1 active profiler session for 1 profiler | ||||
|             # server. So we only profile on rank0. | ||||
|             self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR | ||||
|             logger.info("Profiling enabled. Traces will be saved to: %s", | ||||
|                         self.profile_dir) | ||||
|             self.profiler = xp.start_server(9012) | ||||
|  | ||||
|     def start_profile(self): | ||||
|         if self.rank < 1: | ||||
|             if self.profiler is None: | ||||
|                 raise RuntimeError("Profiler is not enabled.") | ||||
|             xp.start_trace(self.profile_dir) | ||||
|  | ||||
|     def stop_profile(self): | ||||
|         if self.rank < 1: | ||||
|             if self.profiler is None: | ||||
|                 raise RuntimeError("Profiler is not enabled.") | ||||
|             xp.stop_trace() | ||||
|  | ||||
|     def load_model(self): | ||||
|         self.model_runner.load_model() | ||||
|  | ||||
|     def determine_num_available_blocks(self) -> Tuple[int, int]: | ||||
|         num_layers = self.model_config.get_num_layers(self.parallel_config) | ||||
|         head_size = self.model_config.get_head_size() | ||||
|         num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) | ||||
|  | ||||
|         # use an empty tensor instead of `None`` to force Dynamo to pass | ||||
|         # it by reference, rather by specializing on the value ``None``. | ||||
|         # the `dtype` argument does not matter, and we use `float32` as | ||||
|         # a placeholder (it has wide hardware support). | ||||
|         kv_caches = [(torch.tensor([], dtype=torch.float32, | ||||
|                                    device=self.device), | ||||
|                       torch.tensor([], dtype=torch.float32, | ||||
|                                    device=self.device)) | ||||
|                      for _ in range(num_layers)] | ||||
|         bind_kv_cache(self.compilation_config.static_forward_context, | ||||
|                       [kv_caches]) | ||||
|         self.model_runner._dummy_run( | ||||
|             batch_size=1, | ||||
|             seq_len=self.scheduler_config.max_num_batched_tokens, | ||||
|             kv_caches=kv_caches, | ||||
|             exec_mode=ExecutionMode.PREFILL, | ||||
|         ) | ||||
|         # Synchronize before measuring the memory usage. | ||||
|         xm.wait_device_ops() | ||||
|  | ||||
|         # Get the maximum amount of memory used by the model weights and | ||||
|         # intermediate activations. | ||||
|         m = xm.get_memory_info(self.device) | ||||
|         total_memory_size = m["bytes_limit"] | ||||
|         profiled = m["peak_bytes_used"]  # Weights + intermediate activations. | ||||
|  | ||||
|         # Calculate the TPU KV cache size based on profiling. | ||||
|         usable_memory_size = int(total_memory_size * | ||||
|                                  self.cache_config.gpu_memory_utilization) | ||||
|         tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) | ||||
|         dtype_bytes = get_dtype_size(self.cache_dtype) | ||||
|         block_size_bytes = (dtype_bytes * self.cache_config.block_size * | ||||
|                             num_layers * 2 * head_size * num_kv_heads) | ||||
|         num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes | ||||
|         num_tpu_blocks = (num_tpu_blocks // 8) * 8  # Round down to 8. | ||||
|  | ||||
|         # Calculate the CPU KV cache size based on the config. | ||||
|         num_cpu_blocks = int(self.cache_config.swap_space_bytes // | ||||
|                              block_size_bytes) | ||||
|         num_cpu_blocks = (num_cpu_blocks // 8) * 8  # Round down to 8. | ||||
|         return num_tpu_blocks, num_cpu_blocks | ||||
|  | ||||
|     def initialize_cache( | ||||
|         self, | ||||
|         num_gpu_blocks: int, | ||||
|         num_cpu_blocks: int, | ||||
|     ) -> None: | ||||
|         self.cache_config.num_gpu_blocks = num_gpu_blocks | ||||
|         self.cache_config.num_cpu_blocks = num_cpu_blocks | ||||
|         self.block_size = self.cache_config.block_size | ||||
|  | ||||
|         dtype = self.cache_dtype | ||||
|         num_layers = self.model_config.get_num_layers(self.parallel_config) | ||||
|         num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) | ||||
|         head_size = self.model_config.get_head_size() | ||||
|  | ||||
|         self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] | ||||
|         self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] | ||||
|         tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( | ||||
|             num_gpu_blocks, self.block_size, num_kv_heads, head_size) | ||||
|         cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( | ||||
|             num_cpu_blocks, self.block_size, num_kv_heads, head_size) | ||||
|         for _ in range(num_layers): | ||||
|             tpu_k_cache = torch.zeros(tpu_cache_shape, | ||||
|                                       dtype=dtype, | ||||
|                                       device=self.device) | ||||
|             tpu_v_cache = torch.zeros_like(tpu_k_cache) | ||||
|             self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) | ||||
|             cpu_k_cache = torch.zeros(cpu_cache_shape, | ||||
|                                       dtype=dtype, | ||||
|                                       device="cpu") | ||||
|             cpu_v_cache = torch.zeros_like(cpu_k_cache) | ||||
|             self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) | ||||
|         bind_kv_cache(self.compilation_config.static_forward_context, | ||||
|                       [self.tpu_cache]) | ||||
|         self._warmup_model() | ||||
|  | ||||
|     def _warmup_model(self) -> None: | ||||
|         # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined | ||||
|         # for CUDA graphs. We should refactor this part. | ||||
|         if not self.model_config.enforce_eager: | ||||
|             # Warm up the model with all possible input shapes so that | ||||
|             # compilation never happens during the actual execution. | ||||
|             # This may take ~30 mins for the first run and ~20 mins for the | ||||
|             # subsequent runs. | ||||
|             # If `enforce_eager` is True, the ahead-of-time compilation is | ||||
|             # skipped and the compilation happens during the actual execution, | ||||
|             # which is bad for performance but useful for development. | ||||
|             self.model_runner.warmup_model(self.tpu_cache) | ||||
|  | ||||
|     def get_cache_block_size_bytes(self) -> int: | ||||
|         head_size = self.model_config.get_head_size() | ||||
|         num_heads = self.model_config.get_num_kv_heads(self.parallel_config) | ||||
|         num_layers = self.model_config.get_num_layers(self.parallel_config) | ||||
|  | ||||
|         key_cache_block = self.cache_config.block_size * num_heads * head_size | ||||
|         value_cache_block = key_cache_block | ||||
|         total = num_layers * (key_cache_block + value_cache_block) | ||||
|         dtype_size = get_dtype_size(self.cache_dtype) | ||||
|         return dtype_size * total | ||||
|  | ||||
|     @property | ||||
|     def do_metadata_broadcast(self) -> bool: | ||||
|         return self.parallel_config.tensor_parallel_size > 1 | ||||
|  | ||||
|     @property | ||||
|     def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: | ||||
|         # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline | ||||
|         # parallelism. | ||||
|         return [self.tpu_cache] | ||||
|  | ||||
|     def prepare_worker_input( | ||||
|         self, | ||||
|         execute_model_req: ExecuteModelRequest, | ||||
|     ) -> WorkerInput: | ||||
|         virtual_engine = execute_model_req.virtual_engine | ||||
|         num_seq_groups = len(execute_model_req.seq_group_metadata_list) | ||||
|         blocks_to_swap_in = _make_src_to_dst( | ||||
|             execute_model_req.blocks_to_swap_in, "cpu", self.device) | ||||
|         blocks_to_swap_out = _make_src_to_dst( | ||||
|             execute_model_req.blocks_to_swap_out, self.device, "cpu") | ||||
|         blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, | ||||
|                                           self.device, self.device) | ||||
|         return WorkerInput( | ||||
|             num_seq_groups=num_seq_groups, | ||||
|             blocks_to_swap_in=blocks_to_swap_in, | ||||
|             blocks_to_swap_out=blocks_to_swap_out, | ||||
|             blocks_to_copy=blocks_to_copy, | ||||
|             virtual_engine=virtual_engine, | ||||
|         ) | ||||
|  | ||||
|     def execute_worker(self, worker_input: WorkerInput) -> None: | ||||
|         virtual_engine = worker_input.virtual_engine | ||||
|         assert virtual_engine == 0 | ||||
|         attn_backend = self.model_runner.attn_backend | ||||
|         num_layers = self.model_config.get_num_layers(self.parallel_config) | ||||
|  | ||||
|         # Issue cache operations. | ||||
|         if worker_input.blocks_to_swap_in is not None: | ||||
|             src_indices, dst_indices = worker_input.blocks_to_swap_in | ||||
|             if src_indices.numel() > 0: | ||||
|                 # Swap from CPU to TPU. | ||||
|                 for i in range(num_layers): | ||||
|                     tpu_k_cache, tpu_v_cache = self.tpu_cache[i] | ||||
|                     cpu_k_cache, cpu_v_cache = self.cpu_cache[i] | ||||
|                     k = cpu_k_cache[:, src_indices].to(self.device) | ||||
|                     v = cpu_v_cache[:, src_indices].to(self.device) | ||||
|                     _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) | ||||
|  | ||||
|         if worker_input.blocks_to_swap_out is not None: | ||||
|             src_indices, dst_indices = worker_input.blocks_to_swap_out | ||||
|             if src_indices.numel() > 0: | ||||
|                 # Swap from TPU to CPU. | ||||
|                 for i in range(num_layers): | ||||
|                     tpu_k_cache, tpu_v_cache = self.tpu_cache[i] | ||||
|                     cpu_k_cache, cpu_v_cache = self.cpu_cache[i] | ||||
|                     cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] | ||||
|                     cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] | ||||
|  | ||||
|         if worker_input.blocks_to_copy is not None: | ||||
|             src_indices, dst_indices = worker_input.blocks_to_copy | ||||
|             if src_indices.numel() > 0: | ||||
|                 attn_backend.copy_blocks(self.tpu_cache, | ||||
|                                          (src_indices, dst_indices)) | ||||
|  | ||||
|  | ||||
| def _make_src_to_dst( | ||||
|     mapping: List[Tuple[int, int]], | ||||
|     src_device: Union[torch.device, str], | ||||
|     dst_device: Union[torch.device, str], | ||||
| ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: | ||||
|     if not mapping: | ||||
|         return None | ||||
|  | ||||
|     src_indices = [i for i, _ in mapping] | ||||
|     dst_indices = [i for _, i in mapping] | ||||
|     src_indices = torch.tensor(src_indices, | ||||
|                                device=src_device, | ||||
|                                dtype=torch.int64) | ||||
|     dst_indices = torch.tensor(dst_indices, | ||||
|                                device=dst_device, | ||||
|                                dtype=torch.int64) | ||||
|     return src_indices, dst_indices | ||||
|  | ||||
|  | ||||
| @torch.compile(backend="openxla") | ||||
| def _insert_kv( | ||||
|     k: torch.Tensor, | ||||
|     v: torch.Tensor, | ||||
|     indices: torch.Tensor, | ||||
|     tpu_k_cache: torch.Tensor, | ||||
|     tpu_v_cache: torch.Tensor, | ||||
| ) -> None: | ||||
|     torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) | ||||
|     torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) | ||||
|     tpu_k_cache[:, indices] = k | ||||
|     tpu_v_cache[:, indices] = v | ||||
							
								
								
									
										606
									
								
								vllm/worker/xpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										606
									
								
								vllm/worker/xpu_model_runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,606 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import dataclasses | ||||
| import time | ||||
| import weakref | ||||
| from collections import defaultdict | ||||
| from dataclasses import dataclass | ||||
| from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, | ||||
|                     Type, TypeVar) | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from vllm.attention import get_attn_backend | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.distributed import get_pp_group | ||||
| from vllm.forward_context import set_forward_context | ||||
| from vllm.inputs import INPUT_REGISTRY, InputRegistry | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor import SamplingMetadataCache | ||||
| from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||||
| from vllm.model_executor.model_loader import get_model | ||||
| from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, | ||||
|                              MultiModalKwargs, MultiModalPlaceholderMap, | ||||
|                              MultiModalRegistry) | ||||
| from vllm.sampling_params import SamplingParams | ||||
| from vllm.sequence import IntermediateTensors, SequenceGroupMetadata | ||||
| from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad | ||||
| from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata | ||||
| from vllm.worker.model_runner_base import ( | ||||
|     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, | ||||
|     _add_attn_metadata_broadcastable_dict, | ||||
|     _add_sampling_metadata_broadcastable_dict, | ||||
|     _init_attn_metadata_from_tensor_dict, | ||||
|     _init_sampling_metadata_from_tensor_dict) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from vllm.attention.backends.abstract import AttentionBackend | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| _PAD_SLOT_ID = -1 | ||||
|  | ||||
| TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class ModelInputForXPU(ModelRunnerInputBase): | ||||
|     """ | ||||
|     Used by the NeuronModelRunner. | ||||
|     """ | ||||
|     input_tokens: Optional[torch.Tensor] = None | ||||
|     input_positions: Optional[torch.Tensor] = None | ||||
|     attn_metadata: Optional["AttentionMetadata"] = None | ||||
|     multi_modal_kwargs: Optional[BatchedTensorInputs] = None | ||||
|     virtual_engine: Optional[int] = None | ||||
|     seq_lens: Optional[List[int]] = None | ||||
|     query_lens: Optional[List[int]] = None | ||||
|     async_callback: Optional[Callable] = None | ||||
|  | ||||
|     def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: | ||||
|         tensor_dict = { | ||||
|             "input_tokens": self.input_tokens, | ||||
|             "input_positions": self.input_positions, | ||||
|         } | ||||
|         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||||
|  | ||||
|         return tensor_dict | ||||
|  | ||||
|     @classmethod | ||||
|     def from_broadcasted_tensor_dict( | ||||
|         cls: Type[TModelInputForXPU], | ||||
|         tensor_dict: Dict[str, Any], | ||||
|         attn_backend: Optional["AttentionBackend"] = None, | ||||
|     ) -> TModelInputForXPU: | ||||
|         if attn_backend is not None: | ||||
|             tensor_dict = _init_attn_metadata_from_tensor_dict( | ||||
|                 attn_backend, tensor_dict) | ||||
|         return cls(**tensor_dict) | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): | ||||
|     """ | ||||
|     Used by the ModelRunner. | ||||
|     """ | ||||
|     sampling_metadata: Optional["SamplingMetadata"] = None | ||||
|  | ||||
|     def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: | ||||
|         tensor_dict = { | ||||
|             "input_tokens": self.input_tokens, | ||||
|             "input_positions": self.input_positions, | ||||
|         } | ||||
|         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||||
|         _add_sampling_metadata_broadcastable_dict(tensor_dict, | ||||
|                                                   self.sampling_metadata) | ||||
|         return tensor_dict | ||||
|  | ||||
|     @classmethod | ||||
|     def from_broadcasted_tensor_dict( | ||||
|         cls, | ||||
|         tensor_dict: Dict[str, Any], | ||||
|         attn_backend: Optional["AttentionBackend"] = None, | ||||
|     ) -> "ModelInputForXPUWithSamplingMetadata": | ||||
|         tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) | ||||
|         if attn_backend is not None: | ||||
|             tensor_dict = _init_attn_metadata_from_tensor_dict( | ||||
|                 attn_backend, tensor_dict) | ||||
|         return cls(**tensor_dict) | ||||
|  | ||||
|  | ||||
| class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): | ||||
|  | ||||
|     def __init__(self, | ||||
|                  runner: "XPUModelRunner", | ||||
|                  finished_requests_ids: Optional[List[str]] = None) -> None: | ||||
|         super().__init__() | ||||
|         self.runner = runner | ||||
|         self.model_input_cls = self.runner._model_input_cls | ||||
|         self.attn_backend = self.runner.attn_backend | ||||
|         self.sliding_window = self.runner.sliding_window | ||||
|         self.block_size = self.runner.block_size | ||||
|         self.device = self.runner.device | ||||
|  | ||||
|     def prepare(self, | ||||
|                 finished_requests_ids: Optional[List[str]] = None) -> None: | ||||
|         self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] | ||||
|  | ||||
|     def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): | ||||
|         self.seq_group_metadata_list.append(seq_group_metadata) | ||||
|  | ||||
|     def build(self) -> ModelInputForXPU: | ||||
|         is_prompt = self.seq_group_metadata_list[0].is_prompt | ||||
|         # Prepare input tensors. | ||||
|         if is_prompt: | ||||
|             (input_tokens, input_positions, attn_metadata, seq_lens, | ||||
|              multi_modal_kwargs) = self._prepare_prompt( | ||||
|                  self.seq_group_metadata_list) | ||||
|         else: | ||||
|             (input_tokens, input_positions, | ||||
|              attn_metadata) = self._prepare_decode( | ||||
|                  self.seq_group_metadata_list) | ||||
|             seq_lens = None | ||||
|             multi_modal_kwargs = None | ||||
|  | ||||
|         return self.model_input_cls( | ||||
|             input_tokens=input_tokens, | ||||
|             input_positions=input_positions, | ||||
|             attn_metadata=attn_metadata, | ||||
|             multi_modal_kwargs=multi_modal_kwargs, | ||||
|             seq_lens=seq_lens, | ||||
|             query_lens=seq_lens, | ||||
|         ) | ||||
|  | ||||
|     def _prepare_prompt( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], | ||||
|                BatchedTensorInputs]: | ||||
|         assert len(seq_group_metadata_list) > 0 | ||||
|         input_tokens: List[int] = [] | ||||
|         input_positions: List[int] = [] | ||||
|         slot_mapping: List[int] = [] | ||||
|         seq_lens: List[int] = [] | ||||
|         multi_modal_kwargs_list: List[MultiModalKwargs] = [] | ||||
|         multi_modal_placeholder_maps: Dict[ | ||||
|             str, | ||||
|             MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) | ||||
|  | ||||
|         for seq_group_metadata in seq_group_metadata_list: | ||||
|             assert seq_group_metadata.is_prompt | ||||
|             seq_ids = list(seq_group_metadata.seq_data.keys()) | ||||
|             assert len(seq_ids) == 1 | ||||
|             seq_id = seq_ids[0] | ||||
|  | ||||
|             seq_data = seq_group_metadata.seq_data[seq_id] | ||||
|             prompt_tokens = seq_data.get_token_ids() | ||||
|             computed_len = seq_data.get_num_computed_tokens() | ||||
|             seq_len = len(prompt_tokens) | ||||
|  | ||||
|             seq_lens.append(seq_len)  # Prompt token num | ||||
|             input_tokens.extend(prompt_tokens)  # Token ids | ||||
|  | ||||
|             # Token position ids | ||||
|             # NOTE(woosuk): Here we assume that the first token in the prompt | ||||
|             # is always the first token in the sequence. | ||||
|             positions_range = range(computed_len, seq_len) | ||||
|             input_positions.extend(list(positions_range)) | ||||
|  | ||||
|             if seq_group_metadata.multi_modal_data: | ||||
|                 # NOTE: mm_kwargs only includes the subset of multi-modal items | ||||
|                 # that intersect with the current prefill positions. | ||||
|                 mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \ | ||||
|                     .from_seq_group(seq_group_metadata, positions_range) | ||||
|  | ||||
|                 multi_modal_kwargs_list.append(mm_kwargs) | ||||
|  | ||||
|                 for modality, placeholder_map in placeholder_maps.items(): | ||||
|                     multi_modal_placeholder_maps[modality].extend( | ||||
|                         placeholder_map) | ||||
|  | ||||
|             if seq_group_metadata.block_tables is None: | ||||
|                 # During memory profiling, the block tables are not initialized | ||||
|                 # yet. In this case, we just use a dummy slot mapping. | ||||
|                 slot_mapping.extend([_PAD_SLOT_ID] * seq_len) | ||||
|                 continue | ||||
|  | ||||
|             # Compute the slot mapping. | ||||
|             block_table = seq_group_metadata.block_tables[seq_id] | ||||
|             # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, | ||||
|             # where start_idx is max(0, seq_len - sliding_window). | ||||
|             # For example, if the prompt len is 10, sliding window is 8, and | ||||
|             # block size is 4, the first two tokens are masked and the slot | ||||
|             # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. | ||||
|             start_idx = 0 | ||||
|             if self.sliding_window is not None: | ||||
|                 start_idx = max(0, seq_len - self.sliding_window) | ||||
|  | ||||
|             for i in range(computed_len, seq_len): | ||||
|                 if i < start_idx: | ||||
|                     slot_mapping.append(_PAD_SLOT_ID) | ||||
|                     continue | ||||
|  | ||||
|                 block_number = block_table[i // | ||||
|                                            self.block_size]  # type: ignore | ||||
|                 block_offset = i % self.block_size  # type: ignore | ||||
|                 slot = block_number * self.block_size + block_offset | ||||
|                 slot_mapping.append(slot) | ||||
|  | ||||
|         num_prompt_tokens = len(input_tokens) | ||||
|  | ||||
|         input_tokens = torch.tensor(input_tokens, | ||||
|                                     dtype=torch.long, | ||||
|                                     device=self.device)  # type: ignore | ||||
|         input_positions = torch.tensor(input_positions, | ||||
|                                        dtype=torch.long, | ||||
|                                        device=self.device)  # type: ignore | ||||
|         slot_mapping = torch.tensor(slot_mapping, | ||||
|                                     dtype=torch.long, | ||||
|                                     device=self.device)  # type: ignore | ||||
|         placeholder_index_maps = { | ||||
|             modality: placeholder_map.index_map() | ||||
|             for modality, placeholder_map in | ||||
|             multi_modal_placeholder_maps.items() | ||||
|         } | ||||
|  | ||||
|         max_seqlen = max(seq_lens) | ||||
|         tmp = [0] | ||||
|         tmp.extend(seq_lens) | ||||
|         seqlen = torch.tensor(tmp) | ||||
|         seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) | ||||
|  | ||||
|         attn_metadata = self.attn_backend.make_metadata( | ||||
|             is_prompt=True, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=placeholder_index_maps, | ||||
|             enable_kv_scales_calculation=False, | ||||
|             seq_lens=seq_lens, | ||||
|             seqlen_q=seqlen_q, | ||||
|             max_seqlen=max_seqlen, | ||||
|             seq_lens_tensor=torch.tensor([]), | ||||
|             max_decode_seq_len=0, | ||||
|             num_prefills=len(seq_lens), | ||||
|             num_prefill_tokens=num_prompt_tokens, | ||||
|             num_decode_tokens=0, | ||||
|             block_tables=torch.tensor([], device=self.device, dtype=torch.int), | ||||
|         ) | ||||
|  | ||||
|         multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) | ||||
|  | ||||
|         return (input_tokens, input_positions, attn_metadata, seq_lens, | ||||
|                 multi_modal_kwargs) | ||||
|  | ||||
|     def _prepare_decode( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: | ||||
|         assert len(seq_group_metadata_list) > 0 | ||||
|         input_tokens: List[int] = [] | ||||
|         input_positions: List[int] = [] | ||||
|         slot_mapping: List[int] = [] | ||||
|         seq_lens: List[int] = [] | ||||
|         block_tables: List[List[int]] = [] | ||||
|  | ||||
|         for seq_group_metadata in seq_group_metadata_list: | ||||
|             assert not seq_group_metadata.is_prompt | ||||
|             assert seq_group_metadata.token_chunk_size == 1 | ||||
|  | ||||
|             seq_ids = list(seq_group_metadata.seq_data.keys()) | ||||
|  | ||||
|             for seq_id in seq_ids: | ||||
|                 seq_data = seq_group_metadata.seq_data[seq_id] | ||||
|                 generation_token = seq_data.get_last_token_id() | ||||
|                 input_tokens.append(generation_token) | ||||
|  | ||||
|                 seq_len = seq_data.get_len() | ||||
|                 position = seq_len - 1 | ||||
|                 input_positions.append(position) | ||||
|  | ||||
|                 seq_len = seq_len if self.sliding_window is None else min( | ||||
|                     seq_len, self.sliding_window) | ||||
|                 seq_lens.append(seq_len) | ||||
|  | ||||
|                 block_table = seq_group_metadata.block_tables[seq_id] | ||||
|                 block_number = block_table[position // self.block_size] | ||||
|                 block_offset = position % self.block_size | ||||
|                 slot = block_number * self.block_size + block_offset | ||||
|                 slot_mapping.append(slot) | ||||
|  | ||||
|                 if self.sliding_window is not None: | ||||
|                     sliding_window_blocks = (self.sliding_window // | ||||
|                                              self.block_size) | ||||
|                     block_table = block_table[-sliding_window_blocks:] | ||||
|                 block_tables.append(block_table) | ||||
|  | ||||
|         max_decode_seq_len = max(seq_lens) | ||||
|  | ||||
|         input_tokens = torch.tensor(input_tokens, | ||||
|                                     dtype=torch.long, | ||||
|                                     device=self.device) | ||||
|         input_positions = torch.tensor(input_positions, | ||||
|                                        dtype=torch.long, | ||||
|                                        device=self.device) | ||||
|         slot_mapping = torch.tensor(slot_mapping, | ||||
|                                     dtype=torch.long, | ||||
|                                     device=self.device) | ||||
|         seq_lens_tensor = torch.tensor(seq_lens, | ||||
|                                        dtype=torch.int, | ||||
|                                        device=self.device) | ||||
|  | ||||
|         block_tables = make_tensor_with_pad( | ||||
|             block_tables, | ||||
|             pad=0, | ||||
|             dtype=torch.int, | ||||
|             device=self.device, | ||||
|         ) | ||||
|  | ||||
|         attn_metadata = self.attn_backend.make_metadata( | ||||
|             is_prompt=False, | ||||
|             slot_mapping=slot_mapping, | ||||
|             multi_modal_placeholder_index_maps=None, | ||||
|             enable_kv_scales_calculation=False, | ||||
|             seq_lens=seq_lens, | ||||
|             seqlen_q=torch.tensor([]), | ||||
|             max_seqlen=0, | ||||
|             seq_lens_tensor=seq_lens_tensor, | ||||
|             max_decode_seq_len=max_decode_seq_len, | ||||
|             num_prefill_tokens=0, | ||||
|             num_decode_tokens=len(input_tokens), | ||||
|             num_prefills=0, | ||||
|             block_tables=block_tables, | ||||
|         ) | ||||
|         return ( | ||||
|             input_tokens, | ||||
|             input_positions, | ||||
|             attn_metadata, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): | ||||
|     _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( | ||||
|         ModelInputForXPUWithSamplingMetadata) | ||||
|     _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         kv_cache_dtype: Optional[str] = "auto", | ||||
|         is_driver_worker: bool = False, | ||||
|         return_hidden_states: bool = False, | ||||
|         input_registry: InputRegistry = INPUT_REGISTRY, | ||||
|         mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, | ||||
|     ): | ||||
|  | ||||
|         ModelRunnerBase.__init__(self, vllm_config=vllm_config) | ||||
|         model_config = self.model_config | ||||
|         cache_config = self.cache_config | ||||
|         self.is_driver_worker = is_driver_worker | ||||
|         self.return_hidden_states = return_hidden_states | ||||
|  | ||||
|         self.device = self.device_config.device | ||||
|  | ||||
|         self.kv_cache_dtype = kv_cache_dtype | ||||
|         self.sliding_window = model_config.get_sliding_window() | ||||
|         self.block_size = cache_config.block_size | ||||
|  | ||||
|         self.attn_backend = get_attn_backend( | ||||
|             self.model_config.get_head_size(), | ||||
|             self.model_config.dtype, | ||||
|             self.kv_cache_dtype, | ||||
|             self.block_size, | ||||
|             self.model_config.is_attention_free, | ||||
|         ) | ||||
|  | ||||
|         # Multi-modal data support | ||||
|         self.input_registry = input_registry | ||||
|         self.mm_registry = mm_registry | ||||
|  | ||||
|         # Lazy initialization. | ||||
|         self.model: nn.Module  # Set after init_Model | ||||
|         self.sampler = get_sampler() | ||||
|  | ||||
|         self.sampling_metadata_cache: SamplingMetadataCache = \ | ||||
|               SamplingMetadataCache() \ | ||||
|                 if self.parallel_config.pipeline_parallel_size == 1 else None | ||||
|  | ||||
|         self.builder = self._builder_cls(weakref.proxy(self)) | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         with DeviceMemoryProfiler() as m: | ||||
|             self.model = get_model(vllm_config=self.vllm_config) | ||||
|  | ||||
|         self.model_memory_usage = m.consumed_memory | ||||
|         logger.info("Loading model weights took %.4f GiB", | ||||
|                     self.model_memory_usage / GiB_bytes) | ||||
|  | ||||
|     def get_model(self) -> nn.Module: | ||||
|         return self.model | ||||
|  | ||||
|     @property | ||||
|     def vocab_size(self) -> int: | ||||
|         return self.model_config.get_vocab_size() | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def profile_run(self) -> None: | ||||
|         # Enable top-k sampling to reflect the accurate memory usage. | ||||
|         sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) | ||||
|         max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens | ||||
|         max_num_seqs = self.scheduler_config.max_num_seqs | ||||
|  | ||||
|         # Profile memory usage with max_num_sequences sequences and the total | ||||
|         # number of tokens equal to max_num_batched_tokens. | ||||
|         seqs: List[SequenceGroupMetadata] = [] | ||||
|         # Additional GPU memory may be needed for multi-modal encoding, which | ||||
|         # needs to be accounted for when calculating the GPU blocks for | ||||
|         # vLLM blocker manager. | ||||
|         # To exercise the worst scenario for GPU memory consumption, | ||||
|         # the number of seqs (batch_size) is chosen to maximize the number | ||||
|         # of images processed. | ||||
|         max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( | ||||
|             self.model_config) | ||||
|         if max_mm_tokens > 0: | ||||
|             max_num_seqs_orig = max_num_seqs | ||||
|             max_num_seqs = min(max_num_seqs, | ||||
|                                max_num_batched_tokens // max_mm_tokens) | ||||
|             if max_num_seqs < 1: | ||||
|                 expr = (f"min({max_num_seqs_orig}, " | ||||
|                         f"{max_num_batched_tokens} // {max_mm_tokens})") | ||||
|                 logger.warning( | ||||
|                     "Computed max_num_seqs (%s) to be less than 1. " | ||||
|                     "Setting it to the minimum value of 1.", expr) | ||||
|                 max_num_seqs = 1 | ||||
|  | ||||
|         batch_size = 0 | ||||
|         for group_id in range(max_num_seqs): | ||||
|             seq_len = (max_num_batched_tokens // max_num_seqs + | ||||
|                        (group_id < max_num_batched_tokens % max_num_seqs)) | ||||
|             batch_size += seq_len | ||||
|  | ||||
|             dummy_data = self.input_registry \ | ||||
|                 .dummy_data_for_profiling(self.model_config, | ||||
|                                           seq_len, | ||||
|                                           self.mm_registry) | ||||
|  | ||||
|             seq = SequenceGroupMetadata( | ||||
|                 request_id=str(group_id), | ||||
|                 is_prompt=True, | ||||
|                 seq_data={group_id: dummy_data.seq_data}, | ||||
|                 sampling_params=sampling_params, | ||||
|                 block_tables=None, | ||||
|                 lora_request=None, | ||||
|                 multi_modal_data=dummy_data.multi_modal_data, | ||||
|                 multi_modal_placeholders=dummy_data.multi_modal_placeholders) | ||||
|             seqs.append(seq) | ||||
|  | ||||
|         finished_requests_ids = [seq.request_id for seq in seqs] | ||||
|         model_input = self.prepare_model_input( | ||||
|             seqs, finished_requests_ids=finished_requests_ids) | ||||
|         intermediate_tensors = None | ||||
|         if not get_pp_group().is_first_rank: | ||||
|             intermediate_tensors = self.model.make_empty_intermediate_tensors( | ||||
|                 batch_size=batch_size, | ||||
|                 dtype=self.model_config.dtype, | ||||
|                 device=self.device) | ||||
|         self.execute_model(model_input, None, intermediate_tensors) | ||||
|         torch.xpu.synchronize() | ||||
|         return | ||||
|  | ||||
|     def make_model_input_from_broadcasted_tensor_dict( | ||||
|             self, | ||||
|             tensor_dict: Dict[str, | ||||
|                               Any]) -> ModelInputForXPUWithSamplingMetadata: | ||||
|         return ( | ||||
|             ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( | ||||
|                 tensor_dict, | ||||
|                 attn_backend=self.attn_backend, | ||||
|             )) | ||||
|  | ||||
|     def _prepare_model_input_tensors( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         finished_requests_ids: Optional[List[str]] = None | ||||
|     ) -> ModelInputForXPUWithSamplingMetadata: | ||||
|         """Helper method to prepare the model input based on a given sequence | ||||
|         group. Prepares metadata needed for the base model forward pass but not | ||||
|         metadata for possible additional steps, e.g., sampling. | ||||
|  | ||||
|         """ | ||||
|         builder = self.builder | ||||
|         builder.prepare(finished_requests_ids) | ||||
|         for seq_group_metadata in seq_group_metadata_list: | ||||
|             builder.add_seq_group(seq_group_metadata) | ||||
|  | ||||
|         return builder.build()  # type: ignore | ||||
|  | ||||
|     def prepare_model_input( | ||||
|         self, | ||||
|         seq_group_metadata_list: List[SequenceGroupMetadata], | ||||
|         virtual_engine: int = 0, | ||||
|         finished_requests_ids: Optional[List[str]] = None | ||||
|     ) -> ModelInputForXPUWithSamplingMetadata: | ||||
|         """Prepare the model input based on a given sequence group, including | ||||
|         metadata for the sampling step. | ||||
|  | ||||
|         """ | ||||
|         model_input = self._prepare_model_input_tensors( | ||||
|             seq_group_metadata_list, finished_requests_ids) | ||||
|         # Sampling metadata is only required for the final pp group | ||||
|         generators = self.get_generators(finished_requests_ids) | ||||
|         sampling_metadata = SamplingMetadata.prepare( | ||||
|             seq_group_metadata_list, | ||||
|             model_input.seq_lens, | ||||
|             model_input.query_lens, | ||||
|             self.device, | ||||
|             pin_memory=False, | ||||
|             generators=generators, | ||||
|             cache=self.sampling_metadata_cache) | ||||
|  | ||||
|         return dataclasses.replace(model_input, | ||||
|                                    sampling_metadata=sampling_metadata, | ||||
|                                    virtual_engine=virtual_engine) | ||||
|  | ||||
|     @torch.inference_mode() | ||||
|     def execute_model( | ||||
|         self, | ||||
|         model_input: ModelInputForXPUWithSamplingMetadata, | ||||
|         kv_caches: List[torch.Tensor], | ||||
|         intermediate_tensors: Optional[IntermediateTensors] = None, | ||||
|         num_steps: int = 1, | ||||
|     ) -> Optional[List[SamplerOutput]]: | ||||
|         if num_steps > 1: | ||||
|             raise ValueError( | ||||
|                 "XPUModelRunner does not support multi-step execution.") | ||||
|  | ||||
|         model_executable = self.model | ||||
|         if (self.observability_config is not None | ||||
|                 and self.observability_config.collect_model_forward_time): | ||||
|             model_forward_start_time = time.time() | ||||
|         with set_forward_context(model_input.attn_metadata, self.vllm_config, | ||||
|                                  model_input.virtual_engine): | ||||
|             hidden_or_intermediate_states = model_executable( | ||||
|                 input_ids=model_input.input_tokens, | ||||
|                 positions=model_input.input_positions, | ||||
|                 intermediate_tensors=intermediate_tensors, | ||||
|                 **MultiModalKwargs.as_kwargs( | ||||
|                     model_input.multi_modal_kwargs or {}, | ||||
|                     device=self.device, | ||||
|                 ), | ||||
|             ) | ||||
|         # Compute the logits in the last pipeline stage. | ||||
|         if not get_pp_group().is_last_rank: | ||||
|             return hidden_or_intermediate_states | ||||
|  | ||||
|         if (self.observability_config is not None | ||||
|                 and self.observability_config.collect_model_forward_time): | ||||
|             model_forward_end_time = time.time() | ||||
|  | ||||
|         # Compute the logits. | ||||
|         logits = self.model.compute_logits(hidden_or_intermediate_states, | ||||
|                                            model_input.sampling_metadata) | ||||
|  | ||||
|         # Only perform sampling in the driver worker. | ||||
|         if not self.is_driver_worker: | ||||
|             return [] | ||||
|  | ||||
|         if model_input.async_callback is not None: | ||||
|             model_input.async_callback() | ||||
|  | ||||
|         # Sample the next token. | ||||
|         output: SamplerOutput = self.sampler( | ||||
|             logits=logits, | ||||
|             sampling_metadata=model_input.sampling_metadata, | ||||
|         ) | ||||
|         if (self.observability_config is not None | ||||
|                 and self.observability_config.collect_model_forward_time | ||||
|                 and output is not None): | ||||
|             model_forward_time = (model_forward_end_time - | ||||
|                                   model_forward_start_time) | ||||
|             # If there are multiple workers, we are still tracking the latency | ||||
|             # from the start time of the driver worker to the end time of the | ||||
|             # driver worker. The model forward time will then end up covering | ||||
|             # the communication time as well. | ||||
|             output.model_forward_time = model_forward_time | ||||
|  | ||||
|         return [output] | ||||
							
								
								
									
										186
									
								
								vllm/worker/xpu_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								vllm/worker/xpu_worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,186 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| """A XPU worker class.""" | ||||
| import gc | ||||
| import os | ||||
| from typing import List, Optional, Tuple | ||||
|  | ||||
| import intel_extension_for_pytorch  # noqa: F401 | ||||
| import oneccl_bindings_for_pytorch  # noqa: F401 | ||||
| import torch | ||||
| import torch.distributed | ||||
|  | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.distributed import (ensure_model_parallel_initialized, | ||||
|                               init_distributed_environment) | ||||
| from vllm.distributed.parallel_state import get_pp_group | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor import set_random_seed | ||||
| from vllm.platforms import current_platform | ||||
| from vllm.worker.cache_engine import CacheEngine | ||||
| from vllm.worker.worker import Worker | ||||
| from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase | ||||
| from vllm.worker.xpu_model_runner import XPUModelRunner | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class XPUWorker(LoRANotSupportedWorkerBase, Worker): | ||||
|     """A worker class that executes (a partition of) the model on a GPU. | ||||
|  | ||||
|     Each worker is associated with a single XPU device. The worker is  | ||||
|     responsible for maintaining the KV cache and executing the model on the  | ||||
|     XPU. In case of distributed inference, each worker is assigned a partition | ||||
|     of the model. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         vllm_config: VllmConfig, | ||||
|         local_rank: int, | ||||
|         rank: int, | ||||
|         distributed_init_method: str, | ||||
|         is_driver_worker: bool = False, | ||||
|     ) -> None: | ||||
|         WorkerBase.__init__(self, vllm_config=vllm_config) | ||||
|         device_config = self.device_config | ||||
|         parallel_config = self.parallel_config | ||||
|         assert device_config.device_type == "xpu" | ||||
|         assert current_platform.is_xpu() | ||||
|  | ||||
|         self.parallel_config.rank = rank | ||||
|  | ||||
|         self.local_rank = local_rank | ||||
|         self.rank = rank | ||||
|         self.distributed_init_method = distributed_init_method | ||||
|         self.is_driver_worker = is_driver_worker | ||||
|         if parallel_config and is_driver_worker: | ||||
|             assert rank % parallel_config.tensor_parallel_size == 0, \ | ||||
|                    "Driver worker should be rank 0 of tensor parallel group." | ||||
|  | ||||
|         self.model_runner = XPUModelRunner(  # type: ignore | ||||
|             vllm_config=vllm_config, | ||||
|             kv_cache_dtype=self.cache_config.cache_dtype, | ||||
|             is_driver_worker=is_driver_worker, | ||||
|         ) | ||||
|         # Uninitialized cache engine. Will be initialized by | ||||
|         # initialize_cache. | ||||
|         self.cache_engine: List[CacheEngine] | ||||
|         self.gpu_cache: Optional[List[List[torch.Tensor]]] | ||||
|  | ||||
|     def init_device(self) -> None: | ||||
|         if self.device_config.device.type == "xpu" and current_platform.is_xpu( | ||||
|         ): | ||||
|             self.device = torch.device(f"xpu:{self.local_rank}") | ||||
|             torch.xpu.set_device(self.device) | ||||
|             torch.xpu.empty_cache() | ||||
|             self.init_gpu_memory = torch.xpu.get_device_properties( | ||||
|                 self.local_rank).total_memory | ||||
|         else: | ||||
|             raise RuntimeError( | ||||
|                 f"Not support device type: {self.device_config.device}") | ||||
|         # Initialize the distributed environment. | ||||
|         self.init_worker_distributed_environment() | ||||
|         # Initialize the model. | ||||
|         set_random_seed(self.model_config.seed) | ||||
|  | ||||
|     # keep this method for `empty_cache` and `synchronize` api | ||||
|     @torch.inference_mode() | ||||
|     def determine_num_available_blocks(self) -> Tuple[int, int]: | ||||
|         """Profiles the peak memory usage of the model to determine how many | ||||
|         KV blocks may be allocated without OOMs. | ||||
|  | ||||
|         The engine will first conduct a profiling of the existing memory usage. | ||||
|         Then, it calculate the maximum possible number of GPU and CPU blocks | ||||
|         that can be allocated with the remaining free memory. | ||||
|  | ||||
|         Tip: | ||||
|             You may limit the usage of GPU memory | ||||
|             by adjusting the `gpu_memory_utilization` parameter. | ||||
|         """ | ||||
|         # Profile the memory usage of the model and get the maximum number of | ||||
|         # cache blocks that can be allocated with the remaining free memory. | ||||
|         torch.xpu.empty_cache() | ||||
|  | ||||
|         # Execute a forward pass with dummy inputs to profile the memory usage | ||||
|         # of the model. | ||||
|         self.model_runner.profile_run() | ||||
|  | ||||
|         # Calculate the number of blocks that can be allocated with the | ||||
|         # profiled peak memory. | ||||
|         torch.xpu.synchronize() | ||||
|         used_memory = torch.xpu.memory_allocated() | ||||
|         total_gpu_memory = torch.xpu.get_device_properties( | ||||
|             self.local_rank).total_memory | ||||
|         free_gpu_memory = total_gpu_memory - used_memory | ||||
|  | ||||
|         # NOTE(woosuk): Here we assume that the other processes using the same | ||||
|         # GPU did not change their memory usage during the profiling. | ||||
|         peak_memory = self.init_gpu_memory - free_gpu_memory | ||||
|         assert peak_memory > 0, ( | ||||
|             "Error in memory profiling. " | ||||
|             f"Initial free memory {self.init_gpu_memory}, current free memory" | ||||
|             f" {free_gpu_memory}. This happens when the GPU memory was " | ||||
|             "not properly cleaned up before initializing the vLLM instance.") | ||||
|  | ||||
|         cache_block_size = self.get_cache_block_size_bytes() | ||||
|         num_gpu_blocks = int( | ||||
|             (total_gpu_memory * self.cache_config.gpu_memory_utilization - | ||||
|              peak_memory) // cache_block_size) | ||||
|         num_cpu_blocks = int(self.cache_config.swap_space_bytes // | ||||
|                              cache_block_size) | ||||
|         num_gpu_blocks = max(num_gpu_blocks, 0) | ||||
|         num_cpu_blocks = max(num_cpu_blocks, 0) | ||||
|         gc.collect() | ||||
|         torch.xpu.empty_cache() | ||||
|         return num_gpu_blocks, num_cpu_blocks | ||||
|  | ||||
|     def _warm_up_model(self) -> None: | ||||
|         # IPEX don't support capture graph yet | ||||
|         pass | ||||
|  | ||||
|     def init_worker_distributed_environment(self) -> None: | ||||
|         """Initialize the distributed environment.""" | ||||
|  | ||||
|         parallel_config = self.parallel_config | ||||
|         rank = self.rank | ||||
|         distributed_init_method = self.distributed_init_method | ||||
|  | ||||
|         if torch.distributed.is_initialized(): | ||||
|             torch_world_size = torch.distributed.get_world_size() | ||||
|             if torch_world_size != parallel_config.world_size: | ||||
|                 raise RuntimeError( | ||||
|                     "torch.distributed is already initialized but the torch " | ||||
|                     "world size does not match parallel_config.world_size " | ||||
|                     f"({torch_world_size} vs. {parallel_config.world_size}).") | ||||
|         elif not distributed_init_method: | ||||
|             raise ValueError( | ||||
|                 "distributed_init_method must be set if torch.distributed " | ||||
|                 "is not already initialized") | ||||
|         else: | ||||
|             # use sockets as default Level zero IPC exchange backend. By | ||||
|             # default oneccl will use `drmfd` as mechanism which need extra | ||||
|             # dependency (libdrm and drm headers) on your system. | ||||
|             ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") | ||||
|             ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", | ||||
|                                              str(parallel_config.world_size)) | ||||
|             os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT | ||||
|             os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE | ||||
|             os.environ["LOCAL_RANK"] = str(self.local_rank) | ||||
|             init_distributed_environment( | ||||
|                 world_size=parallel_config.world_size, | ||||
|                 rank=rank, | ||||
|                 distributed_init_method=distributed_init_method, | ||||
|                 local_rank=self.local_rank, | ||||
|                 backend="ccl") | ||||
|  | ||||
|         ensure_model_parallel_initialized( | ||||
|             parallel_config.tensor_parallel_size, | ||||
|             parallel_config.pipeline_parallel_size) | ||||
|         # global all_reduce needed for overall oneccl warm up | ||||
|         torch.distributed.all_reduce(torch.zeros(1).xpu()) | ||||
|  | ||||
|         if parallel_config.pipeline_parallel_size > 1: | ||||
|             # Add pp group init to avoid | ||||
|             # p2p communication as the first call | ||||
|             get_pp_group().all_reduce(torch.zeros(1).xpu()) | ||||
		Reference in New Issue
	
	Block a user
	