mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
75 Commits
v0.11.0rc5
...
jax-tpu
Author | SHA1 | Date | |
---|---|---|---|
c00ddd6834 | |||
881b884046 | |||
98a3df0f8d | |||
3f6288cc89 | |||
408ff4950c | |||
278e8a1adc | |||
07be6ed3eb | |||
f6637dba18 | |||
707a5f6473 | |||
57690a9c09 | |||
b15db234ba | |||
d1591f0f1f | |||
85d4488458 | |||
8d072dbfbd | |||
d830766c0c | |||
5ae2f81c2b | |||
4ea41d01a9 | |||
d16a348477 | |||
aa092834bb | |||
d2c6a32c0c | |||
21f35c2289 | |||
2aa9831dd3 | |||
028f528aad | |||
fa5bacd5b0 | |||
b62170e4e3 | |||
98eda57899 | |||
81b8b813f1 | |||
e2c7dedb3a | |||
5323969fcf | |||
f42b4c27d8 | |||
620e7646d3 | |||
d5fb1c20c1 | |||
092e3d6d6d | |||
84284302d8 | |||
743695f586 | |||
62b870fa07 | |||
7e3a230c38 | |||
186c88c497 | |||
ef762cb110 | |||
756c4e78d3 | |||
4880de35d2 | |||
0fb07c08d0 | |||
e4377dd698 | |||
5cb213c85e | |||
25bbc21ef6 | |||
b25fcc06c2 | |||
6661c030c4 | |||
8888d1c474 | |||
cedb67028a | |||
91b47e3f2f | |||
6d62e4c6aa | |||
de82e95787 | |||
b3b89cf755 | |||
6692a30266 | |||
eb0a0466a9 | |||
c59c1e7b2c | |||
d4adf92beb | |||
363e6a950f | |||
696b653193 | |||
0d6402ddfd | |||
60ff6b8c5c | |||
d899009a63 | |||
6894d3efef | |||
38e3d33a62 | |||
02e614d922 | |||
46b31ed98d | |||
31d05f7edb | |||
4cdb732cef | |||
27c592b97b | |||
5083aa9092 | |||
824521c987 | |||
3b8f43024f | |||
d148c2ef00 | |||
86f073edd6 | |||
52a1e908e4 |
148
benchmarks/bench_cache_write.py
Normal file
148
benchmarks/bench_cache_write.py
Normal file
@ -0,0 +1,148 @@
|
||||
import functools
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@jax.jit
|
||||
def write_to_kv_cache1(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
num_heads = key.shape[-2]
|
||||
head_size = key.shape[-1]
|
||||
|
||||
key = key.reshape(-1, num_heads, head_size)
|
||||
key = key.transpose((1, 0, 2))
|
||||
value = value.reshape(-1, num_heads, head_size)
|
||||
value = value.transpose((1, 0, 2))
|
||||
|
||||
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
||||
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||
def write_to_kv_cache2(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
batch_size = slot_mapping.shape[0]
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return val.idx < batch_size
|
||||
|
||||
def body(val: _IteratorState):
|
||||
k_cache, v_cache = _write_seq_to_kv_cache(
|
||||
key[val.idx],
|
||||
value[val.idx],
|
||||
val.k_cache,
|
||||
val.v_cache,
|
||||
slot_mapping[val.idx],
|
||||
)
|
||||
val.k_cache = k_cache
|
||||
val.v_cache = v_cache
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||
def _write_seq_to_kv_cache(
|
||||
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
seq_len = slot_mapping.shape[0]
|
||||
num_heads, _, head_size = k_cache.shape
|
||||
# Reshape to match the rank of kv_cache.
|
||||
key = key.reshape(seq_len, num_heads, 1, head_size)
|
||||
value = value.reshape(seq_len, num_heads, 1, head_size)
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return jnp.logical_and(
|
||||
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
||||
|
||||
def body(val: _IteratorState):
|
||||
slot_idx = slot_mapping[val.idx]
|
||||
val.k_cache = jax.lax.dynamic_update_slice(
|
||||
val.k_cache,
|
||||
key[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.v_cache = jax.lax.dynamic_update_slice(
|
||||
val.v_cache,
|
||||
value[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@chex.dataclass
|
||||
class _IteratorState:
|
||||
|
||||
idx: jnp.int32
|
||||
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
|
||||
|
||||
def benchmark_write_to_kv_cache(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
version: int = 1,
|
||||
):
|
||||
if version == 1:
|
||||
f = write_to_kv_cache1
|
||||
elif version == 2:
|
||||
f = write_to_kv_cache2
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
|
||||
rng_key = jax.random.PRNGKey(0)
|
||||
key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
||||
value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
||||
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
||||
|
||||
# For JIT compilation.
|
||||
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k_cache.block_until_ready()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k_cache.block_until_ready()
|
||||
end = time.time()
|
||||
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
|
||||
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
||||
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)
|
101
benchmarks/bench_paged_attn.py
Normal file
101
benchmarks/bench_paged_attn.py
Normal file
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import functools
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
MAX_NUM_BLOCKS_PER_SEQ = 512
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(6, 7))
|
||||
def paged_attn(
|
||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
sm_scale: float,
|
||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||
context_lens: jax.Array, # [batch]
|
||||
block_size: int,
|
||||
pages_per_compute_block: int,
|
||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||
q = q.squeeze(1)
|
||||
q = q * sm_scale
|
||||
|
||||
head_size = q.shape[-1]
|
||||
num_slots = k_cache.shape[-2]
|
||||
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||
|
||||
output = paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block=pages_per_compute_block,
|
||||
)
|
||||
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
||||
|
||||
|
||||
def benchmark_paged_attn(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
context_len: int,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
pages_per_compute_block: int,
|
||||
):
|
||||
rng_key = jax.random.PRNGKey(0)
|
||||
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
||||
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
sm_scale = head_size ** -0.5
|
||||
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
|
||||
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
||||
|
||||
# For JIT compilation.
|
||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||
output.block_until_ready()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||
output.block_until_ready()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time taken: {(end - start) * 10000:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--num-heads", type=int, default=16)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=16)
|
||||
parser.add_argument("--head-size", type=int, default=256)
|
||||
parser.add_argument("--context-len", type=int, default=512)
|
||||
parser.add_argument("--num-blocks", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
for block_size in [16, 32, 64, 128]:
|
||||
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
|
||||
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
|
||||
continue
|
||||
if block_size * pages_per_compute_block > 1024:
|
||||
continue
|
||||
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
|
||||
benchmark_paged_attn(
|
||||
args.batch_size,
|
||||
args.num_heads,
|
||||
args.num_kv_heads,
|
||||
args.head_size,
|
||||
args.context_len,
|
||||
args.num_blocks,
|
||||
block_size,
|
||||
pages_per_compute_block,
|
||||
)
|
@ -335,7 +335,7 @@ if __name__ == "__main__":
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "cpu"],
|
||||
choices=["cuda", "cpu", "tpu"],
|
||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
|
6
requirements-tpu.txt
Normal file
6
requirements-tpu.txt
Normal file
@ -0,0 +1,6 @@
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
torch
|
||||
jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
flax >= 0.8
|
22
setup.py
22
setup.py
@ -188,9 +188,9 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "cuda" \
|
||||
and torch.version.cuda is not None \
|
||||
and not _is_neuron()
|
||||
has_cuda = torch.version.cuda is not None
|
||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
|
||||
and not (_is_neuron() or _is_tpu()))
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
@ -207,10 +207,18 @@ def _is_neuron() -> bool:
|
||||
return torch_neuronx_installed
|
||||
|
||||
|
||||
def _is_tpu() -> bool:
|
||||
return True # FIXME
|
||||
|
||||
|
||||
def _is_cpu() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "cpu"
|
||||
|
||||
|
||||
def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
|
||||
|
||||
def _install_punica() -> bool:
|
||||
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
|
||||
|
||||
@ -307,6 +315,8 @@ def get_vllm_version() -> str:
|
||||
if neuron_version != MAIN_CUDA_VERSION:
|
||||
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||
version += f"+neuron{neuron_version_str}"
|
||||
elif _is_tpu():
|
||||
version += "+tpu"
|
||||
elif _is_cpu():
|
||||
version += "+cpu"
|
||||
else:
|
||||
@ -353,6 +363,8 @@ def get_requirements() -> List[str]:
|
||||
requirements = _read_requirements("requirements-rocm.txt")
|
||||
elif _is_neuron():
|
||||
requirements = _read_requirements("requirements-neuron.txt")
|
||||
elif _is_tpu():
|
||||
requirements = _read_requirements("requirements-tpu.txt")
|
||||
elif _is_cpu():
|
||||
requirements = _read_requirements("requirements-cpu.txt")
|
||||
else:
|
||||
@ -369,7 +381,7 @@ if _is_cuda():
|
||||
if _install_punica():
|
||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||
|
||||
if not _is_neuron():
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
package_data = {
|
||||
@ -408,6 +420,6 @@ setup(
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer==2.9.0a1"],
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
||||
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||
package_data=package_data,
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -19,6 +19,7 @@ class _Backend(enum.Enum):
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@ -49,6 +50,8 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
|
||||
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
if is_tpu():
|
||||
return _Backend.PALLAS
|
||||
if is_cpu():
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
|
@ -13,7 +13,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||
is_neuron)
|
||||
is_neuron, is_tpu)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -620,6 +620,8 @@ class DeviceConfig:
|
||||
# Automated device type detection
|
||||
if is_neuron():
|
||||
self.device_type = "neuron"
|
||||
elif is_tpu():
|
||||
self.device_type = "tpu"
|
||||
elif is_cpu():
|
||||
self.device_type = "cpu"
|
||||
else:
|
||||
@ -633,6 +635,8 @@ class DeviceConfig:
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
self.device = torch.device("cpu")
|
||||
elif self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
|
@ -598,6 +598,13 @@ class Scheduler:
|
||||
|
||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||
while self._passed_delay(time.time()) and waiting_queue:
|
||||
# FIXME(woosuk): The TPU backend only supports up to 4 sequence
|
||||
# groups in a single batch.
|
||||
MAX_BATCH_SIZE = 1
|
||||
if len(seq_groups) == MAX_BATCH_SIZE:
|
||||
break
|
||||
assert len(seq_groups) < MAX_BATCH_SIZE
|
||||
|
||||
seq_group = waiting_queue[0]
|
||||
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
|
@ -221,6 +221,9 @@ class LLMEngine:
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutor
|
||||
executor_class = NeuronExecutor
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
from vllm.executor.tpu_executor import TPUExecutor
|
||||
executor_class = TPUExecutor
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutor
|
||||
executor_class = CPUExecutor
|
||||
|
98
vllm/executor/tpu_executor.py
Normal file
98
vllm/executor/tpu_executor.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUExecutor(ExecutorBase):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert not self.speculative_config, (
|
||||
"Speculative decoding not yet supported for TPU backend")
|
||||
# Instantiate the worker and load the model to the device.
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"TPUExecutor currently only supports a single TPU chip.")
|
||||
self.driver_worker = TPUWorker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
self.cache_config,
|
||||
self.vision_language_config,
|
||||
)
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker."""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info(f"# TPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> SamplerOutput:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||
|
||||
def check_health(self) -> None:
|
||||
# TPUExecutor will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> SamplerOutput:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy)
|
||||
return output
|
0
vllm/model_executor/models/jax/__init__.py
Normal file
0
vllm/model_executor/models/jax/__init__.py
Normal file
328
vllm/model_executor/models/jax/gemma.py
Normal file
328
vllm/model_executor/models/jax/gemma.py
Normal file
@ -0,0 +1,328 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Gemma transformer."""
|
||||
from typing import List, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
from transformers import GemmaConfig
|
||||
|
||||
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
|
||||
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
|
||||
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_kv_cache
|
||||
|
||||
K_MASK = -2.3819763e38 # Set to a large negative number.
|
||||
|
||||
|
||||
class Einsum(nn.Module):
|
||||
"""Einsum is a convenience module for parameterized tensor multiplication."""
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
|
||||
w = self.param('w', nn.initializers.normal(), self.shape)
|
||||
return jnp.einsum(eqn, x, w)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""RMSNorm layer."""
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
scale = self.param('scale', nn.initializers.zeros_init(),
|
||||
(x.shape[-1]))
|
||||
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
|
||||
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
|
||||
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
|
||||
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
|
||||
scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1))
|
||||
normed_inputs = normed_inputs * (1 + scale)
|
||||
return normed_inputs
|
||||
|
||||
|
||||
def apply_rope(
|
||||
inputs: jax.Array, # [B, L]
|
||||
positions: jax.Array, # [B, L]
|
||||
head_dim: int,
|
||||
max_wavelength: int = 10_000,
|
||||
) -> jax.Array:
|
||||
"""Applies RoPE."""
|
||||
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
|
||||
timescale = max_wavelength**fraction
|
||||
|
||||
sinusoid_inp = (positions[..., jnp.newaxis] /
|
||||
timescale[jnp.newaxis, jnp.newaxis, :])
|
||||
sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :]
|
||||
sin = jnp.sin(sinusoid_inp)
|
||||
cos = jnp.cos(sinusoid_inp)
|
||||
|
||||
first_half, second_half = jnp.split(inputs, 2, axis=-1)
|
||||
first_part = first_half * cos - second_half * sin
|
||||
second_part = second_half * cos + first_half * sin
|
||||
out = jnp.concatenate([first_part, second_part], axis=-1)
|
||||
return out.astype(inputs.dtype)
|
||||
|
||||
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
'input_embedding',
|
||||
nn.initializers.normal(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x: jax.Array) -> jax.Array:
|
||||
x = self.input_embedding_table[(x, )]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x: jax.Array) -> jax.Array:
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
features: int
|
||||
head_dim: int
|
||||
|
||||
@property
|
||||
def use_qkv_einsum(self):
|
||||
return self.num_kv_heads == self.num_heads
|
||||
|
||||
def setup(self):
|
||||
self.attn_vec_einsum = Einsum(shape=(self.num_heads, self.head_dim,
|
||||
self.features), )
|
||||
|
||||
if self.use_qkv_einsum:
|
||||
self.qkv_einsum = Einsum(shape=(3, self.num_heads, self.features,
|
||||
self.head_dim), )
|
||||
else:
|
||||
self.q_einsum = Einsum(shape=(self.num_heads, self.features,
|
||||
self.head_dim), )
|
||||
self.kv_einsum = Einsum(shape=(2, self.num_kv_heads, self.features,
|
||||
self.head_dim), )
|
||||
self.sm_scale = self.head_dim**-0.5
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: jax.Array,
|
||||
segment_pos: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: jax.Array | None,
|
||||
context_lens: jax.Array | None,
|
||||
cache: Tuple[jax.Array, jax.Array],
|
||||
) -> tuple[jax.Array, jax.Array]:
|
||||
if self.use_qkv_einsum:
|
||||
query_proj, key_proj, value_proj = self.qkv_einsum(
|
||||
'BTD,SNDH->SBTNH', x)
|
||||
else:
|
||||
query_proj = self.q_einsum('BTD,NDH->BTNH', x)
|
||||
key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)
|
||||
|
||||
query_proj = apply_rope(
|
||||
query_proj,
|
||||
segment_pos,
|
||||
head_dim=self.head_dim,
|
||||
)
|
||||
key_proj = apply_rope(
|
||||
key_proj,
|
||||
segment_pos,
|
||||
head_dim=self.head_dim,
|
||||
)
|
||||
|
||||
# Write the incoming keys and values to KV cache.
|
||||
k_cache, v_cache = cache
|
||||
# FIXME(woosuk): Uncomment this.
|
||||
# k_cache, v_cache = write_to_kv_cache(
|
||||
# key_proj, value_proj, k_cache, v_cache, slot_mapping)
|
||||
|
||||
if block_tables is None:
|
||||
# Prompt attention.
|
||||
if not self.use_qkv_einsum:
|
||||
# MQA/GQA.
|
||||
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
|
||||
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
|
||||
|
||||
if True:
|
||||
# FlashAttention.
|
||||
output = flash_attn(
|
||||
query_proj,
|
||||
key_proj,
|
||||
value_proj,
|
||||
self.sm_scale,
|
||||
)
|
||||
else:
|
||||
# Naive attention with masking.
|
||||
seq_len = query_proj.shape[1]
|
||||
attn_mask = jnp.tril(
|
||||
jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
||||
|
||||
query_scaled = query_proj * self.sm_scale
|
||||
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
|
||||
masked_logits = jnp.where((jnp.expand_dims(attn_mask, -2)),
|
||||
logits, K_MASK)
|
||||
probs = jax.nn.softmax(masked_logits,
|
||||
axis=-1).astype(key_proj.dtype)
|
||||
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
else:
|
||||
# Decode attention.
|
||||
output = paged_attn(
|
||||
query_proj,
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.sm_scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
)
|
||||
|
||||
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output)
|
||||
return (k_cache, v_cache), attn_output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
w_gating = self.param(
|
||||
'gating_einsum',
|
||||
nn.initializers.zeros_init(),
|
||||
((2, self.features, self.hidden_dim)),
|
||||
)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
'linear',
|
||||
nn.initializers.zeros_init(),
|
||||
(self.hidden_dim, self.features),
|
||||
)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
return outputs
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
embed_dim: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = FeedForward(features=self.embed_dim,
|
||||
hidden_dim=self.hidden_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: jax.Array,
|
||||
segment_pos: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: jax.Array | None,
|
||||
context_lens: jax.Array | None,
|
||||
cache: Tuple[jax.Array, jax.Array],
|
||||
) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
|
||||
inputs_normalized = self.pre_attention_norm(x)
|
||||
cache, attn_output = self.attn(
|
||||
inputs_normalized,
|
||||
segment_pos,
|
||||
slot_mapping,
|
||||
block_tables,
|
||||
context_lens,
|
||||
cache,
|
||||
)
|
||||
attn_output += x
|
||||
residual = attn_output
|
||||
attn_output = self.pre_ffw_norm(attn_output)
|
||||
outputs = self.mlp(attn_output)
|
||||
outputs = residual + outputs
|
||||
return outputs, cache
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Gemma transformer."""
|
||||
|
||||
config: GemmaConfig
|
||||
|
||||
def setup(self):
|
||||
self.embedder = Embedder(
|
||||
vocab_size=256128, # != self.config.vocab_size
|
||||
embed_dim=self.config.hidden_size,
|
||||
)
|
||||
self.blocks = [
|
||||
Block(
|
||||
name=f'layer_{i}',
|
||||
num_heads=self.config.num_attention_heads,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
embed_dim=self.config.hidden_size,
|
||||
head_dim=self.config.head_dim,
|
||||
hidden_dim=self.config.intermediate_size,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
self.final_norm = RMSNorm()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
token_ids: jax.Array,
|
||||
positions: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: jax.Array | None,
|
||||
context_lens: jax.Array | None,
|
||||
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||
logits_indices: jax.Array,
|
||||
) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]:
|
||||
x = self.embedder.encode(token_ids)
|
||||
new_caches = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, new_cache = block(
|
||||
x,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_tables,
|
||||
context_lens,
|
||||
kv_caches[i],
|
||||
)
|
||||
new_caches.append(new_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
hidden_states = x[logits_indices]
|
||||
logits = self.embedder.decode(hidden_states)
|
||||
return logits, new_caches
|
0
vllm/model_executor/models/jax/ops/__init__.py
Normal file
0
vllm/model_executor/models/jax/ops/__init__.py
Normal file
29
vllm/model_executor/models/jax/ops/flash_attn.py
Normal file
29
vllm/model_executor/models/jax/ops/flash_attn.py
Normal file
@ -0,0 +1,29 @@
|
||||
import jax
|
||||
from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes, flash_attention
|
||||
|
||||
_DEFAULT_BLOCK_SIZES = {
|
||||
"block_q": 512,
|
||||
"block_k_major": 512,
|
||||
"block_k": 512,
|
||||
"block_b": 2,
|
||||
}
|
||||
|
||||
|
||||
def flash_attn(
|
||||
q: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||
k: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||
v: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||
sm_scale: float,
|
||||
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
|
||||
return flash_attention(
|
||||
q.transpose(0, 2, 1, 3),
|
||||
k.transpose(0, 2, 1, 3),
|
||||
v.transpose(0, 2, 1, 3),
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
block_sizes=BlockSizes(
|
||||
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
|
||||
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
|
||||
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
|
||||
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
|
||||
).transpose(0, 2, 1, 3)
|
32
vllm/model_executor/models/jax/ops/paged_attn.py
Normal file
32
vllm/model_executor/models/jax/ops/paged_attn.py
Normal file
@ -0,0 +1,32 @@
|
||||
import jax
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||
|
||||
|
||||
def paged_attn(
|
||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
sm_scale: float,
|
||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||
context_lens: jax.Array, # [batch]
|
||||
block_size: int = 16,
|
||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||
q = q * sm_scale
|
||||
q = q.squeeze(1)
|
||||
|
||||
head_size = q.shape[-1]
|
||||
num_slots = k_cache.shape[-2]
|
||||
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size,
|
||||
head_size)
|
||||
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size,
|
||||
head_size)
|
||||
|
||||
output = paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block=512 // 16, # TODO(woosuk): Tune this value.
|
||||
)
|
||||
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
102
vllm/model_executor/models/jax/ops/write_to_cache.py
Normal file
102
vllm/model_executor/models/jax/ops/write_to_cache.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Tuple
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
num_heads = key.shape[-2]
|
||||
head_size = key.shape[-1]
|
||||
|
||||
key = key.reshape(-1, num_heads, head_size)
|
||||
key = key.transpose((1, 0, 2))
|
||||
value = value.reshape(-1, num_heads, head_size)
|
||||
value = value.transpose((1, 0, 2))
|
||||
|
||||
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
||||
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def _write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
batch_size = slot_mapping.shape[0]
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return val.idx < batch_size
|
||||
|
||||
def body(val: _IteratorState):
|
||||
k_cache, v_cache = _write_seq_to_kv_cache(
|
||||
key[val.idx],
|
||||
value[val.idx],
|
||||
val.k_cache,
|
||||
val.v_cache,
|
||||
slot_mapping[val.idx],
|
||||
)
|
||||
val.k_cache = k_cache
|
||||
val.v_cache = v_cache
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
def _write_seq_to_kv_cache(
|
||||
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
seq_len = slot_mapping.shape[0]
|
||||
num_heads, _, head_size = k_cache.shape
|
||||
# Reshape to match the rank of kv_cache.
|
||||
key = key.reshape(seq_len, num_heads, 1, head_size)
|
||||
value = value.reshape(seq_len, num_heads, 1, head_size)
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return jnp.logical_and(
|
||||
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
||||
|
||||
def body(val: _IteratorState):
|
||||
slot_idx = slot_mapping[val.idx]
|
||||
val.k_cache = jax.lax.dynamic_update_slice(
|
||||
val.k_cache,
|
||||
key[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.v_cache = jax.lax.dynamic_update_slice(
|
||||
val.v_cache,
|
||||
value[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@chex.dataclass
|
||||
class _IteratorState:
|
||||
|
||||
idx: jnp.int32
|
||||
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
@ -138,6 +138,15 @@ def is_neuron() -> bool:
|
||||
return transformers_neuronx is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_tpu() -> bool:
|
||||
try:
|
||||
import libtpu
|
||||
except ImportError:
|
||||
libtpu = None
|
||||
return libtpu is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
@ -490,6 +499,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
|
||||
return tensor
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
def merge_dicts(dict1: Dict[Any, List[Any]],
|
||||
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
|
||||
"""Merge 2 dicts that have key -> List of items.
|
||||
|
@ -6,7 +6,8 @@ import torch
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available,
|
||||
get_dtype_size)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -97,9 +98,5 @@ class CacheEngine:
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
dtype_size = _get_dtype_size(dtype)
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
392
vllm/worker/tpu_model_runner.py
Normal file
392
vllm/worker/tpu_model_runner.py
Normal file
@ -0,0 +1,392 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import pad_to_max_length
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
_MAX_NUM_SEQS = 256
|
||||
_MAX_NUM_BLOCKS_PER_SEQ = 8192 // 16
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
if model_config is not None and model_config.get_sliding_window():
|
||||
logger.warning("Sliding window is not supported on TPU. "
|
||||
"The model will run without sliding window.")
|
||||
self.model = None
|
||||
self.block_size = None
|
||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
|
||||
# FIXME(woosuk)
|
||||
self.block_tables = np.zeros((_MAX_NUM_SEQS, _MAX_NUM_BLOCKS_PER_SEQ),
|
||||
dtype=np.int32)
|
||||
|
||||
def load_model(self) -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.model_executor.models.jax.gemma import Transformer
|
||||
|
||||
assert self.model_config.hf_config.model_type == "gemma"
|
||||
self.model = Transformer(self.model_config.hf_config)
|
||||
|
||||
model_name = "google/gemma-7b-flax"
|
||||
model_dir = snapshot_download(model_name)
|
||||
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||
self.params = {"params": params}
|
||||
self.cpu_device = jax.devices("cpu")[0]
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
tpu_caches: List[Tuple[jax.Array, jax.Array]],
|
||||
) -> List[Tuple[jax.Array, jax.Array]]:
|
||||
# Prefill
|
||||
logger.info("Compiling the model with different input shapes...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
if batch_size * seq_len > 8192:
|
||||
continue
|
||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
position_ids = jnp.zeros((batch_size, seq_len),
|
||||
dtype=jnp.int32)
|
||||
slot_mapping = jnp.zeros((batch_size, seq_len),
|
||||
dtype=jnp.int32)
|
||||
block_tables = None
|
||||
context_lens = None
|
||||
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
|
||||
# Dummy run.
|
||||
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||
position_ids, slot_mapping,
|
||||
block_tables, context_lens,
|
||||
prompt_lens, tpu_caches)
|
||||
end = time.time()
|
||||
logger.info(f"Compilation for prefill done in {(end - start):.2f} s.")
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]:
|
||||
seq_len = 1
|
||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
block_tables = jnp.zeros((batch_size, _MAX_NUM_BLOCKS_PER_SEQ),
|
||||
dtype=jnp.int32)
|
||||
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
|
||||
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||
position_ids, slot_mapping,
|
||||
block_tables, context_lens,
|
||||
prompt_lens, tpu_caches)
|
||||
end = time.time()
|
||||
logger.info(f"Compilation for decode done in {(end - start):.2f} s.")
|
||||
return tpu_caches
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
prompt_lens: List[int] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
slot_mapping.append([])
|
||||
for i in range(prompt_len):
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
assert max_prompt_len > 0
|
||||
max_prompt_len = _get_padded_prefill_len(max_prompt_len)
|
||||
|
||||
input_tokens = _make_array_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=jnp.int32)
|
||||
input_positions = _make_array_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=jnp.int32)
|
||||
slot_mapping = _make_array_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=jnp.int32)
|
||||
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
||||
return (input_tokens, input_positions, slot_mapping, None, None,
|
||||
prompt_lens)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
batch_size = _get_padded_batch_size(num_seq_groups)
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
self.block_tables[i, :len(block_table)] = block_table
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
num_paddings = batch_size - num_seq_groups
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
input_positions = input_positions + [[0]] * num_paddings
|
||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||
context_lens = context_lens + [0] * num_paddings
|
||||
|
||||
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
|
||||
input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
|
||||
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
||||
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
||||
|
||||
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||
dtype=jnp.int32)
|
||||
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
|
||||
return (input_tokens, input_positions, slot_mapping, block_tables,
|
||||
context_lens, input_lens)
|
||||
|
||||
def prepare_input_arrays(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
):
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
return self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
return self._prepare_decode(seq_group_metadata_list)
|
||||
|
||||
def _execute_step(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
token_ids: jax.Array,
|
||||
position_ids: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: Optional[jax.Array],
|
||||
context_lens: Optional[jax.Array],
|
||||
input_lens: jax.Array,
|
||||
kv_caches: List[jax.Array],
|
||||
) -> tuple[jax.Array, List[jax.Array]]:
|
||||
batch_size, seq_len = token_ids.shape
|
||||
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
||||
logits_indices = base_indicies + input_lens - 1
|
||||
|
||||
logits, new_kv_caches = self.model.apply(
|
||||
params,
|
||||
token_ids,
|
||||
position_ids,
|
||||
slot_mapping,
|
||||
block_tables,
|
||||
context_lens,
|
||||
kv_caches,
|
||||
logits_indices,
|
||||
)
|
||||
# TODO(woosuk): Support sampling with temperature and top_p.
|
||||
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||
return next_token_ids, new_kv_caches
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||
) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
|
||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||
|
||||
start = time.time()
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
end = time.time()
|
||||
# print(inputs[0].shape)
|
||||
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids, new_kv_caches = self.compiled_fn(
|
||||
self.params, *inputs, kv_caches)
|
||||
next_token_ids.block_until_ready()
|
||||
end = time.time()
|
||||
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids = jax.device_put(next_token_ids, self.cpu_device)
|
||||
end = time.time()
|
||||
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
i = 0
|
||||
|
||||
sampler_outputs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_outputs = []
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[i]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: Logprob(0.0)}))
|
||||
i += 1
|
||||
|
||||
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
||||
return SamplerOutput(sampler_outputs), new_kv_caches
|
||||
|
||||
|
||||
def _make_array_with_pad(
|
||||
x: List[List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: jnp.dtype,
|
||||
) -> jax.Array:
|
||||
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
|
||||
return jnp.asarray(padded_x, dtype)
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
elif batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Mapping
|
||||
|
||||
import orbax.checkpoint
|
||||
|
||||
Params = Mapping[str, Any]
|
||||
|
||||
|
||||
def load_and_format_params(path: str) -> Params:
|
||||
"""Loads parameters and formats them for compatibility."""
|
||||
params = load_params(path)
|
||||
param_state = jax.tree_util.tree_map(jnp.array, params)
|
||||
remapped_params = param_remapper(param_state)
|
||||
nested_params = nest_params(remapped_params)
|
||||
return nested_params
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_params(path: str) -> Params:
|
||||
"""Loads parameters from a checkpoint path."""
|
||||
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||
params = checkpointer.restore(path)
|
||||
return params
|
||||
|
||||
|
||||
def param_remapper(orig_params: Params) -> Params:
|
||||
"""Remaps params to new module layout.
|
||||
|
||||
This is needed here because the model definition does not have a separate
|
||||
`mlp` module.
|
||||
|
||||
Args:
|
||||
orig_params: original dict of parameters in Gemma format.
|
||||
|
||||
Returns:
|
||||
dict of params with different names.
|
||||
"""
|
||||
new_params = {}
|
||||
for k, v in orig_params.items():
|
||||
if 'mlp/' in k:
|
||||
layer_name, param = k.rsplit('/', maxsplit=1)
|
||||
if layer_name not in new_params:
|
||||
new_params[layer_name] = {}
|
||||
if 'w' in v:
|
||||
new_params[layer_name][param] = v['w']
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def nest_params(params: Params) -> Params:
|
||||
"""Nests params as a dict of dicts rather than a flat dict."""
|
||||
nested_params = {}
|
||||
for path, param in params.items():
|
||||
*path, leaf = path.split('/')
|
||||
subdict = nested_params
|
||||
for key in path:
|
||||
subdict = subdict.setdefault(key, {})
|
||||
subdict[leaf] = param
|
||||
return nested_params
|
155
vllm/worker/tpu_worker.py
Normal file
155
vllm/worker/tpu_worker.py
Normal file
@ -0,0 +1,155 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.tpu_model_runner import TPUModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
from vllm.utils import get_dtype_size, STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||
|
||||
Each worker is associated with a single CPU socket. The worker is
|
||||
responsible for maintaining the KV cache and executing the model on the
|
||||
CPU. In case of distributed inference, each worker is assigned a partition
|
||||
of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.vision_language_config = vision_language_config
|
||||
assert self.device_config.device_type == "tpu"
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.model_runner = TPUModelRunner(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
vision_language_config=vision_language_config)
|
||||
self.tpu_cache = None
|
||||
|
||||
def init_device(self) -> None:
|
||||
# Set random seed.
|
||||
# TODO: Set random seed for JAX
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Use persistent cache to avoid recompilation.
|
||||
jax.config.update("jax_compilation_cache_dir",
|
||||
os.path.expanduser("~/.vllm/jax_cache"))
|
||||
|
||||
# DELETE
|
||||
# from jax_smi import initialise_tracking
|
||||
# initialise_tracking()
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
num_tpu_blocks = 2000
|
||||
return num_tpu_blocks, 0
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
dtype = _torch_dtype_to_jax(self.cache_dtype)
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
self.tpu_cache = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = jnp.zeros(
|
||||
(num_kv_heads, num_gpu_blocks * self.block_size, head_size),
|
||||
dtype=dtype)
|
||||
value_cache = jnp.zeros_like(key_cache)
|
||||
self.tpu_cache.append((key_cache, value_cache))
|
||||
self.model_runner.block_size = self.block_size
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
# NOTE(woosuk): Because of buffer donation, the reference to the cache
|
||||
# should be updated after the warmup.
|
||||
self.tpu_cache = self.model_runner.warmup_model(self.tpu_cache)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(self.cache_dtype)
|
||||
return dtype_size * total
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||
) -> Optional[SamplerOutput]:
|
||||
assert seq_group_metadata_list is not None
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
|
||||
# Currently, TPUWorker does not support swapping.
|
||||
# TODO(woosuk): Support block copying.
|
||||
assert len(blocks_to_swap_in) == 0
|
||||
assert len(blocks_to_swap_out) == 0
|
||||
assert len(blocks_to_copy) == 0
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return {}
|
||||
|
||||
output, kv_caches = self.model_runner.execute_model(
|
||||
seq_group_metadata_list, self.tpu_cache)
|
||||
self.tpu_cache = kv_caches
|
||||
return output
|
||||
|
||||
|
||||
def _torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
|
||||
mapping = {
|
||||
torch.float32: jnp.float32,
|
||||
torch.float16: jnp.float16,
|
||||
torch.bfloat16: jnp.bfloat16,
|
||||
}
|
||||
return mapping[dtype]
|
Reference in New Issue
Block a user