Compare commits

...

75 Commits

Author SHA1 Message Date
c00ddd6834 Add buffer donation to benchmark 2024-04-30 21:58:47 +00:00
881b884046 Add block size 2024-04-27 22:35:28 +00:00
98a3df0f8d Disable memory tracking 2024-04-26 08:56:26 +00:00
3f6288cc89 Fix for binary cache 2024-04-26 08:56:12 +00:00
408ff4950c Tune pages_per_compute_block 2024-04-26 08:55:23 +00:00
278e8a1adc Add tpu 2024-04-26 08:54:52 +00:00
07be6ed3eb Improve benchmark 2024-04-26 08:54:41 +00:00
f6637dba18 Use persistent cache 2024-04-26 07:09:44 +00:00
707a5f6473 Move JAX-smi to worker 2024-04-26 07:05:51 +00:00
57690a9c09 Fix bucketing 2024-04-26 07:05:27 +00:00
b15db234ba Add precompilation step 2024-04-26 05:43:08 +00:00
d1591f0f1f Add op benchmark scripts 2024-04-26 05:35:19 +00:00
85d4488458 yapf 2024-04-26 05:31:31 +00:00
8d072dbfbd yapf 2024-04-26 05:30:25 +00:00
d830766c0c yapf 2024-04-26 05:30:08 +00:00
5ae2f81c2b Add warmup + formatting 2024-04-26 05:28:09 +00:00
4ea41d01a9 yapf 2024-04-26 05:27:38 +00:00
d16a348477 Add comment 2024-04-26 05:27:27 +00:00
aa092834bb Format gemma.py 2024-04-26 05:26:38 +00:00
d2c6a32c0c Fix is_tpu 2024-04-26 05:26:24 +00:00
21f35c2289 Change version 2024-04-26 05:00:26 +00:00
2aa9831dd3 Minor 2024-04-25 23:40:44 +00:00
028f528aad Fix KV cache shape 2024-04-25 23:38:07 +00:00
fa5bacd5b0 Add warmup 2024-04-25 05:06:41 +00:00
b62170e4e3 Fix scheduler 2024-04-25 05:06:22 +00:00
98eda57899 Add timer 2024-04-25 05:06:11 +00:00
81b8b813f1 Pad to avoid recompilation 2024-04-25 04:43:33 +00:00
e2c7dedb3a Minor 2024-04-25 03:28:53 +00:00
5323969fcf Increase #blocks 2024-04-24 08:56:58 +00:00
f42b4c27d8 Include argmax to jit 2024-04-24 08:56:45 +00:00
620e7646d3 Fix cache write 2024-04-24 08:56:30 +00:00
d5fb1c20c1 Fix JAX jit OOM 2024-04-24 07:52:56 +00:00
092e3d6d6d Remove hardcoded path 2024-04-19 08:18:10 +00:00
84284302d8 Minor 2024-04-19 08:08:25 +00:00
743695f586 Fix write_to_kv_cache 2024-04-19 07:51:54 +00:00
62b870fa07 Use FlashAttention kernel 2024-04-17 20:24:45 +00:00
7e3a230c38 Fix paged_attn 2024-04-17 20:06:26 +00:00
186c88c497 explictly return new_kv_caches 2024-04-17 18:42:34 +00:00
ef762cb110 Write kV 2024-04-17 18:21:39 +00:00
756c4e78d3 Add write_to_cache ops 2024-04-17 18:20:55 +00:00
4880de35d2 Add attn_mask 2024-04-17 18:12:20 +00:00
0fb07c08d0 Minor 2024-04-17 18:08:33 +00:00
e4377dd698 Add model runner 2024-04-17 18:04:54 +00:00
5cb213c85e Add flash-attn op 2024-04-17 18:02:28 +00:00
25bbc21ef6 Minor 2024-04-17 18:02:16 +00:00
b25fcc06c2 Minor 2024-04-17 18:02:13 +00:00
6661c030c4 Add paged_attn op 2024-04-17 18:02:00 +00:00
8888d1c474 Fix logit indices 2024-04-17 18:01:43 +00:00
cedb67028a Add gemma 2024-04-17 17:00:10 +00:00
91b47e3f2f JAX-based TPU worker 2024-04-16 17:37:11 +00:00
6d62e4c6aa Add torch to dependencies 2024-04-16 17:06:35 +00:00
de82e95787 Minor 2024-04-16 17:04:46 +00:00
b3b89cf755 Renew TPU executor 2024-04-16 09:42:15 +00:00
6692a30266 Minor 2024-04-16 09:41:53 +00:00
eb0a0466a9 Add JAX requirements 2024-04-16 08:05:54 +00:00
c59c1e7b2c Remove 2024-04-16 08:05:36 +00:00
d4adf92beb Merge branch 'main' into woosuk-tpu 2024-04-16 07:56:53 +00:00
363e6a950f Fix flashattn 2024-04-10 08:02:40 +00:00
696b653193 yapf 2024-04-10 08:02:21 +00:00
0d6402ddfd Fix requirements 2024-04-10 07:52:45 +00:00
60ff6b8c5c Merge branch 'main' into woosuk-tpu 2024-04-10 07:51:35 +00:00
d899009a63 [WIP] Add TPU worker 2024-04-01 08:24:23 +00:00
6894d3efef Add JAX to requirements.txt 2024-04-01 08:23:59 +00:00
38e3d33a62 Add TPU to device config 2024-04-01 08:23:44 +00:00
02e614d922 [WIP] Add Pallas backend 2024-04-01 08:23:32 +00:00
46b31ed98d Fix RoPE output shape 2024-04-01 08:22:47 +00:00
31d05f7edb yapf 2024-04-01 07:07:57 +00:00
4cdb732cef Add TPU to setup 2024-04-01 07:07:38 +00:00
27c592b97b Add get_dtype_size 2024-04-01 06:33:06 +00:00
5083aa9092 Add TPUExecutor 2024-04-01 03:24:07 +00:00
824521c987 Add TPU to DeviceConfig 2024-04-01 03:19:17 +00:00
3b8f43024f Add is_tpu 2024-04-01 03:18:36 +00:00
d148c2ef00 Add requirements 2024-04-01 03:17:43 +00:00
86f073edd6 Add reference 2024-04-01 02:02:13 +00:00
52a1e908e4 Add TPU gemma 2024-04-01 02:01:28 +00:00
20 changed files with 1445 additions and 14 deletions

View File

@ -0,0 +1,148 @@
import functools
import time
from typing import Tuple
import chex
import jax
import jax.numpy as jnp
_PAD_SLOT_ID = -1
@jax.jit
def write_to_kv_cache1(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> Tuple[jax.Array, jax.Array]:
num_heads = key.shape[-2]
head_size = key.shape[-1]
key = key.reshape(-1, num_heads, head_size)
key = key.transpose((1, 0, 2))
value = value.reshape(-1, num_heads, head_size)
value = value.transpose((1, 0, 2))
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
return k_cache, v_cache
@functools.partial(jax.jit, donate_argnums=(2, 3))
def write_to_kv_cache2(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> Tuple[jax.Array, jax.Array]:
batch_size = slot_mapping.shape[0]
def cond(val: _IteratorState):
return val.idx < batch_size
def body(val: _IteratorState):
k_cache, v_cache = _write_seq_to_kv_cache(
key[val.idx],
value[val.idx],
val.k_cache,
val.v_cache,
slot_mapping[val.idx],
)
val.k_cache = k_cache
val.v_cache = v_cache
val.idx += 1
return val
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.k_cache, iterator.v_cache
@functools.partial(jax.jit, donate_argnums=(2, 3))
def _write_seq_to_kv_cache(
key: jax.Array, # [seq_len, num_heads, head_size]
value: jax.Array, # [seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [seq_len]
) -> Tuple[jax.Array, jax.Array]:
seq_len = slot_mapping.shape[0]
num_heads, _, head_size = k_cache.shape
# Reshape to match the rank of kv_cache.
key = key.reshape(seq_len, num_heads, 1, head_size)
value = value.reshape(seq_len, num_heads, 1, head_size)
def cond(val: _IteratorState):
return jnp.logical_and(
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
def body(val: _IteratorState):
slot_idx = slot_mapping[val.idx]
val.k_cache = jax.lax.dynamic_update_slice(
val.k_cache,
key[val.idx],
(0, slot_idx, 0),
)
val.v_cache = jax.lax.dynamic_update_slice(
val.v_cache,
value[val.idx],
(0, slot_idx, 0),
)
val.idx += 1
return val
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.k_cache, iterator.v_cache
@chex.dataclass
class _IteratorState:
idx: jnp.int32
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
def benchmark_write_to_kv_cache(
batch_size: int,
seq_len: int,
num_kv_heads: int,
head_size: int,
num_blocks: int,
block_size: int,
version: int = 1,
):
if version == 1:
f = write_to_kv_cache1
elif version == 2:
f = write_to_kv_cache2
else:
raise ValueError(f"Invalid version: {version}")
rng_key = jax.random.PRNGKey(0)
key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
# For JIT compilation.
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k_cache.block_until_ready()
start = time.time()
for _ in range(100):
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k_cache.block_until_ready()
end = time.time()
print(f"Time taken: {(end - start) * 10:.2f} ms")
if __name__ == "__main__":
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)

View File

@ -0,0 +1,101 @@
import argparse
import functools
import time
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_SEQ = 512
@functools.partial(jax.jit, static_argnums=(6, 7))
def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
block_size: int,
pages_per_compute_block: int,
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
q = q * sm_scale
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
output = paged_attention(
q,
k_cache,
v_cache,
context_lens,
block_tables,
pages_per_compute_block=pages_per_compute_block,
)
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
def benchmark_paged_attn(
batch_size: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
context_len: int,
num_blocks: int,
block_size: int,
pages_per_compute_block: int,
):
rng_key = jax.random.PRNGKey(0)
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
sm_scale = head_size ** -0.5
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
# For JIT compilation.
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
start = time.time()
for _ in range(100):
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
end = time.time()
print(f"Time taken: {(end - start) * 10000:.2f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--num-kv-heads", type=int, default=16)
parser.add_argument("--head-size", type=int, default=256)
parser.add_argument("--context-len", type=int, default=512)
parser.add_argument("--num-blocks", type=int, default=2048)
args = parser.parse_args()
print(args)
for block_size in [16, 32, 64, 128]:
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
continue
if block_size * pages_per_compute_block > 1024:
continue
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
benchmark_paged_attn(
args.batch_size,
args.num_heads,
args.num_kv_heads,
args.head_size,
args.context_len,
args.num_blocks,
block_size,
pages_per_compute_block,
)

View File

@ -335,7 +335,7 @@ if __name__ == "__main__":
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",

6
requirements-tpu.txt Normal file
View File

@ -0,0 +1,6 @@
# Common dependencies
-r requirements-common.txt
torch
jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
flax >= 0.8

View File

@ -188,9 +188,9 @@ class cmake_build_ext(build_ext):
def _is_cuda() -> bool:
return VLLM_TARGET_DEVICE == "cuda" \
and torch.version.cuda is not None \
and not _is_neuron()
has_cuda = torch.version.cuda is not None
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
and not (_is_neuron() or _is_tpu()))
def _is_hip() -> bool:
@ -207,10 +207,18 @@ def _is_neuron() -> bool:
return torch_neuronx_installed
def _is_tpu() -> bool:
return True # FIXME
def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"
def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()
def _install_punica() -> bool:
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
@ -307,6 +315,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_tpu():
version += "+tpu"
elif _is_cpu():
version += "+cpu"
else:
@ -353,6 +363,8 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
elif _is_tpu():
requirements = _read_requirements("requirements-tpu.txt")
elif _is_cpu():
requirements = _read_requirements("requirements-cpu.txt")
else:
@ -369,7 +381,7 @@ if _is_cuda():
if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
if not _is_neuron():
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
package_data = {
@ -408,6 +420,6 @@ setup(
extras_require={
"tensorizer": ["tensorizer==2.9.0a1"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data,
)

View File

@ -7,7 +7,7 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_cpu, is_hip, is_tpu
logger = init_logger(__name__)
@ -19,6 +19,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
PALLAS = enum.auto()
@lru_cache(maxsize=None)
@ -49,6 +50,8 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"""Returns which flash attention backend to use."""
if is_tpu():
return _Backend.PALLAS
if is_cpu():
return _Backend.TORCH_SDPA

View File

@ -13,7 +13,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
is_neuron, is_tpu)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -620,6 +620,8 @@ class DeviceConfig:
# Automated device type detection
if is_neuron():
self.device_type = "neuron"
elif is_tpu():
self.device_type = "tpu"
elif is_cpu():
self.device_type = "cpu"
else:
@ -633,6 +635,8 @@ class DeviceConfig:
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)

View File

@ -598,6 +598,13 @@ class Scheduler:
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
# FIXME(woosuk): The TPU backend only supports up to 4 sequence
# groups in a single batch.
MAX_BATCH_SIZE = 1
if len(seq_groups) == MAX_BATCH_SIZE:
break
assert len(seq_groups) < MAX_BATCH_SIZE
seq_group = waiting_queue[0]
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)

View File

@ -221,6 +221,9 @@ class LLMEngine:
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor

View File

@ -0,0 +1,98 @@
from typing import Dict, List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_async
logger = init_logger(__name__)
class TPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert not self.speculative_config, (
"Speculative decoding not yet supported for TPU backend")
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
from vllm.worker.tpu_worker import TPUWorker
assert self.parallel_config.world_size == 1, (
"TPUExecutor currently only supports a single TPU chip.")
self.driver_worker = TPUWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.vision_language_config,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info(f"# TPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def list_loras(self) -> Set[int]:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def check_health(self) -> None:
# TPUExecutor will always be healthy as long as it's running.
return
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output

View File

@ -0,0 +1,328 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Gemma transformer."""
from typing import List, Tuple
import jax
import jax.numpy as jnp
from flax import linen as nn
from transformers import GemmaConfig
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_kv_cache
K_MASK = -2.3819763e38 # Set to a large negative number.
class Einsum(nn.Module):
"""Einsum is a convenience module for parameterized tensor multiplication."""
shape: tuple[int, ...]
@nn.compact
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
w = self.param('w', nn.initializers.normal(), self.shape)
return jnp.einsum(eqn, x, w)
class RMSNorm(nn.Module):
"""RMSNorm layer."""
@nn.compact
def __call__(self, x):
scale = self.param('scale', nn.initializers.zeros_init(),
(x.shape[-1]))
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1))
normed_inputs = normed_inputs * (1 + scale)
return normed_inputs
def apply_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
max_wavelength: int = 10_000,
) -> jax.Array:
"""Applies RoPE."""
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
timescale = max_wavelength**fraction
sinusoid_inp = (positions[..., jnp.newaxis] /
timescale[jnp.newaxis, jnp.newaxis, :])
sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :]
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
out = jnp.concatenate([first_part, second_part], axis=-1)
return out.astype(inputs.dtype)
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
'input_embedding',
nn.initializers.normal(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x: jax.Array) -> jax.Array:
x = self.input_embedding_table[(x, )]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x: jax.Array) -> jax.Array:
return jnp.dot(x, self.input_embedding_table.T)
class Attention(nn.Module):
"""Attention module."""
num_heads: int
num_kv_heads: int
features: int
head_dim: int
@property
def use_qkv_einsum(self):
return self.num_kv_heads == self.num_heads
def setup(self):
self.attn_vec_einsum = Einsum(shape=(self.num_heads, self.head_dim,
self.features), )
if self.use_qkv_einsum:
self.qkv_einsum = Einsum(shape=(3, self.num_heads, self.features,
self.head_dim), )
else:
self.q_einsum = Einsum(shape=(self.num_heads, self.features,
self.head_dim), )
self.kv_einsum = Einsum(shape=(2, self.num_kv_heads, self.features,
self.head_dim), )
self.sm_scale = self.head_dim**-0.5
def __call__(
self,
x: jax.Array,
segment_pos: jax.Array,
slot_mapping: jax.Array,
block_tables: jax.Array | None,
context_lens: jax.Array | None,
cache: Tuple[jax.Array, jax.Array],
) -> tuple[jax.Array, jax.Array]:
if self.use_qkv_einsum:
query_proj, key_proj, value_proj = self.qkv_einsum(
'BTD,SNDH->SBTNH', x)
else:
query_proj = self.q_einsum('BTD,NDH->BTNH', x)
key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)
query_proj = apply_rope(
query_proj,
segment_pos,
head_dim=self.head_dim,
)
key_proj = apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
)
# Write the incoming keys and values to KV cache.
k_cache, v_cache = cache
# FIXME(woosuk): Uncomment this.
# k_cache, v_cache = write_to_kv_cache(
# key_proj, value_proj, k_cache, v_cache, slot_mapping)
if block_tables is None:
# Prompt attention.
if not self.use_qkv_einsum:
# MQA/GQA.
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
if True:
# FlashAttention.
output = flash_attn(
query_proj,
key_proj,
value_proj,
self.sm_scale,
)
else:
# Naive attention with masking.
seq_len = query_proj.shape[1]
attn_mask = jnp.tril(
jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
query_scaled = query_proj * self.sm_scale
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
masked_logits = jnp.where((jnp.expand_dims(attn_mask, -2)),
logits, K_MASK)
probs = jax.nn.softmax(masked_logits,
axis=-1).astype(key_proj.dtype)
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
else:
# Decode attention.
output = paged_attn(
query_proj,
k_cache,
v_cache,
self.sm_scale,
block_tables,
context_lens,
)
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output)
return (k_cache, v_cache), attn_output
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
w_gating = self.param(
'gating_einsum',
nn.initializers.zeros_init(),
((2, self.features, self.hidden_dim)),
)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
'linear',
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
)
outputs = jnp.dot(activations, w_linear)
return outputs
class Block(nn.Module):
"""Transformer block."""
num_heads: int
num_kv_heads: int
embed_dim: int
head_dim: int
hidden_dim: int
def setup(self):
self.pre_attention_norm = RMSNorm()
self.attn = Attention(
num_heads=self.num_heads,
features=self.embed_dim,
head_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
)
self.pre_ffw_norm = RMSNorm()
self.mlp = FeedForward(features=self.embed_dim,
hidden_dim=self.hidden_dim)
def __call__(
self,
x: jax.Array,
segment_pos: jax.Array,
slot_mapping: jax.Array,
block_tables: jax.Array | None,
context_lens: jax.Array | None,
cache: Tuple[jax.Array, jax.Array],
) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
inputs_normalized = self.pre_attention_norm(x)
cache, attn_output = self.attn(
inputs_normalized,
segment_pos,
slot_mapping,
block_tables,
context_lens,
cache,
)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)
outputs = self.mlp(attn_output)
outputs = residual + outputs
return outputs, cache
class Transformer(nn.Module):
"""Gemma transformer."""
config: GemmaConfig
def setup(self):
self.embedder = Embedder(
vocab_size=256128, # != self.config.vocab_size
embed_dim=self.config.hidden_size,
)
self.blocks = [
Block(
name=f'layer_{i}',
num_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
embed_dim=self.config.hidden_size,
head_dim=self.config.head_dim,
hidden_dim=self.config.intermediate_size,
) for i in range(self.config.num_hidden_layers)
]
self.final_norm = RMSNorm()
def __call__(
self,
token_ids: jax.Array,
positions: jax.Array,
slot_mapping: jax.Array,
block_tables: jax.Array | None,
context_lens: jax.Array | None,
kv_caches: List[Tuple[jax.Array, jax.Array]],
logits_indices: jax.Array,
) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]:
x = self.embedder.encode(token_ids)
new_caches = []
for i, block in enumerate(self.blocks):
x, new_cache = block(
x,
positions,
slot_mapping,
block_tables,
context_lens,
kv_caches[i],
)
new_caches.append(new_cache)
x = self.final_norm(x)
x = x.reshape(-1, x.shape[-1])
hidden_states = x[logits_indices]
logits = self.embedder.decode(hidden_states)
return logits, new_caches

View File

@ -0,0 +1,29 @@
import jax
from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes, flash_attention
_DEFAULT_BLOCK_SIZES = {
"block_q": 512,
"block_k_major": 512,
"block_k": 512,
"block_b": 2,
}
def flash_attn(
q: jax.Array, # [batch, seq_len, num_heads, head_size]
k: jax.Array, # [batch, seq_len, num_heads, head_size]
v: jax.Array, # [batch, seq_len, num_heads, head_size]
sm_scale: float,
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
return flash_attention(
q.transpose(0, 2, 1, 3),
k.transpose(0, 2, 1, 3),
v.transpose(0, 2, 1, 3),
causal=True,
sm_scale=sm_scale,
block_sizes=BlockSizes(
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
).transpose(0, 2, 1, 3)

View File

@ -0,0 +1,32 @@
import jax
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
block_size: int = 16,
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q * sm_scale
q = q.squeeze(1)
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size,
head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size,
head_size)
output = paged_attention(
q,
k_cache,
v_cache,
context_lens,
block_tables,
pages_per_compute_block=512 // 16, # TODO(woosuk): Tune this value.
)
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])

View File

@ -0,0 +1,102 @@
from typing import Tuple
import chex
import jax
import jax.numpy as jnp
_PAD_SLOT_ID = -1
def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> Tuple[jax.Array, jax.Array]:
num_heads = key.shape[-2]
head_size = key.shape[-1]
key = key.reshape(-1, num_heads, head_size)
key = key.transpose((1, 0, 2))
value = value.reshape(-1, num_heads, head_size)
value = value.transpose((1, 0, 2))
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
return k_cache, v_cache
def _write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> Tuple[jax.Array, jax.Array]:
batch_size = slot_mapping.shape[0]
def cond(val: _IteratorState):
return val.idx < batch_size
def body(val: _IteratorState):
k_cache, v_cache = _write_seq_to_kv_cache(
key[val.idx],
value[val.idx],
val.k_cache,
val.v_cache,
slot_mapping[val.idx],
)
val.k_cache = k_cache
val.v_cache = v_cache
val.idx += 1
return val
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.k_cache, iterator.v_cache
def _write_seq_to_kv_cache(
key: jax.Array, # [seq_len, num_heads, head_size]
value: jax.Array, # [seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [seq_len]
) -> Tuple[jax.Array, jax.Array]:
seq_len = slot_mapping.shape[0]
num_heads, _, head_size = k_cache.shape
# Reshape to match the rank of kv_cache.
key = key.reshape(seq_len, num_heads, 1, head_size)
value = value.reshape(seq_len, num_heads, 1, head_size)
def cond(val: _IteratorState):
return jnp.logical_and(
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
def body(val: _IteratorState):
slot_idx = slot_mapping[val.idx]
val.k_cache = jax.lax.dynamic_update_slice(
val.k_cache,
key[val.idx],
(0, slot_idx, 0),
)
val.v_cache = jax.lax.dynamic_update_slice(
val.v_cache,
value[val.idx],
(0, slot_idx, 0),
)
val.idx += 1
return val
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.k_cache, iterator.v_cache
@chex.dataclass
class _IteratorState:
idx: jnp.int32
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]

View File

@ -138,6 +138,15 @@ def is_neuron() -> bool:
return transformers_neuronx is not None
@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
import libtpu
except ImportError:
libtpu = None
return libtpu is not None
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
@ -490,6 +499,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
return tensor
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()
def merge_dicts(dict1: Dict[Any, List[Any]],
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
"""Merge 2 dicts that have key -> List of items.

View File

@ -6,7 +6,8 @@ import torch
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available,
get_dtype_size)
logger = init_logger(__name__)
@ -97,9 +98,5 @@ class CacheEngine:
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = _get_dtype_size(dtype)
dtype_size = get_dtype_size(dtype)
return dtype_size * total
def _get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()

View File

@ -0,0 +1,392 @@
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import jax
import jax.numpy as jnp
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import pad_to_max_length
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
_MAX_NUM_SEQS = 256
_MAX_NUM_BLOCKS_PER_SEQ = 8192 // 16
class TPUModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
vision_language_config: Optional[VisionLanguageConfig],
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
if model_config is not None and model_config.get_sliding_window():
logger.warning("Sliding window is not supported on TPU. "
"The model will run without sliding window.")
self.model = None
self.block_size = None
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
# FIXME(woosuk)
self.block_tables = np.zeros((_MAX_NUM_SEQS, _MAX_NUM_BLOCKS_PER_SEQ),
dtype=np.int32)
def load_model(self) -> None:
from huggingface_hub import snapshot_download
from vllm.model_executor.models.jax.gemma import Transformer
assert self.model_config.hf_config.model_type == "gemma"
self.model = Transformer(self.model_config.hf_config)
model_name = "google/gemma-7b-flax"
model_dir = snapshot_download(model_name)
params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params}
self.cpu_device = jax.devices("cpu")[0]
def warmup_model(
self,
tpu_caches: List[Tuple[jax.Array, jax.Array]],
) -> List[Tuple[jax.Array, jax.Array]]:
# Prefill
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
if batch_size * seq_len > 8192:
continue
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len),
dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len),
dtype=jnp.int32)
block_tables = None
context_lens = None
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
# Dummy run.
_, tpu_caches = self.compiled_fn(self.params, token_ids,
position_ids, slot_mapping,
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Compilation for prefill done in {(end - start):.2f} s.")
# Decode
start = time.time()
for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]:
seq_len = 1
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
block_tables = jnp.zeros((batch_size, _MAX_NUM_BLOCKS_PER_SEQ),
dtype=jnp.int32)
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
_, tpu_caches = self.compiled_fn(self.params, token_ids,
position_ids, slot_mapping,
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Compilation for decode done in {(end - start):.2f} s.")
return tpu_caches
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
):
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
prompt_lens: List[int] = []
slot_mapping: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
slot_mapping.append([])
for i in range(prompt_len):
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
max_prompt_len = _get_padded_prefill_len(max_prompt_len)
input_tokens = _make_array_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=jnp.int32)
input_positions = _make_array_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=jnp.int32)
slot_mapping = _make_array_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=jnp.int32)
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
return (input_tokens, input_positions, slot_mapping, None, None,
prompt_lens)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
):
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
num_seq_groups = len(seq_group_metadata_list)
batch_size = _get_padded_batch_size(num_seq_groups)
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[i, :len(block_table)] = block_table
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
num_paddings = batch_size - num_seq_groups
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size],
dtype=jnp.int32)
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
return (input_tokens, input_positions, slot_mapping, block_tables,
context_lens, input_lens)
def prepare_input_arrays(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
):
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
return self._prepare_prompt(seq_group_metadata_list)
else:
return self._prepare_decode(seq_group_metadata_list)
def _execute_step(
self,
params: Dict[str, Any],
token_ids: jax.Array,
position_ids: jax.Array,
slot_mapping: jax.Array,
block_tables: Optional[jax.Array],
context_lens: Optional[jax.Array],
input_lens: jax.Array,
kv_caches: List[jax.Array],
) -> tuple[jax.Array, List[jax.Array]]:
batch_size, seq_len = token_ids.shape
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
logits_indices = base_indicies + input_lens - 1
logits, new_kv_caches = self.model.apply(
params,
token_ids,
position_ids,
slot_mapping,
block_tables,
context_lens,
kv_caches,
logits_indices,
)
# TODO(woosuk): Support sampling with temperature and top_p.
next_token_ids = jnp.argmax(logits, axis=-1)
return next_token_ids, new_kv_caches
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[jax.Array, jax.Array]],
) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
start = time.time()
inputs = self.prepare_input_arrays(seq_group_metadata_list)
end = time.time()
# print(inputs[0].shape)
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids, new_kv_caches = self.compiled_fn(
self.params, *inputs, kv_caches)
next_token_ids.block_until_ready()
end = time.time()
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids = jax.device_put(next_token_ids, self.cpu_device)
end = time.time()
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
next_token_ids = next_token_ids.tolist()
i = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
next_token_id = next_token_ids[i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: Logprob(0.0)}))
i += 1
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
return SamplerOutput(sampler_outputs), new_kv_caches
def _make_array_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: jnp.dtype,
) -> jax.Array:
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
return jnp.asarray(padded_x, dtype)
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
elif batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
import functools
from typing import Any, Mapping
import orbax.checkpoint
Params = Mapping[str, Any]
def load_and_format_params(path: str) -> Params:
"""Loads parameters and formats them for compatibility."""
params = load_params(path)
param_state = jax.tree_util.tree_map(jnp.array, params)
remapped_params = param_remapper(param_state)
nested_params = nest_params(remapped_params)
return nested_params
@functools.cache
def load_params(path: str) -> Params:
"""Loads parameters from a checkpoint path."""
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(path)
return params
def param_remapper(orig_params: Params) -> Params:
"""Remaps params to new module layout.
This is needed here because the model definition does not have a separate
`mlp` module.
Args:
orig_params: original dict of parameters in Gemma format.
Returns:
dict of params with different names.
"""
new_params = {}
for k, v in orig_params.items():
if 'mlp/' in k:
layer_name, param = k.rsplit('/', maxsplit=1)
if layer_name not in new_params:
new_params[layer_name] = {}
if 'w' in v:
new_params[layer_name][param] = v['w']
else:
new_params[k] = v
return new_params
def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict."""
nested_params = {}
for path, param in params.items():
*path, leaf = path.split('/')
subdict = nested_params
for key in path:
subdict = subdict.setdefault(key, {})
subdict[leaf] = param
return nested_params

155
vllm/worker/tpu_worker.py Normal file
View File

@ -0,0 +1,155 @@
import os
from typing import Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
import torch
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.utils import get_dtype_size, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
class TPUWorker(LoraNotSupportedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.vision_language_config = vision_language_config
assert self.device_config.device_type == "tpu"
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.model_runner = TPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
vision_language_config=vision_language_config)
self.tpu_cache = None
def init_device(self) -> None:
# Set random seed.
# TODO: Set random seed for JAX
set_random_seed(self.model_config.seed)
# Use persistent cache to avoid recompilation.
jax.config.update("jax_compilation_cache_dir",
os.path.expanduser("~/.vllm/jax_cache"))
# DELETE
# from jax_smi import initialise_tracking
# initialise_tracking()
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
num_tpu_blocks = 2000
return num_tpu_blocks, 0
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.block_size = self.cache_config.block_size
dtype = _torch_dtype_to_jax(self.cache_dtype)
num_layers = self.model_config.get_num_layers(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
self.tpu_cache = []
for _ in range(num_layers):
key_cache = jnp.zeros(
(num_kv_heads, num_gpu_blocks * self.block_size, head_size),
dtype=dtype)
value_cache = jnp.zeros_like(key_cache)
self.tpu_cache.append((key_cache, value_cache))
self.model_runner.block_size = self.block_size
self._warmup_model()
def _warmup_model(self) -> None:
# NOTE(woosuk): Because of buffer donation, the reference to the cache
# should be updated after the warmup.
self.tpu_cache = self.model_runner.warmup_model(self.tpu_cache)
def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size()
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
num_layers = self.model_config.get_num_layers(self.parallel_config)
key_cache_block = self.cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
assert len(blocks_to_copy) == 0
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
output, kv_caches = self.model_runner.execute_model(
seq_group_metadata_list, self.tpu_cache)
self.tpu_cache = kv_caches
return output
def _torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
mapping = {
torch.float32: jnp.float32,
torch.float16: jnp.float16,
torch.bfloat16: jnp.bfloat16,
}
return mapping[dtype]