mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
75 Commits
Author | SHA1 | Date | |
---|---|---|---|
c00ddd6834 | |||
881b884046 | |||
98a3df0f8d | |||
3f6288cc89 | |||
408ff4950c | |||
278e8a1adc | |||
07be6ed3eb | |||
f6637dba18 | |||
707a5f6473 | |||
57690a9c09 | |||
b15db234ba | |||
d1591f0f1f | |||
85d4488458 | |||
8d072dbfbd | |||
d830766c0c | |||
5ae2f81c2b | |||
4ea41d01a9 | |||
d16a348477 | |||
aa092834bb | |||
d2c6a32c0c | |||
21f35c2289 | |||
2aa9831dd3 | |||
028f528aad | |||
fa5bacd5b0 | |||
b62170e4e3 | |||
98eda57899 | |||
81b8b813f1 | |||
e2c7dedb3a | |||
5323969fcf | |||
f42b4c27d8 | |||
620e7646d3 | |||
d5fb1c20c1 | |||
092e3d6d6d | |||
84284302d8 | |||
743695f586 | |||
62b870fa07 | |||
7e3a230c38 | |||
186c88c497 | |||
ef762cb110 | |||
756c4e78d3 | |||
4880de35d2 | |||
0fb07c08d0 | |||
e4377dd698 | |||
5cb213c85e | |||
25bbc21ef6 | |||
b25fcc06c2 | |||
6661c030c4 | |||
8888d1c474 | |||
cedb67028a | |||
91b47e3f2f | |||
6d62e4c6aa | |||
de82e95787 | |||
b3b89cf755 | |||
6692a30266 | |||
eb0a0466a9 | |||
c59c1e7b2c | |||
d4adf92beb | |||
363e6a950f | |||
696b653193 | |||
0d6402ddfd | |||
60ff6b8c5c | |||
d899009a63 | |||
6894d3efef | |||
38e3d33a62 | |||
02e614d922 | |||
46b31ed98d | |||
31d05f7edb | |||
4cdb732cef | |||
27c592b97b | |||
5083aa9092 | |||
824521c987 | |||
3b8f43024f | |||
d148c2ef00 | |||
86f073edd6 | |||
52a1e908e4 |
148
benchmarks/bench_cache_write.py
Normal file
148
benchmarks/bench_cache_write.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import functools
|
||||||
|
import time
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import chex
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def write_to_kv_cache1(
|
||||||
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
num_heads = key.shape[-2]
|
||||||
|
head_size = key.shape[-1]
|
||||||
|
|
||||||
|
key = key.reshape(-1, num_heads, head_size)
|
||||||
|
key = key.transpose((1, 0, 2))
|
||||||
|
value = value.reshape(-1, num_heads, head_size)
|
||||||
|
value = value.transpose((1, 0, 2))
|
||||||
|
|
||||||
|
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
||||||
|
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
||||||
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||||
|
def write_to_kv_cache2(
|
||||||
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
batch_size = slot_mapping.shape[0]
|
||||||
|
|
||||||
|
def cond(val: _IteratorState):
|
||||||
|
return val.idx < batch_size
|
||||||
|
|
||||||
|
def body(val: _IteratorState):
|
||||||
|
k_cache, v_cache = _write_seq_to_kv_cache(
|
||||||
|
key[val.idx],
|
||||||
|
value[val.idx],
|
||||||
|
val.k_cache,
|
||||||
|
val.v_cache,
|
||||||
|
slot_mapping[val.idx],
|
||||||
|
)
|
||||||
|
val.k_cache = k_cache
|
||||||
|
val.v_cache = v_cache
|
||||||
|
val.idx += 1
|
||||||
|
return val
|
||||||
|
|
||||||
|
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||||
|
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||||
|
return iterator.k_cache, iterator.v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||||
|
def _write_seq_to_kv_cache(
|
||||||
|
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||||
|
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
slot_mapping: jax.Array, # [seq_len]
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
seq_len = slot_mapping.shape[0]
|
||||||
|
num_heads, _, head_size = k_cache.shape
|
||||||
|
# Reshape to match the rank of kv_cache.
|
||||||
|
key = key.reshape(seq_len, num_heads, 1, head_size)
|
||||||
|
value = value.reshape(seq_len, num_heads, 1, head_size)
|
||||||
|
|
||||||
|
def cond(val: _IteratorState):
|
||||||
|
return jnp.logical_and(
|
||||||
|
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
||||||
|
|
||||||
|
def body(val: _IteratorState):
|
||||||
|
slot_idx = slot_mapping[val.idx]
|
||||||
|
val.k_cache = jax.lax.dynamic_update_slice(
|
||||||
|
val.k_cache,
|
||||||
|
key[val.idx],
|
||||||
|
(0, slot_idx, 0),
|
||||||
|
)
|
||||||
|
val.v_cache = jax.lax.dynamic_update_slice(
|
||||||
|
val.v_cache,
|
||||||
|
value[val.idx],
|
||||||
|
(0, slot_idx, 0),
|
||||||
|
)
|
||||||
|
val.idx += 1
|
||||||
|
return val
|
||||||
|
|
||||||
|
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||||
|
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||||
|
return iterator.k_cache, iterator.v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@chex.dataclass
|
||||||
|
class _IteratorState:
|
||||||
|
|
||||||
|
idx: jnp.int32
|
||||||
|
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||||
|
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_write_to_kv_cache(
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
version: int = 1,
|
||||||
|
):
|
||||||
|
if version == 1:
|
||||||
|
f = write_to_kv_cache1
|
||||||
|
elif version == 2:
|
||||||
|
f = write_to_kv_cache2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid version: {version}")
|
||||||
|
|
||||||
|
rng_key = jax.random.PRNGKey(0)
|
||||||
|
key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
||||||
|
value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
||||||
|
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||||
|
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||||
|
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
||||||
|
|
||||||
|
# For JIT compilation.
|
||||||
|
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||||
|
k_cache.block_until_ready()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(100):
|
||||||
|
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||||
|
k_cache.block_until_ready()
|
||||||
|
end = time.time()
|
||||||
|
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
|
||||||
|
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
||||||
|
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)
|
101
benchmarks/bench_paged_attn.py
Normal file
101
benchmarks/bench_paged_attn.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||||
|
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
MAX_NUM_BLOCKS_PER_SEQ = 512
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(jax.jit, static_argnums=(6, 7))
|
||||||
|
def paged_attn(
|
||||||
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
|
sm_scale: float,
|
||||||
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||||
|
context_lens: jax.Array, # [batch]
|
||||||
|
block_size: int,
|
||||||
|
pages_per_compute_block: int,
|
||||||
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
|
q = q.squeeze(1)
|
||||||
|
q = q * sm_scale
|
||||||
|
|
||||||
|
head_size = q.shape[-1]
|
||||||
|
num_slots = k_cache.shape[-2]
|
||||||
|
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||||
|
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||||
|
|
||||||
|
output = paged_attention(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block=pages_per_compute_block,
|
||||||
|
)
|
||||||
|
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_paged_attn(
|
||||||
|
batch_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
context_len: int,
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
pages_per_compute_block: int,
|
||||||
|
):
|
||||||
|
rng_key = jax.random.PRNGKey(0)
|
||||||
|
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
||||||
|
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||||
|
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||||
|
sm_scale = head_size ** -0.5
|
||||||
|
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
|
||||||
|
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
||||||
|
|
||||||
|
# For JIT compilation.
|
||||||
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||||
|
output.block_until_ready()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(100):
|
||||||
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||||
|
output.block_until_ready()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
print(f"Time taken: {(end - start) * 10000:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
parser.add_argument("--num-heads", type=int, default=16)
|
||||||
|
parser.add_argument("--num-kv-heads", type=int, default=16)
|
||||||
|
parser.add_argument("--head-size", type=int, default=256)
|
||||||
|
parser.add_argument("--context-len", type=int, default=512)
|
||||||
|
parser.add_argument("--num-blocks", type=int, default=2048)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
for block_size in [16, 32, 64, 128]:
|
||||||
|
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
|
||||||
|
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
|
||||||
|
continue
|
||||||
|
if block_size * pages_per_compute_block > 1024:
|
||||||
|
continue
|
||||||
|
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
|
||||||
|
benchmark_paged_attn(
|
||||||
|
args.batch_size,
|
||||||
|
args.num_heads,
|
||||||
|
args.num_kv_heads,
|
||||||
|
args.head_size,
|
||||||
|
args.context_len,
|
||||||
|
args.num_blocks,
|
||||||
|
block_size,
|
||||||
|
pages_per_compute_block,
|
||||||
|
)
|
@ -335,7 +335,7 @@ if __name__ == "__main__":
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda", "cpu"],
|
choices=["cuda", "cpu", "tpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
|
6
requirements-tpu.txt
Normal file
6
requirements-tpu.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# Common dependencies
|
||||||
|
-r requirements-common.txt
|
||||||
|
|
||||||
|
torch
|
||||||
|
jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||||
|
flax >= 0.8
|
22
setup.py
22
setup.py
@ -188,9 +188,9 @@ class cmake_build_ext(build_ext):
|
|||||||
|
|
||||||
|
|
||||||
def _is_cuda() -> bool:
|
def _is_cuda() -> bool:
|
||||||
return VLLM_TARGET_DEVICE == "cuda" \
|
has_cuda = torch.version.cuda is not None
|
||||||
and torch.version.cuda is not None \
|
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
|
||||||
and not _is_neuron()
|
and not (_is_neuron() or _is_tpu()))
|
||||||
|
|
||||||
|
|
||||||
def _is_hip() -> bool:
|
def _is_hip() -> bool:
|
||||||
@ -207,10 +207,18 @@ def _is_neuron() -> bool:
|
|||||||
return torch_neuronx_installed
|
return torch_neuronx_installed
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tpu() -> bool:
|
||||||
|
return True # FIXME
|
||||||
|
|
||||||
|
|
||||||
def _is_cpu() -> bool:
|
def _is_cpu() -> bool:
|
||||||
return VLLM_TARGET_DEVICE == "cpu"
|
return VLLM_TARGET_DEVICE == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_custom_ops() -> bool:
|
||||||
|
return _is_cuda() or _is_hip() or _is_cpu()
|
||||||
|
|
||||||
|
|
||||||
def _install_punica() -> bool:
|
def _install_punica() -> bool:
|
||||||
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
|
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
|
||||||
|
|
||||||
@ -307,6 +315,8 @@ def get_vllm_version() -> str:
|
|||||||
if neuron_version != MAIN_CUDA_VERSION:
|
if neuron_version != MAIN_CUDA_VERSION:
|
||||||
neuron_version_str = neuron_version.replace(".", "")[:3]
|
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||||
version += f"+neuron{neuron_version_str}"
|
version += f"+neuron{neuron_version_str}"
|
||||||
|
elif _is_tpu():
|
||||||
|
version += "+tpu"
|
||||||
elif _is_cpu():
|
elif _is_cpu():
|
||||||
version += "+cpu"
|
version += "+cpu"
|
||||||
else:
|
else:
|
||||||
@ -353,6 +363,8 @@ def get_requirements() -> List[str]:
|
|||||||
requirements = _read_requirements("requirements-rocm.txt")
|
requirements = _read_requirements("requirements-rocm.txt")
|
||||||
elif _is_neuron():
|
elif _is_neuron():
|
||||||
requirements = _read_requirements("requirements-neuron.txt")
|
requirements = _read_requirements("requirements-neuron.txt")
|
||||||
|
elif _is_tpu():
|
||||||
|
requirements = _read_requirements("requirements-tpu.txt")
|
||||||
elif _is_cpu():
|
elif _is_cpu():
|
||||||
requirements = _read_requirements("requirements-cpu.txt")
|
requirements = _read_requirements("requirements-cpu.txt")
|
||||||
else:
|
else:
|
||||||
@ -369,7 +381,7 @@ if _is_cuda():
|
|||||||
if _install_punica():
|
if _install_punica():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||||
|
|
||||||
if not _is_neuron():
|
if _build_custom_ops():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||||
|
|
||||||
package_data = {
|
package_data = {
|
||||||
@ -408,6 +420,6 @@ setup(
|
|||||||
extras_require={
|
extras_require={
|
||||||
"tensorizer": ["tensorizer==2.9.0a1"],
|
"tensorizer": ["tensorizer==2.9.0a1"],
|
||||||
},
|
},
|
||||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_cpu, is_hip
|
from vllm.utils import is_cpu, is_hip, is_tpu
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -19,6 +19,7 @@ class _Backend(enum.Enum):
|
|||||||
XFORMERS = enum.auto()
|
XFORMERS = enum.auto()
|
||||||
ROCM_FLASH = enum.auto()
|
ROCM_FLASH = enum.auto()
|
||||||
TORCH_SDPA = enum.auto()
|
TORCH_SDPA = enum.auto()
|
||||||
|
PALLAS = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
@ -49,6 +50,8 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
|||||||
|
|
||||||
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||||
"""Returns which flash attention backend to use."""
|
"""Returns which flash attention backend to use."""
|
||||||
|
if is_tpu():
|
||||||
|
return _Backend.PALLAS
|
||||||
if is_cpu():
|
if is_cpu():
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||||
is_neuron)
|
is_neuron, is_tpu)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -620,6 +620,8 @@ class DeviceConfig:
|
|||||||
# Automated device type detection
|
# Automated device type detection
|
||||||
if is_neuron():
|
if is_neuron():
|
||||||
self.device_type = "neuron"
|
self.device_type = "neuron"
|
||||||
|
elif is_tpu():
|
||||||
|
self.device_type = "tpu"
|
||||||
elif is_cpu():
|
elif is_cpu():
|
||||||
self.device_type = "cpu"
|
self.device_type = "cpu"
|
||||||
else:
|
else:
|
||||||
@ -633,6 +635,8 @@ class DeviceConfig:
|
|||||||
# Some device types require processing inputs on CPU
|
# Some device types require processing inputs on CPU
|
||||||
if self.device_type in ["neuron"]:
|
if self.device_type in ["neuron"]:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
elif self.device_type in ["tpu"]:
|
||||||
|
self.device = None
|
||||||
else:
|
else:
|
||||||
# Set device with device type
|
# Set device with device type
|
||||||
self.device = torch.device(self.device_type)
|
self.device = torch.device(self.device_type)
|
||||||
|
@ -598,6 +598,13 @@ class Scheduler:
|
|||||||
|
|
||||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||||
while self._passed_delay(time.time()) and waiting_queue:
|
while self._passed_delay(time.time()) and waiting_queue:
|
||||||
|
# FIXME(woosuk): The TPU backend only supports up to 4 sequence
|
||||||
|
# groups in a single batch.
|
||||||
|
MAX_BATCH_SIZE = 1
|
||||||
|
if len(seq_groups) == MAX_BATCH_SIZE:
|
||||||
|
break
|
||||||
|
assert len(seq_groups) < MAX_BATCH_SIZE
|
||||||
|
|
||||||
seq_group = waiting_queue[0]
|
seq_group = waiting_queue[0]
|
||||||
|
|
||||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||||
|
@ -221,6 +221,9 @@ class LLMEngine:
|
|||||||
if engine_config.device_config.device_type == "neuron":
|
if engine_config.device_config.device_type == "neuron":
|
||||||
from vllm.executor.neuron_executor import NeuronExecutor
|
from vllm.executor.neuron_executor import NeuronExecutor
|
||||||
executor_class = NeuronExecutor
|
executor_class = NeuronExecutor
|
||||||
|
elif engine_config.device_config.device_type == "tpu":
|
||||||
|
from vllm.executor.tpu_executor import TPUExecutor
|
||||||
|
executor_class = TPUExecutor
|
||||||
elif engine_config.device_config.device_type == "cpu":
|
elif engine_config.device_config.device_type == "cpu":
|
||||||
from vllm.executor.cpu_executor import CPUExecutor
|
from vllm.executor.cpu_executor import CPUExecutor
|
||||||
executor_class = CPUExecutor
|
executor_class = CPUExecutor
|
||||||
|
98
vllm/executor/tpu_executor.py
Normal file
98
vllm/executor/tpu_executor.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.utils import make_async
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
def _init_executor(self) -> None:
|
||||||
|
assert not self.speculative_config, (
|
||||||
|
"Speculative decoding not yet supported for TPU backend")
|
||||||
|
# Instantiate the worker and load the model to the device.
|
||||||
|
self._init_worker()
|
||||||
|
|
||||||
|
def _init_worker(self):
|
||||||
|
from vllm.worker.tpu_worker import TPUWorker
|
||||||
|
|
||||||
|
assert self.parallel_config.world_size == 1, (
|
||||||
|
"TPUExecutor currently only supports a single TPU chip.")
|
||||||
|
self.driver_worker = TPUWorker(
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
self.device_config,
|
||||||
|
self.cache_config,
|
||||||
|
self.vision_language_config,
|
||||||
|
)
|
||||||
|
self.driver_worker.init_device()
|
||||||
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
|
def initialize_cache(
|
||||||
|
self,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the KV cache by invoking the underlying worker."""
|
||||||
|
# NOTE: This is logged in the executor because there can be >1 worker
|
||||||
|
# with other executors. We could log in the engine level, but work
|
||||||
|
# remains to abstract away the device for non-GPU configurations.
|
||||||
|
logger.info(f"# TPU blocks: {num_gpu_blocks}, "
|
||||||
|
f"# CPU blocks: {num_cpu_blocks}")
|
||||||
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Determine the number of available KV blocks by invoking the
|
||||||
|
underlying worker.
|
||||||
|
"""
|
||||||
|
return self.driver_worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
output = self.driver_worker.execute_model(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
# TPUExecutor will always be healthy as long as it's running.
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
|
||||||
|
|
||||||
|
async def execute_model_async(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
output = await make_async(self.driver_worker.execute_model)(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy)
|
||||||
|
return output
|
0
vllm/model_executor/models/jax/__init__.py
Normal file
0
vllm/model_executor/models/jax/__init__.py
Normal file
328
vllm/model_executor/models/jax/gemma.py
Normal file
328
vllm/model_executor/models/jax/gemma.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
# Copyright 2024 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Gemma transformer."""
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax import linen as nn
|
||||||
|
from transformers import GemmaConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
|
||||||
|
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
|
||||||
|
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_kv_cache
|
||||||
|
|
||||||
|
K_MASK = -2.3819763e38 # Set to a large negative number.
|
||||||
|
|
||||||
|
|
||||||
|
class Einsum(nn.Module):
|
||||||
|
"""Einsum is a convenience module for parameterized tensor multiplication."""
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
|
||||||
|
w = self.param('w', nn.initializers.normal(), self.shape)
|
||||||
|
return jnp.einsum(eqn, x, w)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
"""RMSNorm layer."""
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
scale = self.param('scale', nn.initializers.zeros_init(),
|
||||||
|
(x.shape[-1]))
|
||||||
|
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
|
||||||
|
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
|
||||||
|
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
|
||||||
|
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
|
||||||
|
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
|
||||||
|
scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1))
|
||||||
|
normed_inputs = normed_inputs * (1 + scale)
|
||||||
|
return normed_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(
|
||||||
|
inputs: jax.Array, # [B, L]
|
||||||
|
positions: jax.Array, # [B, L]
|
||||||
|
head_dim: int,
|
||||||
|
max_wavelength: int = 10_000,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Applies RoPE."""
|
||||||
|
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
|
||||||
|
timescale = max_wavelength**fraction
|
||||||
|
|
||||||
|
sinusoid_inp = (positions[..., jnp.newaxis] /
|
||||||
|
timescale[jnp.newaxis, jnp.newaxis, :])
|
||||||
|
sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :]
|
||||||
|
sin = jnp.sin(sinusoid_inp)
|
||||||
|
cos = jnp.cos(sinusoid_inp)
|
||||||
|
|
||||||
|
first_half, second_half = jnp.split(inputs, 2, axis=-1)
|
||||||
|
first_part = first_half * cos - second_half * sin
|
||||||
|
second_part = second_half * cos + first_half * sin
|
||||||
|
out = jnp.concatenate([first_part, second_part], axis=-1)
|
||||||
|
return out.astype(inputs.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Embedder(nn.Module):
|
||||||
|
"""Embedder module."""
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
embed_dim: int
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.input_embedding_table = self.param(
|
||||||
|
'input_embedding',
|
||||||
|
nn.initializers.normal(),
|
||||||
|
(self.vocab_size, self.embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x: jax.Array) -> jax.Array:
|
||||||
|
x = self.input_embedding_table[(x, )]
|
||||||
|
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x: jax.Array) -> jax.Array:
|
||||||
|
return jnp.dot(x, self.input_embedding_table.T)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""Attention module."""
|
||||||
|
|
||||||
|
num_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
features: int
|
||||||
|
head_dim: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_qkv_einsum(self):
|
||||||
|
return self.num_kv_heads == self.num_heads
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.attn_vec_einsum = Einsum(shape=(self.num_heads, self.head_dim,
|
||||||
|
self.features), )
|
||||||
|
|
||||||
|
if self.use_qkv_einsum:
|
||||||
|
self.qkv_einsum = Einsum(shape=(3, self.num_heads, self.features,
|
||||||
|
self.head_dim), )
|
||||||
|
else:
|
||||||
|
self.q_einsum = Einsum(shape=(self.num_heads, self.features,
|
||||||
|
self.head_dim), )
|
||||||
|
self.kv_einsum = Einsum(shape=(2, self.num_kv_heads, self.features,
|
||||||
|
self.head_dim), )
|
||||||
|
self.sm_scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: jax.Array,
|
||||||
|
segment_pos: jax.Array,
|
||||||
|
slot_mapping: jax.Array,
|
||||||
|
block_tables: jax.Array | None,
|
||||||
|
context_lens: jax.Array | None,
|
||||||
|
cache: Tuple[jax.Array, jax.Array],
|
||||||
|
) -> tuple[jax.Array, jax.Array]:
|
||||||
|
if self.use_qkv_einsum:
|
||||||
|
query_proj, key_proj, value_proj = self.qkv_einsum(
|
||||||
|
'BTD,SNDH->SBTNH', x)
|
||||||
|
else:
|
||||||
|
query_proj = self.q_einsum('BTD,NDH->BTNH', x)
|
||||||
|
key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)
|
||||||
|
|
||||||
|
query_proj = apply_rope(
|
||||||
|
query_proj,
|
||||||
|
segment_pos,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
)
|
||||||
|
key_proj = apply_rope(
|
||||||
|
key_proj,
|
||||||
|
segment_pos,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the incoming keys and values to KV cache.
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
# FIXME(woosuk): Uncomment this.
|
||||||
|
# k_cache, v_cache = write_to_kv_cache(
|
||||||
|
# key_proj, value_proj, k_cache, v_cache, slot_mapping)
|
||||||
|
|
||||||
|
if block_tables is None:
|
||||||
|
# Prompt attention.
|
||||||
|
if not self.use_qkv_einsum:
|
||||||
|
# MQA/GQA.
|
||||||
|
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
|
||||||
|
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# FlashAttention.
|
||||||
|
output = flash_attn(
|
||||||
|
query_proj,
|
||||||
|
key_proj,
|
||||||
|
value_proj,
|
||||||
|
self.sm_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Naive attention with masking.
|
||||||
|
seq_len = query_proj.shape[1]
|
||||||
|
attn_mask = jnp.tril(
|
||||||
|
jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
||||||
|
|
||||||
|
query_scaled = query_proj * self.sm_scale
|
||||||
|
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
|
||||||
|
masked_logits = jnp.where((jnp.expand_dims(attn_mask, -2)),
|
||||||
|
logits, K_MASK)
|
||||||
|
probs = jax.nn.softmax(masked_logits,
|
||||||
|
axis=-1).astype(key_proj.dtype)
|
||||||
|
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
|
else:
|
||||||
|
# Decode attention.
|
||||||
|
output = paged_attn(
|
||||||
|
query_proj,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
self.sm_scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output)
|
||||||
|
return (k_cache, v_cache), attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
"""Feed forward module."""
|
||||||
|
|
||||||
|
features: int
|
||||||
|
hidden_dim: int
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
w_gating = self.param(
|
||||||
|
'gating_einsum',
|
||||||
|
nn.initializers.zeros_init(),
|
||||||
|
((2, self.features, self.hidden_dim)),
|
||||||
|
)
|
||||||
|
ff_gate = jnp.dot(x, w_gating[0])
|
||||||
|
gate_value = nn.gelu(ff_gate)
|
||||||
|
|
||||||
|
ff1 = jnp.dot(x, w_gating[1])
|
||||||
|
activations = gate_value * ff1
|
||||||
|
|
||||||
|
w_linear = self.param(
|
||||||
|
'linear',
|
||||||
|
nn.initializers.zeros_init(),
|
||||||
|
(self.hidden_dim, self.features),
|
||||||
|
)
|
||||||
|
outputs = jnp.dot(activations, w_linear)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
"""Transformer block."""
|
||||||
|
|
||||||
|
num_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
embed_dim: int
|
||||||
|
head_dim: int
|
||||||
|
hidden_dim: int
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.pre_attention_norm = RMSNorm()
|
||||||
|
self.attn = Attention(
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
features=self.embed_dim,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
self.pre_ffw_norm = RMSNorm()
|
||||||
|
self.mlp = FeedForward(features=self.embed_dim,
|
||||||
|
hidden_dim=self.hidden_dim)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: jax.Array,
|
||||||
|
segment_pos: jax.Array,
|
||||||
|
slot_mapping: jax.Array,
|
||||||
|
block_tables: jax.Array | None,
|
||||||
|
context_lens: jax.Array | None,
|
||||||
|
cache: Tuple[jax.Array, jax.Array],
|
||||||
|
) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
|
||||||
|
inputs_normalized = self.pre_attention_norm(x)
|
||||||
|
cache, attn_output = self.attn(
|
||||||
|
inputs_normalized,
|
||||||
|
segment_pos,
|
||||||
|
slot_mapping,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
cache,
|
||||||
|
)
|
||||||
|
attn_output += x
|
||||||
|
residual = attn_output
|
||||||
|
attn_output = self.pre_ffw_norm(attn_output)
|
||||||
|
outputs = self.mlp(attn_output)
|
||||||
|
outputs = residual + outputs
|
||||||
|
return outputs, cache
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
"""Gemma transformer."""
|
||||||
|
|
||||||
|
config: GemmaConfig
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.embedder = Embedder(
|
||||||
|
vocab_size=256128, # != self.config.vocab_size
|
||||||
|
embed_dim=self.config.hidden_size,
|
||||||
|
)
|
||||||
|
self.blocks = [
|
||||||
|
Block(
|
||||||
|
name=f'layer_{i}',
|
||||||
|
num_heads=self.config.num_attention_heads,
|
||||||
|
num_kv_heads=self.config.num_key_value_heads,
|
||||||
|
embed_dim=self.config.hidden_size,
|
||||||
|
head_dim=self.config.head_dim,
|
||||||
|
hidden_dim=self.config.intermediate_size,
|
||||||
|
) for i in range(self.config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.final_norm = RMSNorm()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
token_ids: jax.Array,
|
||||||
|
positions: jax.Array,
|
||||||
|
slot_mapping: jax.Array,
|
||||||
|
block_tables: jax.Array | None,
|
||||||
|
context_lens: jax.Array | None,
|
||||||
|
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||||
|
logits_indices: jax.Array,
|
||||||
|
) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]:
|
||||||
|
x = self.embedder.encode(token_ids)
|
||||||
|
new_caches = []
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
x, new_cache = block(
|
||||||
|
x,
|
||||||
|
positions,
|
||||||
|
slot_mapping,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
kv_caches[i],
|
||||||
|
)
|
||||||
|
new_caches.append(new_cache)
|
||||||
|
|
||||||
|
x = self.final_norm(x)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
hidden_states = x[logits_indices]
|
||||||
|
logits = self.embedder.decode(hidden_states)
|
||||||
|
return logits, new_caches
|
0
vllm/model_executor/models/jax/ops/__init__.py
Normal file
0
vllm/model_executor/models/jax/ops/__init__.py
Normal file
29
vllm/model_executor/models/jax/ops/flash_attn.py
Normal file
29
vllm/model_executor/models/jax/ops/flash_attn.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import jax
|
||||||
|
from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes, flash_attention
|
||||||
|
|
||||||
|
_DEFAULT_BLOCK_SIZES = {
|
||||||
|
"block_q": 512,
|
||||||
|
"block_k_major": 512,
|
||||||
|
"block_k": 512,
|
||||||
|
"block_b": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn(
|
||||||
|
q: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||||
|
k: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||||
|
v: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||||
|
sm_scale: float,
|
||||||
|
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
|
||||||
|
return flash_attention(
|
||||||
|
q.transpose(0, 2, 1, 3),
|
||||||
|
k.transpose(0, 2, 1, 3),
|
||||||
|
v.transpose(0, 2, 1, 3),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
block_sizes=BlockSizes(
|
||||||
|
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
|
||||||
|
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
|
||||||
|
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
|
||||||
|
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
|
||||||
|
).transpose(0, 2, 1, 3)
|
32
vllm/model_executor/models/jax/ops/paged_attn.py
Normal file
32
vllm/model_executor/models/jax/ops/paged_attn.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import jax
|
||||||
|
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attn(
|
||||||
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
|
sm_scale: float,
|
||||||
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||||
|
context_lens: jax.Array, # [batch]
|
||||||
|
block_size: int = 16,
|
||||||
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
|
q = q * sm_scale
|
||||||
|
q = q.squeeze(1)
|
||||||
|
|
||||||
|
head_size = q.shape[-1]
|
||||||
|
num_slots = k_cache.shape[-2]
|
||||||
|
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size,
|
||||||
|
head_size)
|
||||||
|
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size,
|
||||||
|
head_size)
|
||||||
|
|
||||||
|
output = paged_attention(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block=512 // 16, # TODO(woosuk): Tune this value.
|
||||||
|
)
|
||||||
|
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
102
vllm/model_executor/models/jax/ops/write_to_cache.py
Normal file
102
vllm/model_executor/models/jax/ops/write_to_cache.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import chex
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_kv_cache(
|
||||||
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
num_heads = key.shape[-2]
|
||||||
|
head_size = key.shape[-1]
|
||||||
|
|
||||||
|
key = key.reshape(-1, num_heads, head_size)
|
||||||
|
key = key.transpose((1, 0, 2))
|
||||||
|
value = value.reshape(-1, num_heads, head_size)
|
||||||
|
value = value.transpose((1, 0, 2))
|
||||||
|
|
||||||
|
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
||||||
|
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
||||||
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _write_to_kv_cache(
|
||||||
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
batch_size = slot_mapping.shape[0]
|
||||||
|
|
||||||
|
def cond(val: _IteratorState):
|
||||||
|
return val.idx < batch_size
|
||||||
|
|
||||||
|
def body(val: _IteratorState):
|
||||||
|
k_cache, v_cache = _write_seq_to_kv_cache(
|
||||||
|
key[val.idx],
|
||||||
|
value[val.idx],
|
||||||
|
val.k_cache,
|
||||||
|
val.v_cache,
|
||||||
|
slot_mapping[val.idx],
|
||||||
|
)
|
||||||
|
val.k_cache = k_cache
|
||||||
|
val.v_cache = v_cache
|
||||||
|
val.idx += 1
|
||||||
|
return val
|
||||||
|
|
||||||
|
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||||
|
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||||
|
return iterator.k_cache, iterator.v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _write_seq_to_kv_cache(
|
||||||
|
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||||
|
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||||
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||||
|
slot_mapping: jax.Array, # [seq_len]
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
seq_len = slot_mapping.shape[0]
|
||||||
|
num_heads, _, head_size = k_cache.shape
|
||||||
|
# Reshape to match the rank of kv_cache.
|
||||||
|
key = key.reshape(seq_len, num_heads, 1, head_size)
|
||||||
|
value = value.reshape(seq_len, num_heads, 1, head_size)
|
||||||
|
|
||||||
|
def cond(val: _IteratorState):
|
||||||
|
return jnp.logical_and(
|
||||||
|
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
||||||
|
|
||||||
|
def body(val: _IteratorState):
|
||||||
|
slot_idx = slot_mapping[val.idx]
|
||||||
|
val.k_cache = jax.lax.dynamic_update_slice(
|
||||||
|
val.k_cache,
|
||||||
|
key[val.idx],
|
||||||
|
(0, slot_idx, 0),
|
||||||
|
)
|
||||||
|
val.v_cache = jax.lax.dynamic_update_slice(
|
||||||
|
val.v_cache,
|
||||||
|
value[val.idx],
|
||||||
|
(0, slot_idx, 0),
|
||||||
|
)
|
||||||
|
val.idx += 1
|
||||||
|
return val
|
||||||
|
|
||||||
|
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||||
|
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||||
|
return iterator.k_cache, iterator.v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@chex.dataclass
|
||||||
|
class _IteratorState:
|
||||||
|
|
||||||
|
idx: jnp.int32
|
||||||
|
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||||
|
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
@ -138,6 +138,15 @@ def is_neuron() -> bool:
|
|||||||
return transformers_neuronx is not None
|
return transformers_neuronx is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def is_tpu() -> bool:
|
||||||
|
try:
|
||||||
|
import libtpu
|
||||||
|
except ImportError:
|
||||||
|
libtpu = None
|
||||||
|
return libtpu is not None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
@ -490,6 +499,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||||
|
"""Get the size of the data type in bytes."""
|
||||||
|
return torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
def merge_dicts(dict1: Dict[Any, List[Any]],
|
def merge_dicts(dict1: Dict[Any, List[Any]],
|
||||||
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
|
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
|
||||||
"""Merge 2 dicts that have key -> List of items.
|
"""Merge 2 dicts that have key -> List of items.
|
||||||
|
@ -6,7 +6,8 @@ import torch
|
|||||||
from vllm.attention import get_attn_backend
|
from vllm.attention import get_attn_backend
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available,
|
||||||
|
get_dtype_size)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -97,9 +98,5 @@ class CacheEngine:
|
|||||||
dtype = model_config.dtype
|
dtype = model_config.dtype
|
||||||
else:
|
else:
|
||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
dtype_size = _get_dtype_size(dtype)
|
dtype_size = get_dtype_size(dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
|
||||||
return torch.tensor([], dtype=dtype).element_size()
|
|
||||||
|
392
vllm/worker/tpu_model_runner.py
Normal file
392
vllm/worker/tpu_model_runner.py
Normal file
@ -0,0 +1,392 @@
|
|||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||||
|
SchedulerConfig, VisionLanguageConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.utils import pad_to_max_length
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
_MAX_NUM_SEQS = 256
|
||||||
|
_MAX_NUM_BLOCKS_PER_SEQ = 8192 // 16
|
||||||
|
|
||||||
|
|
||||||
|
class TPUModelRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
|
if model_config is not None and model_config.get_sliding_window():
|
||||||
|
logger.warning("Sliding window is not supported on TPU. "
|
||||||
|
"The model will run without sliding window.")
|
||||||
|
self.model = None
|
||||||
|
self.block_size = None
|
||||||
|
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
|
||||||
|
# FIXME(woosuk)
|
||||||
|
self.block_tables = np.zeros((_MAX_NUM_SEQS, _MAX_NUM_BLOCKS_PER_SEQ),
|
||||||
|
dtype=np.int32)
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm.model_executor.models.jax.gemma import Transformer
|
||||||
|
|
||||||
|
assert self.model_config.hf_config.model_type == "gemma"
|
||||||
|
self.model = Transformer(self.model_config.hf_config)
|
||||||
|
|
||||||
|
model_name = "google/gemma-7b-flax"
|
||||||
|
model_dir = snapshot_download(model_name)
|
||||||
|
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||||
|
self.params = {"params": params}
|
||||||
|
self.cpu_device = jax.devices("cpu")[0]
|
||||||
|
|
||||||
|
def warmup_model(
|
||||||
|
self,
|
||||||
|
tpu_caches: List[Tuple[jax.Array, jax.Array]],
|
||||||
|
) -> List[Tuple[jax.Array, jax.Array]]:
|
||||||
|
# Prefill
|
||||||
|
logger.info("Compiling the model with different input shapes...")
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1]:
|
||||||
|
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||||
|
if batch_size * seq_len > 8192:
|
||||||
|
continue
|
||||||
|
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
position_ids = jnp.zeros((batch_size, seq_len),
|
||||||
|
dtype=jnp.int32)
|
||||||
|
slot_mapping = jnp.zeros((batch_size, seq_len),
|
||||||
|
dtype=jnp.int32)
|
||||||
|
block_tables = None
|
||||||
|
context_lens = None
|
||||||
|
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||||
|
|
||||||
|
# Dummy run.
|
||||||
|
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||||
|
position_ids, slot_mapping,
|
||||||
|
block_tables, context_lens,
|
||||||
|
prompt_lens, tpu_caches)
|
||||||
|
end = time.time()
|
||||||
|
logger.info(f"Compilation for prefill done in {(end - start):.2f} s.")
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]:
|
||||||
|
seq_len = 1
|
||||||
|
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
block_tables = jnp.zeros((batch_size, _MAX_NUM_BLOCKS_PER_SEQ),
|
||||||
|
dtype=jnp.int32)
|
||||||
|
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||||
|
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||||
|
|
||||||
|
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||||
|
position_ids, slot_mapping,
|
||||||
|
block_tables, context_lens,
|
||||||
|
prompt_lens, tpu_caches)
|
||||||
|
end = time.time()
|
||||||
|
logger.info(f"Compilation for decode done in {(end - start):.2f} s.")
|
||||||
|
return tpu_caches
|
||||||
|
|
||||||
|
def _prepare_prompt(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
):
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
prompt_lens: List[int] = []
|
||||||
|
slot_mapping: List[List[int]] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
prompt_len = len(prompt_tokens)
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
|
input_tokens.append(prompt_tokens)
|
||||||
|
input_positions.append(list(range(prompt_len)))
|
||||||
|
|
||||||
|
assert seq_group_metadata.block_tables is not None
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
slot_mapping.append([])
|
||||||
|
for i in range(prompt_len):
|
||||||
|
block_number = block_table[i //
|
||||||
|
self.block_size] # type: ignore
|
||||||
|
block_offset = i % self.block_size # type: ignore
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping[-1].append(slot)
|
||||||
|
|
||||||
|
max_prompt_len = max(prompt_lens)
|
||||||
|
assert max_prompt_len > 0
|
||||||
|
max_prompt_len = _get_padded_prefill_len(max_prompt_len)
|
||||||
|
|
||||||
|
input_tokens = _make_array_with_pad(input_tokens,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=jnp.int32)
|
||||||
|
input_positions = _make_array_with_pad(input_positions,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=jnp.int32)
|
||||||
|
slot_mapping = _make_array_with_pad(slot_mapping,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=_PAD_SLOT_ID,
|
||||||
|
dtype=jnp.int32)
|
||||||
|
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
||||||
|
return (input_tokens, input_positions, slot_mapping, None, None,
|
||||||
|
prompt_lens)
|
||||||
|
|
||||||
|
def _prepare_decode(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
):
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
slot_mapping: List[List[int]] = []
|
||||||
|
context_lens: List[int] = []
|
||||||
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
|
batch_size = _get_padded_batch_size(num_seq_groups)
|
||||||
|
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
assert not seq_group_metadata.is_prompt
|
||||||
|
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
generation_token = seq_data.get_last_token_id()
|
||||||
|
input_tokens.append([generation_token])
|
||||||
|
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
position = seq_len - 1
|
||||||
|
input_positions.append([position])
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
|
||||||
|
assert seq_group_metadata.block_tables is not None
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
self.block_tables[i, :len(block_table)] = block_table
|
||||||
|
|
||||||
|
block_number = block_table[position // self.block_size]
|
||||||
|
block_offset = position % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append([slot])
|
||||||
|
|
||||||
|
num_paddings = batch_size - num_seq_groups
|
||||||
|
input_tokens = input_tokens + [[0]] * num_paddings
|
||||||
|
input_positions = input_positions + [[0]] * num_paddings
|
||||||
|
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||||
|
context_lens = context_lens + [0] * num_paddings
|
||||||
|
|
||||||
|
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
|
||||||
|
input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
|
||||||
|
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
||||||
|
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
||||||
|
|
||||||
|
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||||
|
dtype=jnp.int32)
|
||||||
|
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
|
||||||
|
return (input_tokens, input_positions, slot_mapping, block_tables,
|
||||||
|
context_lens, input_lens)
|
||||||
|
|
||||||
|
def prepare_input_arrays(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
):
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
# Prepare input tensors.
|
||||||
|
if is_prompt:
|
||||||
|
return self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
return self._prepare_decode(seq_group_metadata_list)
|
||||||
|
|
||||||
|
def _execute_step(
|
||||||
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
token_ids: jax.Array,
|
||||||
|
position_ids: jax.Array,
|
||||||
|
slot_mapping: jax.Array,
|
||||||
|
block_tables: Optional[jax.Array],
|
||||||
|
context_lens: Optional[jax.Array],
|
||||||
|
input_lens: jax.Array,
|
||||||
|
kv_caches: List[jax.Array],
|
||||||
|
) -> tuple[jax.Array, List[jax.Array]]:
|
||||||
|
batch_size, seq_len = token_ids.shape
|
||||||
|
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
||||||
|
logits_indices = base_indicies + input_lens - 1
|
||||||
|
|
||||||
|
logits, new_kv_caches = self.model.apply(
|
||||||
|
params,
|
||||||
|
token_ids,
|
||||||
|
position_ids,
|
||||||
|
slot_mapping,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
kv_caches,
|
||||||
|
logits_indices,
|
||||||
|
)
|
||||||
|
# TODO(woosuk): Support sampling with temperature and top_p.
|
||||||
|
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||||
|
return next_token_ids, new_kv_caches
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||||
|
) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
|
||||||
|
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||||
|
end = time.time()
|
||||||
|
# print(inputs[0].shape)
|
||||||
|
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
next_token_ids, new_kv_caches = self.compiled_fn(
|
||||||
|
self.params, *inputs, kv_caches)
|
||||||
|
next_token_ids.block_until_ready()
|
||||||
|
end = time.time()
|
||||||
|
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
next_token_ids = jax.device_put(next_token_ids, self.cpu_device)
|
||||||
|
end = time.time()
|
||||||
|
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||||
|
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
sampler_outputs = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_outputs = []
|
||||||
|
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
next_token_id = next_token_ids[i]
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: Logprob(0.0)}))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
||||||
|
return SamplerOutput(sampler_outputs), new_kv_caches
|
||||||
|
|
||||||
|
|
||||||
|
def _make_array_with_pad(
|
||||||
|
x: List[List[int]],
|
||||||
|
max_len: int,
|
||||||
|
pad: int,
|
||||||
|
dtype: jnp.dtype,
|
||||||
|
) -> jax.Array:
|
||||||
|
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
|
||||||
|
return jnp.asarray(padded_x, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_prefill_len(x: int) -> int:
|
||||||
|
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||||
|
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||||
|
# multiple of 16. This is also good for performance.
|
||||||
|
if x <= 16:
|
||||||
|
return 16
|
||||||
|
return 1 << (x - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_batch_size(batch_size: int) -> int:
|
||||||
|
if batch_size <= 2:
|
||||||
|
return batch_size
|
||||||
|
elif batch_size <= 4:
|
||||||
|
return 4
|
||||||
|
elif batch_size <= 8:
|
||||||
|
return 8
|
||||||
|
else:
|
||||||
|
return ((batch_size + 15) // 16) * 16
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
import orbax.checkpoint
|
||||||
|
|
||||||
|
Params = Mapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_format_params(path: str) -> Params:
|
||||||
|
"""Loads parameters and formats them for compatibility."""
|
||||||
|
params = load_params(path)
|
||||||
|
param_state = jax.tree_util.tree_map(jnp.array, params)
|
||||||
|
remapped_params = param_remapper(param_state)
|
||||||
|
nested_params = nest_params(remapped_params)
|
||||||
|
return nested_params
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def load_params(path: str) -> Params:
|
||||||
|
"""Loads parameters from a checkpoint path."""
|
||||||
|
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||||
|
params = checkpointer.restore(path)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def param_remapper(orig_params: Params) -> Params:
|
||||||
|
"""Remaps params to new module layout.
|
||||||
|
|
||||||
|
This is needed here because the model definition does not have a separate
|
||||||
|
`mlp` module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orig_params: original dict of parameters in Gemma format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of params with different names.
|
||||||
|
"""
|
||||||
|
new_params = {}
|
||||||
|
for k, v in orig_params.items():
|
||||||
|
if 'mlp/' in k:
|
||||||
|
layer_name, param = k.rsplit('/', maxsplit=1)
|
||||||
|
if layer_name not in new_params:
|
||||||
|
new_params[layer_name] = {}
|
||||||
|
if 'w' in v:
|
||||||
|
new_params[layer_name][param] = v['w']
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
def nest_params(params: Params) -> Params:
|
||||||
|
"""Nests params as a dict of dicts rather than a flat dict."""
|
||||||
|
nested_params = {}
|
||||||
|
for path, param in params.items():
|
||||||
|
*path, leaf = path.split('/')
|
||||||
|
subdict = nested_params
|
||||||
|
for key in path:
|
||||||
|
subdict = subdict.setdefault(key, {})
|
||||||
|
subdict[leaf] = param
|
||||||
|
return nested_params
|
155
vllm/worker/tpu_worker.py
Normal file
155
vllm/worker/tpu_worker.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.worker.tpu_model_runner import TPUModelRunner
|
||||||
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
from vllm.utils import get_dtype_size, STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUWorker(LoraNotSupportedWorkerBase):
|
||||||
|
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||||
|
|
||||||
|
Each worker is associated with a single CPU socket. The worker is
|
||||||
|
responsible for maintaining the KV cache and executing the model on the
|
||||||
|
CPU. In case of distributed inference, each worker is assigned a partition
|
||||||
|
of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
) -> None:
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
assert self.device_config.device_type == "tpu"
|
||||||
|
|
||||||
|
if self.cache_config.cache_dtype == "auto":
|
||||||
|
self.cache_dtype = self.model_config.dtype
|
||||||
|
else:
|
||||||
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
|
self.cache_config.cache_dtype]
|
||||||
|
|
||||||
|
self.model_runner = TPUModelRunner(
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
vision_language_config=vision_language_config)
|
||||||
|
self.tpu_cache = None
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
# Set random seed.
|
||||||
|
# TODO: Set random seed for JAX
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
# Use persistent cache to avoid recompilation.
|
||||||
|
jax.config.update("jax_compilation_cache_dir",
|
||||||
|
os.path.expanduser("~/.vllm/jax_cache"))
|
||||||
|
|
||||||
|
# DELETE
|
||||||
|
# from jax_smi import initialise_tracking
|
||||||
|
# initialise_tracking()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
num_tpu_blocks = 2000
|
||||||
|
return num_tpu_blocks, 0
|
||||||
|
|
||||||
|
def initialize_cache(
|
||||||
|
self,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
) -> None:
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.block_size = self.cache_config.block_size
|
||||||
|
|
||||||
|
dtype = _torch_dtype_to_jax(self.cache_dtype)
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
|
||||||
|
self.tpu_cache = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_cache = jnp.zeros(
|
||||||
|
(num_kv_heads, num_gpu_blocks * self.block_size, head_size),
|
||||||
|
dtype=dtype)
|
||||||
|
value_cache = jnp.zeros_like(key_cache)
|
||||||
|
self.tpu_cache.append((key_cache, value_cache))
|
||||||
|
self.model_runner.block_size = self.block_size
|
||||||
|
self._warmup_model()
|
||||||
|
|
||||||
|
def _warmup_model(self) -> None:
|
||||||
|
# NOTE(woosuk): Because of buffer donation, the reference to the cache
|
||||||
|
# should be updated after the warmup.
|
||||||
|
self.tpu_cache = self.model_runner.warmup_model(self.tpu_cache)
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
|
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = get_dtype_size(self.cache_dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
|
assert blocks_to_swap_in is not None
|
||||||
|
assert blocks_to_swap_out is not None
|
||||||
|
assert blocks_to_copy is not None
|
||||||
|
|
||||||
|
# Currently, TPUWorker does not support swapping.
|
||||||
|
# TODO(woosuk): Support block copying.
|
||||||
|
assert len(blocks_to_swap_in) == 0
|
||||||
|
assert len(blocks_to_swap_out) == 0
|
||||||
|
assert len(blocks_to_copy) == 0
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if num_seq_groups == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
output, kv_caches = self.model_runner.execute_model(
|
||||||
|
seq_group_metadata_list, self.tpu_cache)
|
||||||
|
self.tpu_cache = kv_caches
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
|
||||||
|
mapping = {
|
||||||
|
torch.float32: jnp.float32,
|
||||||
|
torch.float16: jnp.float16,
|
||||||
|
torch.bfloat16: jnp.bfloat16,
|
||||||
|
}
|
||||||
|
return mapping[dtype]
|
Reference in New Issue
Block a user