mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
7 Commits
add-nixl-t
...
v0.10.0rc1
Author | SHA1 | Date | |
---|---|---|---|
d1fb65bde3 | |||
3a1d8940ae | |||
2b504eb770 | |||
10eb24cc91 | |||
2e8cbb58f3 | |||
752c6ade2e | |||
881e3cbe3b |
@ -108,7 +108,6 @@ fi
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
|
@ -264,6 +264,7 @@ steps:
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/kv_connector/unit
|
||||
- pytest -v -s v1/metrics
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
|
@ -576,7 +576,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
|
@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
|
||||
elif (
|
||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||
or config.architectures[0] == "Glm4MoeForCausalLM"
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
@ -376,7 +376,6 @@ Specified using `--task generate`.
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -577,6 +576,7 @@ Specified using `--task generate`.
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
|
@ -107,12 +107,11 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
|
||||
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
|
||||
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
|
||||
enforcing eager mode and disabling prefix caching in V1.
|
||||
disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
||||
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
|
||||
backend in V1.
|
||||
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
|
||||
|
||||
#### Encoder-Decoder Models
|
||||
|
||||
|
@ -15,15 +15,18 @@ import pytest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAMES = [
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
"Qwen/Qwen3-1.7B",
|
||||
"google/gemma-3-1b-it",
|
||||
]
|
||||
FP8_KV_MODEL_NAMES = [
|
||||
"Qwen/Qwen3-1.7B",
|
||||
]
|
||||
NUM_CONCURRENT = 500
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen2-1.5B-Instruct": 0.58,
|
||||
"Qwen/Qwen3-1.7B": 0.68,
|
||||
"google/gemma-3-1b-it": 0.25,
|
||||
}
|
||||
|
||||
@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
if current_platform.is_tpu():
|
||||
# Limit compilation time for TPU V1
|
||||
|
||||
if model == "google/gemma-3-1b-it":
|
||||
# TPU + google/gemma-3-1b-it + xet doesn't work well.
|
||||
m.setenv("HF_HUB_DISABLE_XET", "1")
|
||||
|
||||
# xet doesn't work well for both Qwen/Qwen3-1.7B and
|
||||
# google/gemma-3-1b-it
|
||||
m.setenv("HF_HUB_DISABLE_XET", "1")
|
||||
more_args = "max_model_len=2048,max_num_seqs=64"
|
||||
|
||||
# Add TP test (if provided)
|
||||
@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
run_test(model, more_args)
|
||||
|
||||
|
||||
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V0 Engine."""
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU")
|
||||
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
|
||||
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
|
||||
model, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
run_test("Qwen/Qwen2-1.5B-Instruct")
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
more_args = None
|
||||
if current_platform.is_tpu():
|
||||
# Limit compilation time for TPU V1
|
||||
|
||||
# xet doesn't work well for Qwen/Qwen3-1.7B
|
||||
m.setenv("HF_HUB_DISABLE_XET", "1")
|
||||
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
|
||||
|
||||
# Add TP test (if provided)
|
||||
if TPU_TP_TEST_STR:
|
||||
more_args += ",{}".format(TPU_TP_TEST_STR)
|
||||
|
||||
run_test(model, more_args)
|
||||
|
@ -1,441 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
# MAX_SEQ_LEN = 2771
|
||||
|
||||
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_GEN_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40)] # Arbitrary values for testing
|
||||
|
||||
HEAD_SIZES = [64, 112]
|
||||
BLOCK_SIZES = [16]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
BLOCKSPARSE_LOCAL_BLOCKS = [16]
|
||||
BLOCKSPARSE_VERT_STRIDES = [8]
|
||||
|
||||
BLOCKSPARSE_BLOCK_SIZES = [64]
|
||||
BLOCKSPARSE_HEADS_SLIDINGS = [2, -1]
|
||||
BLOCKSPARSE_HOMO_HEADS = [True, False]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
|
||||
def ref_single_query_cached_kv_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
num_queries_per_kv: int,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 1,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
block_tables_lst = block_tables.cpu().tolist()
|
||||
seq_lens_lst = seq_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables_lst[i]
|
||||
seq_len = int(seq_lens_lst[i])
|
||||
|
||||
keys_lst: list[torch.Tensor] = []
|
||||
values_lst: list[torch.Tensor] = []
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_kv_heads, head_size)
|
||||
keys_lst.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values_lst.append(v)
|
||||
keys = torch.stack(keys_lst, dim=0)
|
||||
values = torch.stack(values_lst, dim=0)
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(seq_len).int()
|
||||
alibi_bias = (position_ids - seq_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
|
||||
if blocksparse_vert_stride >= 1:
|
||||
bsize = blocksparse_block_size
|
||||
hsliding = blocksparse_head_sliding_step
|
||||
vert = blocksparse_vert_stride
|
||||
locals = blocksparse_local_blocks
|
||||
qb = (seq_len - 1) // bsize
|
||||
attn_mask = q.new_zeros(
|
||||
(num_query_heads, 1, seq_len)).float() - torch.inf
|
||||
for h in range(num_query_heads):
|
||||
if hsliding >= 0: # slide with q heads
|
||||
bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1
|
||||
else: # slide with kv heads
|
||||
bs_offset = (tp_rank * num_kv_heads +
|
||||
h // num_queries_per_kv) * (-hsliding) + 1
|
||||
for kb in range(qb + 1):
|
||||
kj = kb * bsize
|
||||
if (qb - kb) < locals or \
|
||||
(kb + bs_offset) % vert == 0:
|
||||
attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0
|
||||
if alibi_bias is not None:
|
||||
attn_mask += alibi_bias
|
||||
else:
|
||||
attn_mask = alibi_bias
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["v1", "v2"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
|
||||
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
|
||||
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("blocksparse_head_sliding_step",
|
||||
BLOCKSPARSE_HEADS_SLIDINGS)
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
blocksparse_local_blocks: int,
|
||||
blocksparse_vert_stride: int,
|
||||
blocksparse_block_size: int,
|
||||
blocksparse_head_sliding_step: int,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.rand(num_query_heads, dtype=torch.float)
|
||||
|
||||
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
tp_rank = 0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v1":
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: list[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
|
||||
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
|
||||
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_blocksparse_attention_prefill(
|
||||
num_seqs: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
blocksparse_local_blocks: int,
|
||||
blocksparse_vert_stride: int,
|
||||
blocksparse_block_size: int,
|
||||
blocksparse_homo_heads: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
max_len = min(MAX_SEQ_LEN, 4096)
|
||||
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||
cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
|
||||
qkv = torch.empty(num_tokens,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
bs_attn_op = LocalStridedBlockSparseAttn(
|
||||
num_query_heads,
|
||||
max_len,
|
||||
local_blocks=blocksparse_local_blocks,
|
||||
vert_stride=blocksparse_vert_stride,
|
||||
block_size=blocksparse_block_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
homo_head=blocksparse_homo_heads)
|
||||
|
||||
output = bs_attn_op(query,
|
||||
key,
|
||||
value,
|
||||
cu_seq_lens.to(device),
|
||||
sm_scale=scale)
|
||||
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
cu_seq_lens.tolist(),
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "ROCM_AITER_MLA"
|
||||
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
|
||||
|
||||
@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "ROCM_AITER_MLA"
|
||||
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
|
||||
|
@ -104,7 +104,6 @@ def test_models(
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
@ -247,10 +247,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
# Blocksparse attention not supported in V1 yet
|
||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
v0_only=True,
|
||||
@ -364,6 +360,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501
|
||||
"Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4.5",
|
||||
min_transformers_version="4.54",
|
||||
is_available_online=False), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
@ -489,6 +488,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
|
||||
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
|
||||
"Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4.5",
|
||||
speculative_model="THUDM/GLM-4.5",
|
||||
min_transformers_version="4.54",
|
||||
is_available_online=False),
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL")
|
||||
|
410
tests/tool_use/test_glm4_moe_tool_parser.py
Normal file
410
tests/tool_use/test_glm4_moe_tool_parser.py
Normal file
@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytest.skip("skip glm4_moe parser test", allow_module_level=True)
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "THUDM/GLM-4.5"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def glm4_moe_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def glm4_moe_tool_parser(glm4_moe_tokenizer):
|
||||
return Glm4MoeModelToolParser(glm4_moe_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 0
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
# Compare arguments as JSON objects to handle formatting differences
|
||||
actual_args = json.loads(actual_tool_call.function.arguments)
|
||||
expected_args = json.loads(expected_tool_call.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
"tool_call_with_mixed_args",
|
||||
"tool_call_with_chinese_content",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Dallas</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>TX</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Dallas</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>TX</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Orlando</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>FL</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather. <tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Seattle</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>WA</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>New York</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "New York",
|
||||
"state": "NY",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
],
|
||||
None,
|
||||
),
|
||||
("""I will help you get the weather.<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>""", [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
"date": "2025-08-01",
|
||||
}),
|
||||
))
|
||||
], "I will help you get the weather."),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(glm4_moe_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser):
|
||||
"""Test tool extraction when thinking tags are present."""
|
||||
model_output = """<think>I want to get the weather.</think>
|
||||
|
||||
I will help you get the weather.
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
|
||||
|
||||
expected_content = """<think>I want to get the weather.</think>
|
||||
|
||||
I will help you get the weather."""
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
|
||||
"""Test that malformed XML is handled gracefully."""
|
||||
model_output = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Seattle</arg_value>
|
||||
<arg_key>incomplete_arg
|
||||
<arg_value>value</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
# Should handle malformed XML gracefully
|
||||
# The parser should either extract what it can or return no tool calls
|
||||
# depending on how robust we want the parsing to be
|
||||
assert isinstance(extracted_tool_calls.tools_called, bool)
|
||||
assert isinstance(extracted_tool_calls.tool_calls, list)
|
||||
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
|
||||
"""Test tool calls with no arguments."""
|
||||
model_output = """<tool_call>get_current_time
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "get_current_time"
|
||||
# Empty arguments should result in empty JSON object
|
||||
assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
|
||||
|
||||
|
||||
def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
|
||||
"""Test extraction with mixed content and multiple tool calls."""
|
||||
model_output = """I will help you get the weather info.
|
||||
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>
|
||||
|
||||
meaningwhile, I will also check the weather in Shanghai.
|
||||
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Shanghai</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
|
||||
# Check first tool call
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
|
||||
args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args1["city"] == "Beijing"
|
||||
assert args1["date"] == "2025-08-01"
|
||||
|
||||
# Check second tool call
|
||||
assert extracted_tool_calls.tool_calls[1].function.name == "get_weather"
|
||||
args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments)
|
||||
assert args2["city"] == "Shanghai"
|
||||
assert args2["date"] == "2025-08-01"
|
||||
|
||||
# Content should be everything before the first tool call
|
||||
assert extracted_tool_calls.content == "I will help you get the weather info."
|
||||
|
||||
|
||||
def test_streaming_basic_functionality(glm4_moe_tool_parser):
|
||||
"""Test basic streaming functionality."""
|
||||
# Reset streaming state
|
||||
glm4_moe_tool_parser.current_tool_name_sent = False
|
||||
glm4_moe_tool_parser.prev_tool_call_arr = []
|
||||
glm4_moe_tool_parser.current_tool_id = -1
|
||||
glm4_moe_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Test with a simple tool call
|
||||
current_text = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
# Mock token IDs for testing
|
||||
tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345
|
||||
tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=current_text,
|
||||
delta_text="</tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[tool_call_start_id, tool_call_end_id],
|
||||
delta_token_ids=[tool_call_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The result behavior depends on the streaming state
|
||||
# This test mainly ensures no exceptions are thrown
|
||||
assert result is None or hasattr(result, 'tool_calls') or hasattr(
|
||||
result, 'content')
|
||||
|
||||
|
||||
def test_streaming_no_tool_calls(glm4_moe_tool_parser):
|
||||
"""Test streaming when there are no tool calls."""
|
||||
current_text = "This is just regular text without any tool calls."
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="This is just regular text",
|
||||
current_text=current_text,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
|
||||
"""Test streaming when there's content before tool calls."""
|
||||
# Reset streaming state
|
||||
glm4_moe_tool_parser.current_tool_name_sent = False
|
||||
glm4_moe_tool_parser.prev_tool_call_arr = []
|
||||
glm4_moe_tool_parser.current_tool_id = -1
|
||||
glm4_moe_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
current_text = "I will help you get the weather<tool_call>"
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I will help you",
|
||||
current_text=current_text,
|
||||
delta_text="get the weather.<tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return content when no tool call tokens are detected
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == "get the weather.<tool_call>"
|
||||
|
||||
|
||||
def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
|
||||
"""Test tool calls with special characters and unicode."""
|
||||
model_output = """<tool_call>send_message
|
||||
<arg_key>recipient</arg_key>
|
||||
<arg_value>Amy</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>It is a nice day</arg_value>
|
||||
<arg_key>priority</arg_key>
|
||||
<arg_value>high</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "send_message"
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["recipient"] == "Amy"
|
||||
assert args["message"] == "It is a nice day"
|
||||
assert args["priority"] == "high"
|
||||
|
||||
|
||||
def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
|
||||
"""Test incomplete tool calls (missing closing tag)."""
|
||||
model_output = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
# Incomplete tool calls should not be extracted
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
@ -1,8 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from vllm.config import ModelDType
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
|
||||
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
|
||||
@ -27,7 +30,7 @@ MODELS = [
|
||||
def test_engine_log_metrics_ray(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
dtype: ModelDType,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
""" Simple smoke test, verifying this can be used without exceptions.
|
||||
@ -37,11 +40,14 @@ def test_engine_log_metrics_ray(
|
||||
class EngineTestActor:
|
||||
|
||||
async def run(self):
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
# Set environment variable inside the Ray actor since environment
|
||||
# variables from pytest fixtures don't propagate to Ray actors
|
||||
os.environ['VLLM_USE_V1'] = '1'
|
||||
|
||||
engine_args = AsyncEngineArgs(model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
enforce_eager=True)
|
||||
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
engine_args, stat_loggers=[RayPrometheusStatLogger])
|
||||
|
@ -95,4 +95,6 @@ def test_ragged_paged_attention():
|
||||
sm_scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=logits_soft_cap,
|
||||
k_scale=1.0,
|
||||
v_scale=1.0,
|
||||
)
|
||||
|
@ -269,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
|
@ -1,466 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseParams:
|
||||
max_seqlen: int
|
||||
|
||||
# Num q heads per tensor-parallel rank/partition
|
||||
num_heads: int # per TP partition
|
||||
# Num kv heads per tensor-parallel rank/partition
|
||||
num_kv_heads: int
|
||||
|
||||
# block size used for blocksparse attention.
|
||||
# This is the block_size used in `local_blocks`, `vert_stride`.
|
||||
block_size: int
|
||||
|
||||
# Number of blocks for local attention, i.e., number of
|
||||
# local attended tokens / `sparse_block_size`
|
||||
local_blocks: int
|
||||
|
||||
# Attend to one block per every `vert_stride` blocks.
|
||||
# Controlling the sparsity
|
||||
vert_stride: int
|
||||
"""
|
||||
If to use the same vertical stride offset for all heads,
|
||||
i.e., attend to the same block of tokens on all heads.
|
||||
By default, it is False, i.e., attention on the non-local
|
||||
blocks depends on the `head_idx`, that is on
|
||||
blocks satisfying
|
||||
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
|
||||
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
|
||||
`block_idx = position_id // sparse_block_size`.
|
||||
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
|
||||
for more detail.
|
||||
"""
|
||||
homo_head: bool = False
|
||||
|
||||
# If within a group, the kv offsets that each q attends is the same or no.
|
||||
homo_head_group: bool = False
|
||||
|
||||
# Decided by homo_head and homo_head group
|
||||
head_sliding_step: int = field(init=False)
|
||||
|
||||
# range of q heads to for a TP rank
|
||||
active_head_range: Tuple = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.block_size > 0
|
||||
assert self.local_blocks >= 0
|
||||
assert self.vert_stride >= 1
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
total_heads = tp_size * self.num_heads
|
||||
total_kv_heads = tp_size * self.num_kv_heads
|
||||
|
||||
if self.homo_head:
|
||||
self.head_sliding_step = 0
|
||||
elif self.homo_head_group:
|
||||
head_sliding_step = get_head_sliding_step(total_kv_heads,
|
||||
self.vert_stride)
|
||||
# negative indicates sliding along kv heads, i.e., homo q group
|
||||
self.head_sliding_step = -head_sliding_step
|
||||
else:
|
||||
self.head_sliding_step = get_head_sliding_step(
|
||||
total_heads, self.vert_stride)
|
||||
|
||||
self.active_head_range = (
|
||||
tp_rank * self.num_heads,
|
||||
(tp_rank + 1) * self.num_heads,
|
||||
)
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "BLOCK_SPARSE_FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||
return BlocksparseFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return BlocksparseFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
|
||||
return BlocksparseFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
"""A copy of Metadata for FlashAttentionBackend,
|
||||
to avoid having to install flash_attn.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Max number of query tokens for among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
|
||||
|
||||
_metadata_cls = BlocksparseFlashAttentionMetadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"BLOCK_SPARSE_FLASH_ATTN Backend.")
|
||||
assert blocksparse_params is not None
|
||||
assert alibi_slopes is None, ValueError(
|
||||
"Alibi not support for blocksparse flash attention.")
|
||||
assert sliding_window is None, ValueError(
|
||||
"sliding_window is invalid for blocksparse attention.")
|
||||
assert logits_soft_cap is None, ValueError(
|
||||
"logits_soft_cap is invalid for blocksparse attention.")
|
||||
|
||||
if "num_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_heads"] = num_heads
|
||||
if "num_kv_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
|
||||
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.local_blocks = self.blocksparse_params.local_blocks
|
||||
self.vert_stride = self.blocksparse_params.vert_stride
|
||||
self.sparse_block_size = self.blocksparse_params.block_size
|
||||
self.head_sliding_step = self.blocksparse_params.head_sliding_step
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
total_num_heads = num_heads * self.tp_size
|
||||
self.bs_attn = LocalStridedBlockSparseAttn(
|
||||
total_num_heads,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.blocksparse_params.local_blocks,
|
||||
self.blocksparse_params.vert_stride,
|
||||
self.blocksparse_params.block_size,
|
||||
homo_head=self.blocksparse_params.homo_head,
|
||||
active_head_range=self.blocksparse_params.active_head_range,
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"BlocksparseFlashAttentionImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for BlocksparseFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
|
||||
# Prompt run.
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
|
||||
assert kv_cache.numel() == 0 \
|
||||
or prefill_meta.block_tables is None \
|
||||
or prefill_meta.block_tables.numel() == 0, \
|
||||
"Does not support prefix-enabled attention."
|
||||
|
||||
output = self.bs_attn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
tp_rank=self.tp_rank,
|
||||
blocksparse_local_blocks=self.local_blocks,
|
||||
blocksparse_vert_stride=self.vert_stride,
|
||||
blocksparse_block_size=self.sparse_block_size,
|
||||
blocksparse_head_sliding_step=self.head_sliding_step,
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
@ -667,7 +667,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -680,9 +679,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
differential_flash_attention_config
|
||||
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
|
@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
|
@ -4,7 +4,7 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -615,7 +615,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -624,9 +623,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"FLASH_ATTN backend.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
|
@ -999,7 +999,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
@ -4,7 +4,7 @@
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -494,7 +494,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -507,9 +506,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
logger.warning_once(
|
||||
"Using irope in ROCm Flash Attention is not supported yet, it "
|
||||
"will fail back to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,7 +35,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -43,17 +42,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with xFormers and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
@ -387,7 +387,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -396,9 +395,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"XFORMERS backend.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("XFormers does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -74,7 +74,6 @@ class Attention(nn.Module):
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
@ -163,12 +162,11 @@ class Attention(nn.Module):
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
@ -1,433 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v, # (#tokens, n_heads, head_size)
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
sparse_layout,
|
||||
*,
|
||||
block_size=64,
|
||||
q_block_size=None,
|
||||
max_seqlen=None):
|
||||
# split q to blocks
|
||||
|
||||
assert isinstance(sparse_layout, (list, tuple))
|
||||
|
||||
_, n_heads, head_size = q.shape
|
||||
batch_size = cu_seqlens_k.size(0) - 1
|
||||
q_block_size = q_block_size or block_size
|
||||
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
assert q.size(2) == k.size(2)
|
||||
# TODO(linxihui): allow k, v to have different head_size
|
||||
assert k.shape == v.shape
|
||||
assert cu_seqlens_k.dim() == 1
|
||||
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
|
||||
if cu_seqlens_q is None:
|
||||
if q.size(0) == batch_size: # decoding only
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size + 1,
|
||||
dtype=cu_seqlens_k.dtype,
|
||||
device=cu_seqlens_k.device,
|
||||
)
|
||||
elif q.size(0) == k.size(0):
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
else:
|
||||
raise ValueError("cu_seqlens_q must be specified\
|
||||
if it mix of prefilling and decoding.")
|
||||
else:
|
||||
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
||||
|
||||
# switch to use cpu to avoid too many kernel launches when iterated over
|
||||
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
||||
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
||||
|
||||
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
|
||||
"length of q should either be 1 (decoding) or same as k (prefilling).")
|
||||
|
||||
if max_seqlen:
|
||||
assert k_lens.max() <= max_seqlen
|
||||
|
||||
n_blocks = (q_lens + q_block_size - 1) // q_block_size
|
||||
|
||||
q_batch_ids = torch.tensor(
|
||||
[i for i, n in enumerate(n_blocks) for _ in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
q_start_sids = torch.tensor(
|
||||
[i * q_block_size for n in n_blocks for i in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
|
||||
out = q.new_empty(q.shape)
|
||||
cu_seqlens_q = cu_seqlens_q.contiguous()
|
||||
cu_seqlens_k = cu_seqlens_k.contiguous()
|
||||
|
||||
layout_crow_indices, layout_col_indices = sparse_layout
|
||||
block_d = triton.next_power_of_2(head_size)
|
||||
|
||||
decoding_only = (q_lens == 1).all().item()
|
||||
grid = (len(q_start_sids), n_heads, 1)
|
||||
|
||||
_fwd_kernel_batch_inference[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
sm_scale,
|
||||
cu_seqlens_q[:-1],
|
||||
cu_seqlens_q[1:],
|
||||
cu_seqlens_k[:-1],
|
||||
cu_seqlens_k[1:],
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
0,
|
||||
*q.stride(),
|
||||
0,
|
||||
*k.stride(),
|
||||
0,
|
||||
*v.stride(),
|
||||
0,
|
||||
*out.stride(),
|
||||
layout_crow_indices,
|
||||
layout_col_indices,
|
||||
*layout_crow_indices.stride(),
|
||||
*layout_col_indices.stride(),
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM=False,
|
||||
D_HEAD=head_size,
|
||||
BLOCK_M=q_block_size,
|
||||
BLOCK_N=block_size,
|
||||
BLOCK_D=block_d,
|
||||
BLOCK_M_LOADING=(16 if decoding_only else
|
||||
q_block_size), # smaller for decoding
|
||||
EVEN_D=block_d == head_size,
|
||||
num_warps=1 if decoding_only else 4,
|
||||
num_stages=3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
LAST_K_BLOCK: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
|
||||
k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||
(offs_d[:, None] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||
mask=offs_d[:, None] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
||||
if LAST_K_BLOCK | M_LT_N:
|
||||
qk += tl.where(
|
||||
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# flash-attn2
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
# update m_i
|
||||
m_i = m_ij
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||
(offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||
mask=offs_d[None, :] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"M_LT_N":
|
||||
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
|
||||
})
|
||||
@triton.jit
|
||||
def _fwd_kernel_batch_inference(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
sm_scale,
|
||||
q_batch_starts,
|
||||
q_batch_ends,
|
||||
k_batch_starts,
|
||||
k_batch_ends,
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
stride_qb,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kb,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vb,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ob,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
layout_crow_ptr,
|
||||
layout_col_ptr,
|
||||
layout_crow_stride_h,
|
||||
layout_crow_stride_m,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
NOTATION:
|
||||
pid: position id
|
||||
sid: storage id
|
||||
sbid: storage block id
|
||||
pbid: position block id
|
||||
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||
|
||||
TODO(linxihui):
|
||||
Optimize grouped-attn
|
||||
"""
|
||||
off_zm = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
|
||||
off_h_for_kv = off_h // q_k_ratio
|
||||
|
||||
if HAS_BATCH_DIM:
|
||||
off_z = tl.program_id(2)
|
||||
Q += off_z * stride_qb
|
||||
K += off_z * stride_kb
|
||||
V += off_z * stride_vb
|
||||
Out += off_z * stride_ob
|
||||
start_m = off_zm
|
||||
q_start_sid = start_m * BLOCK_M # always 0 for decoding
|
||||
else:
|
||||
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
||||
q_start_sid = tl.load(q_start_sids + off_zm)
|
||||
start_m = q_start_sid // BLOCK_M # q_sbid
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
|
||||
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
||||
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
||||
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
||||
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
||||
past_len = k_seqlen - q_seqlen
|
||||
|
||||
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||
|
||||
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
||||
|
||||
if EVEN_D:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||
q_pbid * layout_crow_stride_m)
|
||||
|
||||
# TODO(linxihui): load at once, with any Triton version
|
||||
# that supports `tl.split`, e.g., Triton 3.0
|
||||
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
||||
|
||||
sm_scale *= (
|
||||
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
|
||||
)
|
||||
|
||||
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
False,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_end - 1,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
True,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
# flash-attn 2
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# write output
|
||||
if EVEN_D:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
)
|
@ -1,239 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
|
||||
|
||||
class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
max_seqlen,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
block_size,
|
||||
device=None,
|
||||
dtype=None,
|
||||
homo_head=False,
|
||||
active_head_range=None,
|
||||
q_block_size=None,
|
||||
use_spda=None,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spda is None:
|
||||
use_spda = current_platform.is_rocm() or \
|
||||
current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||
or device.type == "cpu" else torch.half)
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.max_seqlen = max_seqlen
|
||||
self.local_blocks = local_blocks
|
||||
self.vert_stride = vert_stride
|
||||
self.use_spda = use_spda
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.q_block_size = q_block_size
|
||||
self.homo_head = homo_head
|
||||
self.active_head_range = active_head_range
|
||||
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
|
||||
homo_head)
|
||||
|
||||
sparse_layout, sparse_pattern, self.dense_attn_mask = (
|
||||
self.get_attn_pattern(dtype, device))
|
||||
|
||||
if q_block_size is not None and q_block_size != block_size:
|
||||
if q_block_size > block_size:
|
||||
assert q_block_size % block_size == 0
|
||||
blocks_to_merge = q_block_size // block_size
|
||||
shape = sparse_pattern.shape
|
||||
sparse_pattern = sparse_pattern.view(shape[0], -1,
|
||||
blocks_to_merge,
|
||||
shape[-1])
|
||||
sparse_pattern = sparse_pattern.sum(2)
|
||||
sparse_layout = dense_to_crow_col(sparse_pattern)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Does not support smaller q_block_size. It will be slower."
|
||||
)
|
||||
|
||||
self.sparse_layout = sparse_layout
|
||||
|
||||
def get_attn_pattern(self, dtype, device):
|
||||
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
|
||||
self.n_heads,
|
||||
self.max_seqlen,
|
||||
self.max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size=self.block_size,
|
||||
local_blocks=self.local_blocks,
|
||||
vert_stride=self.vert_stride,
|
||||
homo_head=self.homo_head,
|
||||
return_dense=self.use_spda,
|
||||
dense_mask_type="bias",
|
||||
)
|
||||
if (not self.homo_head) and (self.active_head_range is not None):
|
||||
assert isinstance(self.active_head_range, tuple)
|
||||
assert (len(self.active_head_range) == 2)
|
||||
h_start, h_end = self.active_head_range
|
||||
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
|
||||
if self.use_spda:
|
||||
dense_attn_mask = dense_attn_mask[h_start:h_end]
|
||||
return sparse_layout, sparse_pattern, dense_attn_mask
|
||||
|
||||
def varlen_attn(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=None,
|
||||
sm_scale=None):
|
||||
"""
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,),
|
||||
indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify is when q is a mix of
|
||||
prefilling and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert (
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
), "Requires compute capability of 8 or above (Ampere or newer) to use \
|
||||
Triton kernel."
|
||||
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
|
||||
return blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
self.sparse_layout,
|
||||
block_size=self.block_size,
|
||||
q_block_size=self.q_block_size,
|
||||
max_seqlen=self.max_seqlen,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
|
||||
"""
|
||||
:param x: (total_tokens, n_heads, head_size)
|
||||
:return: (batch, n_heads, length, head_size)
|
||||
"""
|
||||
x_padded = x.new_empty(
|
||||
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
|
||||
1).unsqueeze(1))
|
||||
return x_padded.flatten(1, 2)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_unpad(x_padded, cu_seqlens):
|
||||
"""
|
||||
:param x_padded: (batch, n_heads, length, head_size)
|
||||
:return: (total_tokens, n_heads, head_size)
|
||||
"""
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
total_n_tokens = cu_seqlens[-1]
|
||||
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
|
||||
x_padded.size(3))
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
|
||||
return x
|
||||
|
||||
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""For CPU, V100 or other older GPUs.
|
||||
NOTE: torch SPDA supports nested tensor,
|
||||
but seems extremely slow. Choose to pad instead.
|
||||
"""
|
||||
assert (cu_seqlens_q is None or
|
||||
(cu_seqlens_q
|
||||
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
|
||||
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
|
||||
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
cu_seqlens = cu_seqlens_k.cpu()
|
||||
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
if (self.dense_attn_mask.dtype != q.dtype
|
||||
or self.dense_attn_mask.device != q.device):
|
||||
_, _, self.dense_attn_mask = self.get_attn_pattern(
|
||||
q.dtype, q.device)
|
||||
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
|
||||
|
||||
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
|
||||
k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
|
||||
for x in [k, v])
|
||||
spda_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
|
||||
return self.transpose_and_unpad(spda_output, cu_seqlens)
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||
`self.spda`(cpu, Volta, Turing or older)based on
|
||||
the type of device used and cuda compute capability.
|
||||
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify
|
||||
is when q is a mix of prefilling
|
||||
and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert k.dim() == 3
|
||||
if self.use_spda:
|
||||
return self.spda(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
return self.varlen_attn(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale)
|
@ -1,246 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Helper functions for 3D sparse pattern
|
||||
# These function are not optimized and very inefficient.
|
||||
# Avoid calling them too frequent or use a cache mechanism.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class csr_matrix:
|
||||
"""Simple implementation of CSR matrix conversion without scipy.
|
||||
This replaced scipy.sparse.csr_matrix() previously used."""
|
||||
|
||||
def __init__(self, input_array):
|
||||
if not isinstance(input_array, np.ndarray):
|
||||
raise ValueError("Input must be a NumPy array")
|
||||
|
||||
self.shape = input_array.shape
|
||||
rows, cols = self.shape
|
||||
data = []
|
||||
indices = []
|
||||
indptr = [0]
|
||||
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
if input_array[i, j]:
|
||||
data.append(input_array[i, j])
|
||||
indices.append(j)
|
||||
indptr.append(len(indices))
|
||||
|
||||
self.data = np.array(data)
|
||||
self.indices = np.array(indices)
|
||||
self.indptr = np.array(indptr)
|
||||
|
||||
|
||||
def dense_to_crow_col(x: torch.Tensor):
|
||||
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
||||
NOTE: col_indices padded -1
|
||||
"""
|
||||
device = x.device
|
||||
pad = -1
|
||||
dim = x.dim()
|
||||
assert x.dim() in (2, 3)
|
||||
if x.dim() == 2:
|
||||
x = x[None]
|
||||
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
||||
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
||||
cols = [torch.from_numpy(xi.indices) for xi in x]
|
||||
max_cols = max(len(xi) for xi in cols)
|
||||
cols = [
|
||||
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
|
||||
for xi in cols
|
||||
]
|
||||
cols = torch.vstack(cols)
|
||||
if dim == 2:
|
||||
crows = crows[0]
|
||||
cols = cols[0]
|
||||
return crows.to(device), cols.to(device)
|
||||
|
||||
|
||||
def crow_col_to_dense(crows: torch.Tensor,
|
||||
cols: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
dim = crows.dim()
|
||||
if dim == 1:
|
||||
crows = crows[None]
|
||||
cols = cols[None]
|
||||
device = crows.device
|
||||
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
||||
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
||||
x = torch.zeros(shape, dtype=dtype)
|
||||
for i in range(shape[0]):
|
||||
for j in range(shape[1]):
|
||||
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
|
||||
if dim == 1:
|
||||
x = x[0]
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def dense_to_ccol_row(x: torch.Tensor):
|
||||
"""Similar, but to CSC format"""
|
||||
x = x.transpose(-2, -1)
|
||||
return dense_to_crow_col(x)
|
||||
|
||||
|
||||
def ccol_row_to_dense(ccol: torch.Tensor,
|
||||
rows: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
||||
|
||||
|
||||
def _get_sparse_attn_mask_homo_head(
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 128,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
return_dense: bool = False,
|
||||
):
|
||||
"""
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be
|
||||
OOM if it is too big) if `return_dense==True`,
|
||||
otherwise, None
|
||||
"""
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[:, None]
|
||||
k_pos = torch.arange(num_blocks)[None]
|
||||
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = (dense_to_crow_col(
|
||||
block_mask_dense[-num_blocks_q:].contiguous()))
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def binary_mask_to_bias(mask_dense: torch.Tensor):
|
||||
mask_dense = 1 - mask_dense
|
||||
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
|
||||
return mask_dense
|
||||
|
||||
|
||||
def get_head_sliding_step(n_heads: int,
|
||||
vert_stride: int,
|
||||
homo_head: bool = False):
|
||||
if homo_head:
|
||||
return 0
|
||||
return max(1, int(vert_stride / n_heads))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_sparse_attn_mask(
|
||||
n_heads: int,
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 64,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
homo_head: bool = True,
|
||||
return_dense: bool = False,
|
||||
dense_mask_type: str = "binary",
|
||||
):
|
||||
"""
|
||||
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
||||
or "bias" (-inf for skip token, 0 or others)
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be OOM if it
|
||||
is too big) if `return_dense==True`, otherwise, None
|
||||
"""
|
||||
assert dense_mask_type in ("binary", "bias")
|
||||
if homo_head:
|
||||
with torch.no_grad():
|
||||
(crow, col), block_mask_dense, mask_dense = (
|
||||
_get_sparse_attn_mask_homo_head(
|
||||
q_len,
|
||||
max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
return_dense,
|
||||
))
|
||||
crow = crow[None].expand(n_heads, crow.shape[0])
|
||||
col = col[None].expand(n_heads, col.shape[0])
|
||||
if return_dense:
|
||||
mask_dense = mask_dense[None].expand(n_heads,
|
||||
*mask_dense.shape)
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
return (crow, col), block_mask_dense, mask_dense
|
||||
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[None, :, None]
|
||||
k_pos = torch.arange(num_blocks)[None, None]
|
||||
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
|
||||
mask_vert_strided = [
|
||||
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
|
||||
vert_stride == 0 for h in range(n_heads)
|
||||
]
|
||||
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
@ -143,7 +143,6 @@ def get_attn_backend(
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
@ -157,7 +156,6 @@ def get_attn_backend(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
@ -170,16 +168,9 @@ def _cached_get_attn_backend(
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
|
@ -1333,7 +1333,8 @@ class ModelConfig:
|
||||
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
if (self.hf_text_config.model_type == "deepseek_mtp"
|
||||
or self.hf_config.model_type == "mimo_mtp"):
|
||||
or self.hf_config.model_type == "mimo_mtp"
|
||||
or self.hf_config.model_type == "glm4_moe_mtp"):
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
else:
|
||||
@ -2663,7 +2664,15 @@ class SpeculativeConfig:
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"]
|
||||
})
|
||||
return hf_config
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"]
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
@ -2774,7 +2783,7 @@ class SpeculativeConfig:
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp")):
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
@ -4312,6 +4321,7 @@ class CompilationConfig:
|
||||
self.splitting_ops = [] if self.full_cuda_graph else [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.mamba_mixer2",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1358,10 +1358,10 @@ class EngineArgs:
|
||||
and not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||
supported = False
|
||||
if current_platform.is_rocm() or (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
): # handle hpu also for OOT platform
|
||||
if (current_platform.is_rocm()
|
||||
or (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100))
|
||||
or current_platform.is_tpu()):
|
||||
supported = True
|
||||
elif fp8_attention and will_use_fa:
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
@ -19,10 +20,22 @@ from .pythonic_tool_parser import PythonicToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
|
||||
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
||||
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
|
||||
"KimiK2ToolParser", "HunyuanA13BToolParser"
|
||||
"ToolParser",
|
||||
"ToolParserManager",
|
||||
"Granite20bFCToolParser",
|
||||
"GraniteToolParser",
|
||||
"Hermes2ProToolParser",
|
||||
"MistralToolParser",
|
||||
"Internlm2ToolParser",
|
||||
"Llama3JsonToolParser",
|
||||
"JambaToolParser",
|
||||
"Llama4PythonicToolParser",
|
||||
"PythonicToolParser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser",
|
||||
"xLAMToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
]
|
||||
|
402
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
Normal file
402
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
Normal file
@ -0,0 +1,402 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# code modified from deepseekv3_tool_parser.py
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("glm4_moe")
|
||||
class Glm4MoeModelToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
|
||||
# Updated regex for the XML-based format
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>\s*"
|
||||
r"(?P<function_name>[^\n<]+)\s*" # 函数名(到换行或 <)
|
||||
r"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
|
||||
r"<arg_value>[^<]*</arg_value>\s*)*)\s*"
|
||||
r"</tool_call>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Regex for parsing individual arguments
|
||||
self.arg_regex = re.compile(
|
||||
r"<arg_key>(?P<key>[^<]+)</arg_key>\s*<arg_value>(?P<value>[^<]*)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Streaming regex
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<function_name>[^\n<]+)\s*"
|
||||
r"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
|
||||
r"<arg_value>[^<]*</arg_value>\s*)*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# For streaming, we also need a regex to match just the function name
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<function_name>[^\n<]+)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
def _parse_arguments(self, args_text: str) -> str:
|
||||
"""Parse XML-based arguments into JSON format."""
|
||||
if not args_text or not args_text.strip():
|
||||
return "{}"
|
||||
|
||||
args_dict = {}
|
||||
matches = self.arg_regex.findall(args_text)
|
||||
|
||||
for key, value in matches:
|
||||
args_dict[key.strip()] = value.strip()
|
||||
|
||||
import json
|
||||
return json.dumps(args_dict, ensure_ascii=False)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
# Find all tool calls in the output
|
||||
function_call_matches = self.tool_call_regex.findall(model_output)
|
||||
|
||||
logger.debug("function_call_matches: %s", function_call_matches)
|
||||
|
||||
if not function_call_matches:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for i, match in enumerate(function_call_matches):
|
||||
function_name, function_args_xml = match
|
||||
function_name = function_name.strip()
|
||||
|
||||
# Parse XML arguments to JSON
|
||||
function_args_json = self._parse_arguments(function_args_xml)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{i}",
|
||||
type='function',
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args_json),
|
||||
))
|
||||
|
||||
# Extract content before the first tool call
|
||||
content = model_output[:model_output.find(self.
|
||||
tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=bool(tool_calls),
|
||||
tool_calls=tool_calls,
|
||||
content=content.strip() if content.strip() else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_call_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_id, tool_args = (current_tool_call_matches.groups())
|
||||
tool_name = tool_id.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_id_str, = current_tool_call_name_matches.groups()
|
||||
tool_name = tool_id_str.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id_str
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
tool_id = current_tool_call.get("id")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
||||
mamba2_metadata, mup_vector)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C,
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt,
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return get_mamba_state_shape(
|
||||
@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None,
|
||||
mamba2_metadata=None,
|
||||
mup_vector=mup_vector)
|
||||
|
||||
|
||||
def mamba_mixer2_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer2",
|
||||
op_func=mamba_mixer2,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer2_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ from transformers import BambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -122,11 +123,10 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
@ -169,7 +169,7 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if hasattr(config, "partial_rotary_factor"):
|
||||
rotary_dim = self.head_dim * config.partial_rotary_factor
|
||||
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||
elif hasattr(config, "attn_rotary_emb"):
|
||||
rotary_dim = config.attn_rotary_emb # for backward compatibility
|
||||
else:
|
||||
@ -258,6 +258,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BambaModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -10,6 +10,7 @@ from transformers import FalconH1Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -179,13 +180,15 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
hidden_states = self.mamba(
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
mup_vector=self.mup_vector,
|
||||
)
|
||||
return hidden_states, residual
|
||||
return output, residual
|
||||
|
||||
|
||||
class FalconH1AttentionDecoderLayer(nn.Module):
|
||||
@ -398,6 +401,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class FalconH1Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
685
vllm/model_executor/models/glm4_moe.py
Normal file
685
vllm/model_executor/models/glm4_moe.py
Normal file
@ -0,0 +1,685 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The ZhipuAI Team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Glm4MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
# noaux_tc is not set in transformers new config now
|
||||
self.gate.e_score_correction_bias = (nn.Parameter(
|
||||
torch.empty(config.n_routed_experts)))
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func="sigmoid",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = Glm4MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits) * self.routed_scaling_factor
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class Glm4MoeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 131072,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
qkv_bias: bool = False,
|
||||
use_qk_norm: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
partial_rotary_factor=partial_rotary_factor,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q.reshape(-1, self.num_heads,
|
||||
self.head_dim)).reshape(q.shape)
|
||||
k = self.k_norm(k.reshape(-1, self.num_kv_heads,
|
||||
self.head_dim)).reshape(k.shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
131072)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Glm4MoeAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=config.attention_bias,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_qk_norm=config.use_qk_norm,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace):
|
||||
self.mlp = Glm4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Glm4MoeModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
enable_eplb = vllm_config.parallel_config.enable_eplb
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Glm4MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Glm4MoeForCausalLM(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Glm4MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.expert_weights = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (config.num_hidden_layers -
|
||||
config.first_k_dense_replace)
|
||||
self.num_expert_groups = config.n_group
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
for layer in self.model.layers:
|
||||
assert isinstance(layer, Glm4MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Glm4MoE):
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = typing.cast(
|
||||
Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if f"layers.{layer_idx+i}." in weight_name:
|
||||
return layer_idx + i
|
||||
return None
|
307
vllm/model_executor/models/glm4_moe_mtp.py
Normal file
307
vllm/model_executor/models/glm4_moe_mtp.py
Normal file
@ -0,0 +1,307 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The ZhipuAI Team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
class Glm4MoeMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = Glm4MoeDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions == 0] = 0
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Glm4MoeMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
Glm4MoeMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
|
||||
current_step_idx)]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class Glm4MoeMTP(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.model.compute_logits(hidden_states, sampling_metadata,
|
||||
spec_step_idx)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
and rename shared layer weights to be top level.
|
||||
"""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||
]
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(f"model.layers.{spec_layer}.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
elif shared_weight:
|
||||
# treat shared weights as top level weights
|
||||
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||
return name
|
@ -11,6 +11,7 @@ from transformers import GraniteMoeHybridConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -104,9 +105,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
hidden_states = residual + output * self.residual_multiplier
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
@ -307,6 +308,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GraniteMoeHybridModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -10,6 +10,7 @@ from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -79,11 +80,12 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
return hidden_states, residual
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
return output, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Mamba2Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -172,9 +173,9 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
return hidden_states, residual
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
return output, residual
|
||||
|
||||
|
||||
class NemotronHAttention(nn.Module):
|
||||
@ -292,6 +293,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class NemotronHModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -1,465 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
def load_column_parallel_weight(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
tp = get_tensor_model_parallel_world_size()
|
||||
rk = get_tensor_model_parallel_rank()
|
||||
assert param.size(0) * tp == loaded_weight.size(0)
|
||||
s = rk * param.size(0)
|
||||
e = (rk + 1) * param.size(0)
|
||||
loaded_weight = loaded_weight[s:e]
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class HeadMajorQKVParallelLinear(QKVParallelLinear):
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
return load_column_parallel_weight(param, loaded_weight)
|
||||
|
||||
|
||||
class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
return load_column_parallel_weight(param, loaded_weight)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def quick_gelu(x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def gegelu(input, limit: Optional[float] = None):
|
||||
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
||||
if limit is not None:
|
||||
a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
|
||||
a_gelu.clamp(min=None, max=limit))
|
||||
a_linear = torch.where(
|
||||
torch.isinf(a_linear),
|
||||
a_linear,
|
||||
a_linear.clamp(min=-limit, max=limit),
|
||||
)
|
||||
out_gelu = quick_gelu(a_gelu)
|
||||
return out_gelu * (a_linear + 1)
|
||||
|
||||
|
||||
class Phi3SmallMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert (self.config.hidden_act == "gegelu"
|
||||
), "Only `gegelu` is supported for the 4.7 series of models .."
|
||||
self.hidden_size = config.hidden_size
|
||||
self.gegelu_limit = config.gegelu_limit
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.up_proj = HeadMajorColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
2 * [self.intermediate_size],
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.up_proj(x)
|
||||
x = gegelu(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Phi3SmallSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
self.sparse_block_size = config.blocksparse_block_size
|
||||
self.homo_heads = config.blocksparse_homo_head_pattern
|
||||
self.local_blocks = config.blocksparse_num_local_blocks
|
||||
self.vert_stride = config.blocksparse_vert_stride
|
||||
|
||||
assert (config.blocksparse_block_size ==
|
||||
config.blocksparse_triton_kernel_block_size)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Number of Query Heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# Number of total Key Value Heads before tensor parallel
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
|
||||
if self.tp_size > 1:
|
||||
assert self.num_key_value_heads % self.tp_size == 0
|
||||
self.num_kv_heads_per_partition = max(
|
||||
1, self.num_key_value_heads // self.tp_size)
|
||||
self.num_heads_per_partition = self.num_heads // self.tp_size
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_embedding_base = config.rope_embedding_base
|
||||
self.rope_position_scale = config.rope_position_scale
|
||||
self.is_causal = True
|
||||
|
||||
norm_factor = None
|
||||
if config.mup_use_scaling:
|
||||
norm_factor = self.head_dim / config.mup_attn_multiplier
|
||||
else:
|
||||
norm_factor = math.sqrt(self.head_dim)
|
||||
self.scale = 1 / norm_factor
|
||||
|
||||
self.query_key_value = HeadMajorQKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.dense = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
|
||||
if getattr(self.config, "rope_scaling", None) is not None:
|
||||
rope_scaling = self.config.rope_scaling
|
||||
for key in rope_scaling:
|
||||
if isinstance(rope_scaling[key], list):
|
||||
rope_scaling[key] = tuple(rope_scaling[key])
|
||||
|
||||
if "factor" not in rope_scaling:
|
||||
rope_scaling["factor"] = self.rope_position_scale
|
||||
else:
|
||||
rope_scaling = {
|
||||
"rope_type": "linear",
|
||||
"factor": self.rope_position_scale,
|
||||
}
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_embedding_base,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
# blocksparse params
|
||||
self.blocksparse_block_size = config.blocksparse_block_size
|
||||
self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
|
||||
self.blocksparse_vert_stride = config.blocksparse_vert_stride
|
||||
|
||||
use_dense_attn = (getattr(self.config,
|
||||
"dense_attention_every_n_layers", None)
|
||||
and (self.layer_idx + 1) %
|
||||
self.config.dense_attention_every_n_layers == 0)
|
||||
|
||||
bs_params = None
|
||||
if not use_dense_attn:
|
||||
bs_params = {
|
||||
'max_seqlen': self.max_position_embeddings,
|
||||
'num_heads': self.num_heads_per_partition,
|
||||
"num_kv_heads": self.num_kv_heads_per_partition,
|
||||
"block_size": self.sparse_block_size,
|
||||
"local_blocks": self.local_blocks,
|
||||
"vert_stride": self.vert_stride,
|
||||
"homo_head": self.homo_heads
|
||||
}
|
||||
|
||||
self.attn = Attention(self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads_per_partition,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
blocksparse_params=bs_params,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor]]]:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
|
||||
qkv = qkv.view(qkv.shape[:-1] +
|
||||
(-1, (self.num_q_per_kv + 2), self.head_dim))
|
||||
q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
|
||||
|
||||
# NOTE: this is required by RotaryEmbed, which indeed does not have to
|
||||
# TODO: allow 3D QK for rotary forward
|
||||
q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
|
||||
k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partition)
|
||||
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partition)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Phi3SmallDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Phi3SmallSelfAttention(config,
|
||||
layer_idx,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Phi3SmallMLP(config, quant_config)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Phi3SmallModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Phi3SmallDecoderLayer(config,
|
||||
int(prefix.split('.')[-1]),
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
if (self.mup_embedding_multiplier is not None
|
||||
and self.mup_embedding_multiplier > 0.0):
|
||||
hidden_states = hidden_states * self.mup_embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_suffix={"rotary_emb.inv_freq": None})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Phi3SmallModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.vocab_size = config.vocab_size
|
||||
self.mup_width_multiplier = config.mup_width_multiplier
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# tokens in tiktoken but not used
|
||||
if hasattr(config, 'dummy_token_indices'):
|
||||
device = self.lm_head.weight.device
|
||||
self.register_buffer('dummy_token_indices',
|
||||
torch.LongTensor(
|
||||
config.dummy_token_indices).to(device),
|
||||
persistent=False)
|
||||
else:
|
||||
self.dummy_token_indices = None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.lm_head = value
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
if self.dummy_token_indices is not None and logits is not None:
|
||||
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
||||
logits = logits / self.mup_width_multiplier
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
output_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
output_hidden_states = output_hidden_states
|
||||
return output_hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head.weight"]
|
||||
if self.config.tie_word_embeddings else None))
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
@ -67,6 +67,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
@ -110,7 +111,6 @@ _TEXT_GENERATION_MODELS = {
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
|
||||
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
||||
@ -245,6 +245,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
# Temporarily disabled.
|
||||
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
||||
|
@ -17,6 +17,7 @@ from transformers import Zamba2Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -548,14 +549,16 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Process through Mamba mixer
|
||||
hidden_states = self.mamba(
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
# residual connection after mamba
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = residual + output
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -646,6 +649,7 @@ class Zamba2HybridLayer(nn.Module):
|
||||
return layer_outputs
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Zamba2Model(nn.Module):
|
||||
"""Core Zamba2 model combining transformer and Mamba architectures.
|
||||
|
||||
|
@ -57,7 +57,6 @@ class _Backend(enum.Enum):
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
@ -35,7 +35,9 @@ class TpuPlatform(Platform):
|
||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||
simple_compile_backend: str = "openxla"
|
||||
|
||||
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
|
||||
supported_quantization: list[str] = [
|
||||
"fp8", "tpu_int8", "compressed-tensors"
|
||||
]
|
||||
|
||||
additional_env_vars: list[str] = [
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
@ -14,4 +15,5 @@ __all__ = [
|
||||
"GraniteReasoningParser",
|
||||
"HunyuanA13BReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
"Glm4MoeModelReasoningParser",
|
||||
]
|
||||
|
151
vllm/reasoning/glm4_moe_reasoning_parser.py
Normal file
151
vllm/reasoning/glm4_moe_reasoning_parser.py
Normal file
@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("glm4_moe")
|
||||
class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for the Glm4MoeModel model.
|
||||
|
||||
The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning
|
||||
text within its output. The model provides a strict switch to disable
|
||||
reasoning output via the 'enable_thinking=False' parameter. This parser
|
||||
extracts the reasoning content enclosed by <think> and </think> tokens
|
||||
from the model's output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
self.think_start_token = "<think>"
|
||||
self.think_end_token = "</think>"
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if (self.think_start_token_id is None
|
||||
or self.think_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Glm4MoeModel reasoning parser could not locate "
|
||||
"think start/end tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content after the end tokens
|
||||
"""
|
||||
if self.think_end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
Handles streaming output where previous + delta = current.
|
||||
Uses token IDs for faster processing.
|
||||
For text <think>abc</think>xyz:
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
||||
self.think_start_token_id, self.think_end_token_id
|
||||
]):
|
||||
return None
|
||||
|
||||
if self.think_start_token_id in previous_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in previous, </think> in delta,
|
||||
# extract reasoning content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
# <think> in previous, </think> in previous,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
# <think> in previous, no </think> in previous or delta,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
elif self.think_start_token_id in delta_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in delta, </think> in delta, extract reasoning content
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[start_index +
|
||||
len(self.think_start_token
|
||||
):end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
else:
|
||||
# <think> in delta, no </think> in delta,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
else:
|
||||
# thinking is disabled, just content
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
For text <think>abc</think>xyz:
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||
"""
|
||||
|
||||
# Check if the model output contains the <think> and </think> tokens.
|
||||
if (self.think_start_token not in model_output
|
||||
or self.think_end_token not in model_output):
|
||||
return None, model_output
|
||||
# Check if the <think> is present in the model output, remove it
|
||||
# if it is present.
|
||||
model_output_parts = model_output.partition(self.think_start_token)
|
||||
model_output = model_output_parts[2] if model_output_parts[
|
||||
1] else model_output_parts[0]
|
||||
# Check if the model output contains the </think> tokens.
|
||||
# If the end token is not found, return the model output as is.
|
||||
if self.think_end_token not in model_output:
|
||||
return None, model_output
|
||||
|
||||
# Extract reasoning content from the model output.
|
||||
reasoning_content, _, content = model_output.partition(
|
||||
self.think_end_token)
|
||||
|
||||
final_content = content or None
|
||||
return reasoning_content, final_content
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -443,7 +443,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -451,9 +450,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -349,15 +349,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
@ -490,7 +490,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
|
@ -3,7 +3,7 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
@ -342,15 +342,10 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
# TODO we should support this :think
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -190,7 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -754,7 +754,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -74,7 +74,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -82,17 +81,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -119,7 +119,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -127,20 +126,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -167,7 +167,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -175,20 +174,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
assert (num_heads == 16 or num_heads == 128), (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value.")
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -42,7 +42,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -50,17 +49,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_builder as xb
|
||||
@ -24,6 +24,19 @@ logger = init_logger(__name__)
|
||||
# TPU requires the head size to be a multiple of 128.
|
||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||
|
||||
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
|
||||
# from to fp32 directly. That's why it has a dtype mapping different from GPU
|
||||
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8": torch.float8_e4m3fn,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
"int8": torch.int8,
|
||||
"uint8": torch.uint8,
|
||||
}
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -132,7 +145,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
@ -142,9 +154,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("Paged attention Pallas kernel does "
|
||||
"not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@ -156,10 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
@ -167,6 +172,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
self.kv_cache_quantized_dtype = None
|
||||
if kv_cache_dtype != "auto":
|
||||
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
|
||||
kv_cache_dtype.lower().strip())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@ -200,7 +210,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
@ -221,10 +230,21 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(
|
||||
key, value, kv_cache, slot_mapping,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||
attn_metadata.num_kv_update_slices)
|
||||
attn_metadata.num_kv_update_slices,
|
||||
self.kv_cache_quantized_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.kv_cache_quantized_dtype is not None and (
|
||||
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
|
||||
raise ValueError(
|
||||
"k_scale_float and v_scale_float must be non-zero")
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
@ -242,6 +262,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
@ -257,18 +279,32 @@ def write_to_kv_cache(
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
|
||||
if kv_cache_quantized_dtype is not None:
|
||||
dtype_info = torch.finfo(kv_cache_quantized_dtype)
|
||||
key = key.to(torch.float32) / k_scale
|
||||
# NOTE: clamp is added here to avoid out of range of quantized dtype
|
||||
key = torch.clamp(key, dtype_info.min, dtype_info.max)
|
||||
key = key.to(kv_cache_quantized_dtype)
|
||||
value = value.to(torch.float32) / v_scale
|
||||
value = torch.clamp(value, dtype_info.min, dtype_info.max)
|
||||
value = value.to(kv_cache_quantized_dtype)
|
||||
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
head_size)
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -334,15 +334,11 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"AiterFlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -205,15 +205,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -51,7 +51,13 @@ class RayGaugeWrapper(RayPrometheusMetric):
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None):
|
||||
labelnames: Optional[list[str]] = None,
|
||||
multiprocess_mode: Optional[str] = ""):
|
||||
|
||||
# All Ray metrics are keyed by WorkerId, so multiprocess modes like
|
||||
# "mostrecent", "all", "sum" do not apply. This logic can be manually
|
||||
# implemented at the observability layer (Prometheus/Grafana).
|
||||
del multiprocess_mode
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self.metric = ray_metrics.Gauge(name=name,
|
||||
description=documentation,
|
||||
|
@ -2079,7 +2079,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_table_tensor=self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.input_batch.
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_reqs])
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens])
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builders[
|
||||
kv_cache_group_id].build_for_cudagraph_capture(
|
||||
@ -2753,9 +2753,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if not self.vllm_config.model_config.enforce_eager:
|
||||
raise NotImplementedError(
|
||||
"Mamba with cuda graph is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
|
@ -32,9 +32,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
||||
is_pin_memory_available, prev_power_of_2)
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
|
||||
prev_power_of_2)
|
||||
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
|
||||
PallasAttentionBackend,
|
||||
PallasMetadata,
|
||||
get_page_size_bytes)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
@ -142,11 +143,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
if cache_config.cache_dtype == "auto":
|
||||
model_dtype = self.dtype
|
||||
if isinstance(model_dtype, str):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||
else:
|
||||
self.kv_cache_dtype = model_dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
self._hidden_states_dtype = self.dtype
|
||||
|
||||
|
@ -77,7 +77,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
"mlp_speculator",
|
||||
"eagle",
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp")) \
|
||||
"glm4_moe_mtp",
|
||||
"mimo_mtp")) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
|
Reference in New Issue
Block a user