Add the support for the qwen3 next model (a hybrid attention model). (#24526)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Tao He
2025-09-11 15:32:09 +08:00
committed by GitHub
parent 2048c4e379
commit e93f4cc9e3
29 changed files with 2476 additions and 61 deletions

View File

@ -1 +1,2 @@
collect_env.py
vllm/model_executor/layers/fla/ops/*.py

View File

@ -403,6 +403,7 @@ th {
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3NextForCausalLM` | Qwen3.5MoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |

View File

@ -228,6 +228,7 @@ fo = "fo"
ba = "ba"
[tool.typos.type.py.extend-words]
ba = "ba"
[tool.typos.type.cpp]
extend-glob = ["*.cu"]

View File

@ -326,6 +326,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
@ -640,7 +642,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online=False),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
}
_TRANSFORMERS_BACKEND_MODELS = {

View File

@ -1508,7 +1508,8 @@ class ModelConfig:
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
or self.hf_config.model_type == "ernie_mtp"
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
@ -1571,15 +1572,28 @@ class ModelConfig:
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
if layers_block_type_value is None and attn_type_list is None:
# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
return sum(t == "full_attention"
for t in layer_types_value[start:end])
elif getattr(block_type, "value",
block_type) == "linear_attention":
return sum(t == "linear_attention"
for t in layer_types_value[start:end])
else:
return sum(t == getattr(block_type, "value", block_type)
for t in layer_types_value[start:end])
if (layers_block_type_value is None and attn_type_list is None
and layer_types_value is None):
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
"layers_block_type or an attn_type_list, or a layer_types "
"in the hf_config, cannot determine the num of "
f"{block_type.value} layers")
return sum(t == 1 for t in attn_type_list[start:end])
def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
@ -1866,7 +1880,7 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
"ernie_mtp", "qwen3_next_mtp"]
@config
@ -2007,7 +2021,15 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
return hf_config
@ -2028,9 +2050,13 @@ class SpeculativeConfig:
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
("mimo","ernie4_5_moe", "qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
@ -2140,6 +2166,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
@ -2355,7 +2390,8 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
def __repr__(self) -> str:
method = self.method

View File

@ -341,6 +341,7 @@ class CompilationConfig:
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
]
def compute_hash(self) -> str:

View File

@ -14,7 +14,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp, safe_exp
from .op import exp
from .utils import is_nvidia_hopper, use_cuda_graph
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
boundary_check=(0, 1))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:

View File

@ -16,7 +16,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp, safe_exp
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),

View File

@ -14,7 +14,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import safe_exp
from .op import exp
@triton.heuristics({
@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = tl.arange(0, BT)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * safe_exp(b_g_diff)
b_A = b_A * exp(b_g_diff)
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))

View File

@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)

View File

@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)

View File

@ -28,11 +28,6 @@ else:
log2 = tl.log2
@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float('-inf')))
if not hasattr(tl, 'gather'):
@triton.jit

View File

@ -70,6 +70,15 @@ class MambaStateDtypeCalculator:
model_dtype)
return (conv_state_dtype, )
@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
class MambaStateShapeCalculator:
@ -163,3 +172,31 @@ class MambaStateShapeCalculator:
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@classmethod
def gated_delta_net_state_shape(
cls,
tp_world_size: int,
num_k_heads: int,
num_v_heads: int,
head_k_dim: int,
head_v_dim: int,
conv_kernel_size: int,
num_spec: int = 0,
use_v1: bool = True,
):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = (
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads,
tp_world_size), head_k_dim, head_v_dim)
return conv_state_shape, temporal_state_shape

View File

@ -464,7 +464,9 @@ def causal_conv1d_fn(
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
assert (num_cache_lines, dim, width - 1) == conv_states.shape
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and width - 1 <= conv_states.shape[2])
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel(
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel(
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel(
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
@ -663,7 +668,8 @@ def _causal_conv1d_update_kernel(
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel(
# not processing as this is not the actual sequence
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
1)
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
@ -695,10 +720,14 @@ def _causal_conv1d_update_kernel(
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# The conv_state updates works in a sliding window manner,
# at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
@ -820,6 +849,7 @@ def causal_conv1d_update(
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
@ -890,10 +920,11 @@ def causal_conv1d_update(
) # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
state_len = width - 1
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
state_len = width - 1 + (seqlen - 1) # effective state_len needed
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
@ -910,6 +941,7 @@ def causal_conv1d_update(
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
out,
# Matrix dimensions
batch,
@ -926,6 +958,7 @@ def causal_conv1d_update(
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
@ -936,6 +969,7 @@ def causal_conv1d_update(
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,

View File

@ -312,7 +312,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
# TODO(tdoublep): remove as full cuda graph support is added
FCG_NOT_SUPPORTED_MODELS = [
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
"Lfm2ForCausalLM",
"MiniMaxText01ForCausalLM",
]
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,285 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer,
Qwen3NextRMSNorm)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class Qwen3NextMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False)
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
config,
layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
) for idx in range(self.num_mtp_layers))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = (spec_step_idx % self.num_mtp_layers)
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, \
"Qwen3NextMTP currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
hidden_states = self.model(input_ids, positions, hidden_states,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif not any(key in name for key in shared_weight_names):
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))

View File

@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
@ -285,6 +286,7 @@ _SPECULATIVE_DECODING_MODELS = {
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),

View File

@ -79,7 +79,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
step3_text="Step3TextConfig",
)
qwen3_next="Qwen3NextConfig")
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config",

View File

@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
@ -50,4 +51,5 @@ __all__ = [
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
"Qwen3NextConfig",
]

View File

@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3-Next model configuration"""
from transformers.configuration_utils import (PretrainedConfig,
layer_type_validation)
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen3NextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
Qwen3-Next model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
`inputs_ids`.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
head_dim (`int`, *optional*, defaults to 256):
Projection weights dimension in multi-head attention.
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
Kernel size of the convolution used in linear attention layers.
linear_key_head_dim (`int`, *optional*, defaults to 128):
Dimension of each key head in linear attention.
linear_value_head_dim (`int`, *optional*, defaults to 128):
Dimension of each value head in linear attention.
linear_num_key_heads (`int`, *optional*, defaults to 16):
Number of key heads used in linear attention layers.
linear_num_value_heads (`int`, *optional*, defaults to 32):
Number of value heads used in linear attention layers.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 10):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 512):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
layer_types (`list[str]`, *optional*):
Types of each layer (attention or linear).
```python
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
>>> # Initializing a Qwen3Next style configuration
>>> configuration = Qwen3NextConfig()
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
>>> model = Qwen3NextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
""" # noqa: E501
model_type = "qwen3_next"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=48,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.25,
attention_bias=False,
attention_dropout=0.0,
head_dim=256,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=32,
decoder_sparse_step=1,
moe_intermediate_size=512,
shared_expert_intermediate_size=512,
num_experts_per_tok=10,
num_experts=512,
norm_topk_prob=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
layer_types=None,
**kwargs,
):
if mlp_only_layers is None:
mlp_only_layers = []
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"linear_attention" if bool((i + 1) % 4) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
# linear attention part
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = mlp_only_layers
__all__ = ["Qwen3NextConfig"]

View File

@ -0,0 +1,319 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
has_initial_state: Optional[torch.Tensor] = None
spec_query_start_loc: Optional[
torch.Tensor] = None # shape: [num_spec_decodes + 1,]
non_spec_query_start_loc: Optional[
torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,]
spec_state_indices_tensor: Optional[
torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[
torch.Tensor] = None # shape: [batch - num_spec_decodes,]
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[
torch.
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
class GDNAttentionMetadataBuilder(
AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.bool,
device=device,
)
self.spec_token_masks = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1), ),
dtype=torch.bool,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1, ),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1, ),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_draft_tokens: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens
if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0):
spec_sequence_masks = None
else:
spec_sequence_masks = (num_draft_tokens > 0) & (
context_lens_tensor +
(num_draft_tokens + 1) == seq_lens_tensor)
if spec_sequence_masks.sum().item() == 0:
spec_sequence_masks = None
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1))
num_spec_decodes = 0
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
num_spec_decodes = spec_sequence_masks.sum().item()
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item(
) - num_decode_tokens
if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
(min(num_spec_decodes *
(self.num_spec + 1), query_start_loc[-1].item())),
dtype=torch.bool,
device=query_start_loc.device)
spec_state_indices_tensor = m.block_table_tensor[:, :self.
num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens)
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, :self.num_spec + 1]
non_spec_state_indices_tensor = \
m.block_table_tensor[~spec_sequence_masks, 0]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device)
torch.cumsum(query_lens[spec_sequence_masks],
dim=0,
out=spec_query_start_loc[1:])
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device)
torch.cumsum(query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:])
num_spec_decode_tokens = min(
num_spec_decodes * (self.num_spec + 1),
spec_token_masks.size(0))
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
else:
has_initial_state = None
# prepare tensors for cudagraph
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens // (self.num_spec + 1)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True)
spec_state_indices_tensor = self.spec_state_indices_tensor[:
batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert spec_token_masks is not None
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True)
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
spec_token_masks[spec_token_masks.size(0):].fill_(False)
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True)
spec_num_query_tokens = spec_query_start_loc[
-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1]
spec_query_start_loc[num_spec_decodes +
1:].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (self.use_full_cuda_graph and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True)
non_spec_state_indices_tensor = \
self.non_spec_state_indices_tensor[:batch_size]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[:num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True)
non_spec_num_query_tokens = non_spec_query_start_loc[
-1] # type: ignore[index]
non_spec_query_start_loc = \
self.non_spec_query_start_loc[:batch_size + 1]
non_spec_query_start_loc[num_decodes +
1:].fill_(non_spec_num_query_tokens)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
and ((m.num_reqs + 1) * (self.num_spec + 1)
>= m.num_actual_tokens)), \
"GDN only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_accepted_tokens = torch.full((m.num_reqs, ),
m.max_query_len,
dtype=torch.int32,
device=m.query_start_loc.device)
num_drafted_tokens = torch.full((m.num_reqs, ),
self.num_spec,
dtype=torch.int32,
device=m.query_start_loc.device)
# Fixes query-start loc for spec-sequence-indices.
m.query_start_loc = torch.arange(0,
m.num_actual_tokens + 1,
step=m.max_query_len,
device=m.query_start_loc.device,
dtype=torch.int32)
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)

View File

@ -559,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager):
num_running_requests: int) -> int:
return 0
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks
"""
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (self.kv_cache_spec.block_size *
self.kv_cache_spec.num_speculative_blocks)
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
len(self.req_to_blocks[request_id]))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return num_new_blocks + num_evictable_computed_blocks
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
new_blocks = super().allocate_new_blocks(request_id, num_tokens)
assert len(self.req_to_blocks[request_id]) == 1, (
"MambaManager should only allocate 1 block for each request.")
return new_blocks
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (self.kv_cache_spec.block_size *
self.kv_cache_spec.num_speculative_blocks)
return super().allocate_new_blocks(request_id, num_tokens)
class CrossAttentionManager(SingleTypeKVCacheManager):

View File

@ -194,6 +194,7 @@ class MambaSpec(KVCacheSpec):
dtypes: tuple[torch.dtype]
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
num_speculative_blocks: int = 0
@property
def page_size_bytes(self) -> int:

View File

@ -218,7 +218,7 @@ class EagleProposer:
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method in ("deepseek_mtp", "ernie_mtp"):
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"):
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
@ -322,12 +322,18 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method in ("deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp"):
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)

View File

@ -156,9 +156,14 @@ class BlockTable:
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_sizes: list[int]) -> None:
def __init__(self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
@ -170,10 +175,11 @@ class MultiGroupBlockTable:
dcp_world_size = 1
self.block_tables = [
BlockTable(block_size, max_num_reqs,
cdiv(max_model_len, block_size * dcp_world_size),
max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes
BlockTable(
block_size, max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device) for block_size in block_sizes
]
def append_row(self, block_ids: tuple[list[int], ...],

View File

@ -83,6 +83,7 @@ class InputBatch:
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@ -127,6 +128,7 @@ class InputBatch:
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
num_speculative_tokens=num_speculative_tokens,
)
# Sampling-related.
@ -202,6 +204,14 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \
self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
@ -394,6 +404,9 @@ class InputBatch:
else:
raise NotImplementedError("Unrecognized request type")
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
@ -515,6 +528,8 @@ class InputBatch:
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
@ -609,6 +624,8 @@ class InputBatch:
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.num_accepted_tokens_cpu[
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator

View File

@ -53,9 +53,9 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend,
@ -324,6 +324,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size,
dtype=self.dtype,
numpy=False)
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -663,6 +667,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor) -> None:
"""Update the cached states after model execution.
This is used for MTP/EAGLE for hybrid models, as in linear attention,
only the last token's state is kept. In MTP/EAGLE, for draft tokens
the state are kept util we decide how many tokens are accepted for
each sequence, and a shifting is done during the next iteration
based on the number of accepted tokens.
"""
if not self.model_config.is_hybrid or not self.speculative_config:
return
# Find the number of accepted tokens for each sequence.
num_accepted_tokens = (torch.cat(
[
output_token_ids,
torch.full((output_token_ids.size(0), 1),
-1,
device=output_token_ids.device),
],
dim=1) == -1).int().argmax(-1).cpu().numpy()
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
def _init_mrope_positions(self, req_state: CachedRequestState):
image_grid_thw = []
video_grid_thw = []
@ -936,6 +965,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
@ -950,6 +980,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
@ -964,6 +997,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
if use_spec_decode:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
@ -1034,10 +1072,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
builder,
)
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder,
GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
)
**extra_attn_metadata_args)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@ -1814,6 +1861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
return sampler_output
@ -2644,13 +2692,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Note: Overriding max_query_len to be the prefill tokens
max_query_len = num_prefill_tokens
elif uniform_decode:
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
num_reqs = num_tokens // max_query_len
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
num_scheduled_tokens_list[-1] += num_tokens % max_query_len
else:
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
@ -3297,6 +3344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0),
)
def _allocate_kv_cache_tensors(
@ -3647,7 +3697,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None:
if (self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
@ -3666,7 +3718,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type)
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0),
)
return kv_cache_spec

View File

@ -78,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase):
"deepseek_mtp",
"glm4_moe_mtp",
"mimo_mtp",
"ernie_mtp")) \
"ernie_mtp",
"qwen3_next_mtp")) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner