From e93f4cc9e37484009f74e15d3111a1f335c532a5 Mon Sep 17 00:00:00 2001 From: Tao He Date: Thu, 11 Sep 2025 15:32:09 +0800 Subject: [PATCH] Add the support for the qwen3 next model (a hybrid attention model). (#24526) Signed-off-by: Tao He Co-authored-by: Jee Jee Li --- .yapfignore | 1 + docs/models/supported_models.md | 1 + pyproject.toml | 1 + tests/models/registry.py | 6 +- vllm/config/__init__.py | 56 +- vllm/config/compilation.py | 1 + .../layers/fla/ops/chunk_delta_h.py | 5 +- vllm/model_executor/layers/fla/ops/chunk_o.py | 9 +- .../layers/fla/ops/chunk_scaled_dot_kkt.py | 10 +- .../layers/fla/ops/fused_recurrent.py | 4 +- vllm/model_executor/layers/fla/ops/l2norm.py | 2 +- vllm/model_executor/layers/fla/ops/op.py | 5 - .../layers/mamba/mamba_utils.py | 37 + .../layers/mamba/ops/causal_conv1d.py | 48 +- vllm/model_executor/models/config.py | 3 +- vllm/model_executor/models/qwen3_next.py | 1294 +++++++++++++++++ vllm/model_executor/models/qwen3_next_mtp.py | 285 ++++ vllm/model_executor/models/registry.py | 2 + vllm/transformers_utils/config.py | 2 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/qwen3_next.py | 275 ++++ vllm/v1/attention/backends/gdn_attn.py | 319 ++++ vllm/v1/core/single_type_kv_cache_manager.py | 44 +- vllm/v1/kv_cache_interface.py | 1 + vllm/v1/spec_decode/eagle.py | 10 +- vllm/v1/worker/block_table.py | 20 +- vllm/v1/worker/gpu_input_batch.py | 17 + vllm/v1/worker/gpu_model_runner.py | 74 +- vllm/worker/worker.py | 3 +- 29 files changed, 2476 insertions(+), 61 deletions(-) create mode 100644 vllm/model_executor/models/qwen3_next.py create mode 100644 vllm/model_executor/models/qwen3_next_mtp.py create mode 100644 vllm/transformers_utils/configs/qwen3_next.py create mode 100644 vllm/v1/attention/backends/gdn_attn.py diff --git a/.yapfignore b/.yapfignore index 2d6dcf8380..3815825903 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1 +1,2 @@ collect_env.py +vllm/model_executor/layers/fla/ops/*.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 851b1b58d4..a3cc1cda5f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -403,6 +403,7 @@ th { | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3NextForCausalLM` | Qwen3.5MoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | diff --git a/pyproject.toml b/pyproject.toml index e63f8aeae2..7aa0371f38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,6 +228,7 @@ fo = "fo" ba = "ba" [tool.typos.type.py.extend-words] +ba = "ba" [tool.typos.type.cpp] extend-glob = ["*.cu"] diff --git a/tests/models/registry.py b/tests/models/registry.py index dfeaee1acf..0c77ec5ef1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -326,6 +326,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", + min_transformers_version="4.56.2"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 trust_remote_code=True, @@ -640,7 +642,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, - speculative_model="XiaomiMiMo/MiMo-7B-RL") + speculative_model="XiaomiMiMo/MiMo-7B-RL"), + "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", + min_transformers_version="4.56.2"), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index f62e4468ef..587cfab355 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1508,7 +1508,8 @@ class ModelConfig: if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp"): + or self.hf_config.model_type == "ernie_mtp" + or self.hf_config.model_type == "qwen3_next_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1571,15 +1572,28 @@ class ModelConfig: if attn_type_list: return sum(t == 1 for t in attn_type_list[start:end]) - if layers_block_type_value is None and attn_type_list is None: + # Hybrid model Qwen3Next + layer_types_value = getattr(self.hf_config, "layer_types", None) + if layer_types_value is not None: + if getattr(block_type, "value", block_type) == "attention": + return sum(t == "full_attention" + for t in layer_types_value[start:end]) + elif getattr(block_type, "value", + block_type) == "linear_attention": + return sum(t == "linear_attention" + for t in layer_types_value[start:end]) + else: + return sum(t == getattr(block_type, "value", block_type) + for t in layer_types_value[start:end]) + + if (layers_block_type_value is None and attn_type_list is None + and layer_types_value is None): raise ValueError( "The model is an hybrid without a" - "layers_block_type or an attn_type_list in the hf_config," - "cannot determine the num of " + "layers_block_type or an attn_type_list, or a layer_types " + "in the hf_config, cannot determine the num of " f"{block_type.value} layers") - return sum(t == 1 for t in attn_type_list[start:end]) - def get_mamba_chunk_size(self) -> Optional[int]: """ Returns the mamba chunk size if it exists @@ -1866,7 +1880,7 @@ class DeviceConfig: SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp"] + "ernie_mtp", "qwen3_next_mtp"] @config @@ -2007,7 +2021,15 @@ class SpeculativeConfig: "n_predict": n_predict, "architectures": ["ErnieMTPModel"] }) - return hf_config + + if hf_config.model_type == "qwen3_next": + hf_config.model_type = "qwen3_next_mtp" + if hf_config.model_type == "qwen3_next_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["Qwen3NextMTP"] + }) return hf_config @@ -2028,9 +2050,13 @@ class SpeculativeConfig: (self.target_model_config.hf_text_config.model_type \ == "deepseek_v3" or self.target_model_config.hf_text_config.model_type in - ("mimo","ernie4_5_moe")): + ("mimo","ernie4_5_moe", "qwen3_next")): # use the draft model from the same model: self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" else: @@ -2140,6 +2166,15 @@ class SpeculativeConfig: "one layer. Might need some code changes " \ "to support multiple layers." ) + elif (self.draft_model_config.hf_config.model_type == + "qwen3_next_mtp"): + self.method = "qwen3_next_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Qwen3Next MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( @@ -2355,7 +2390,8 @@ class SpeculativeConfig: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp") def __repr__(self) -> str: method = self.method diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 09600e96a1..f8ccc20222 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -341,6 +341,7 @@ class CompilationConfig: "vllm.short_conv", "vllm.linear_attention", "vllm.plamo2_mamba_mixer", + "vllm.gdn_attention", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index eac56ef352..34006f87f4 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -14,7 +14,7 @@ import torch from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices, prepare_chunk_offsets -from .op import exp, safe_exp +from .op import exp from .utils import is_nvidia_hopper, use_cuda_graph NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] @@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( boundary_check=(0, 1)) if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T last_idx = min((i_t + 1) * BT, T) - 1 b_g_last = tl.load(g + bos * H + last_idx * H + i_h) p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) b_g = tl.load(p_g, boundary_check=(0, )) - b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None] + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] b_g_last = exp(b_g_last) b_h1 = b_h1 * b_g_last if K > 64: diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index 5a36d31332..332751a186 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -16,7 +16,7 @@ import torch from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices -from .op import exp, safe_exp +from .op import exp from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] @@ -112,10 +112,11 @@ def chunk_fwd_kernel_o( p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) b_g = tl.load(p_g, boundary_check=(0, )) b_o = b_o * exp(b_g)[:, None] - b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) - o_i = tl.arange(0, BT) - m_A = o_i[:, None] >= o_i[None, :] + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) b_A = tl.where(m_A, b_A, 0) p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py index 9938eae52d..d1adc6978f 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -14,7 +14,7 @@ import torch from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices -from .op import safe_exp +from .op import exp @triton.heuristics({ @@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel( T = eos - bos else: bos, eos = i_b * T, i_b * T + T - o_t = tl.arange(0, BT) + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) @@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel( (i_t * BT, ), (BT, ), (0, )) b_g = tl.load(p_g, boundary_check=(0, )) b_g_diff = b_g[:, None] - b_g[None, :] - b_A = b_A * safe_exp(b_g_diff) + b_A = b_A * exp(b_g_diff) - b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 25a615fe12..b278e37415 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) - b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) b_q = b_q * scale # [BK, BV] b_h *= exp(b_g) diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py index b89c67871d..ef9788ceaf 100644 --- a/vllm/model_executor/layers/fla/ops/l2norm.py +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] xmask = row_idx < M rindex = tl.arange(0, N)[None, :] - xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32) + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) square = tl.broadcast_to(xs * xs, [MBLOCK, N]) square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] rsqrt = tl.rsqrt(square_sum + eps) diff --git a/vllm/model_executor/layers/fla/ops/op.py b/vllm/model_executor/layers/fla/ops/op.py index 05c424b437..8c29434ca1 100644 --- a/vllm/model_executor/layers/fla/ops/op.py +++ b/vllm/model_executor/layers/fla/ops/op.py @@ -28,11 +28,6 @@ else: log2 = tl.log2 -@triton.jit -def safe_exp(x): - return exp(tl.where(x <= 0, x, float('-inf'))) - - if not hasattr(tl, 'gather'): @triton.jit diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 1dc4663964..a6c1af91de 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -70,6 +70,15 @@ class MambaStateDtypeCalculator: model_dtype) return (conv_state_dtype, ) + @classmethod + def gated_delta_net_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, torch.dtype]: + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, state_dtype) + class MambaStateShapeCalculator: @@ -163,3 +172,31 @@ class MambaStateShapeCalculator: # for n_groups == 1, this is exactly tp_size - n_groups return tp_size - ngroups + + @classmethod + def gated_delta_net_state_shape( + cls, + tp_world_size: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_kernel_size: int, + num_spec: int = 0, + use_v1: bool = True, + ): + conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) + conv_state_shape = ( + divide(conv_dim, tp_world_size), + conv_kernel_size - 1 + num_spec, + ) + + # In V0, the conv_state shape was swapped during allocation in + # MambaCacheManager, but in V1 it needs to be determined here at the + # calculation level + if use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + temporal_state_shape = (divide(num_v_heads, + tp_world_size), head_k_dim, head_v_dim) + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b8d4bbc371..709794429c 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -464,7 +464,9 @@ def causal_conv1d_fn( # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) - assert (num_cache_lines, dim, width - 1) == conv_states.shape + assert (num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2]) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) @@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel( conv_state_ptr, cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, + num_accepted_tokens_ptr, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel( stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, @@ -663,8 +668,9 @@ def _causal_conv1d_update_kernel( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa @@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel( # not processing as this is not the actual sequence return + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) - + 1) + else: + conv_state_token_offset = 0 + # STEP 1: READ init_state data conv_states_base = (conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) mask_w = idx_feats < dim - prior_tokens = conv_states_base + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok if KERNEL_WIDTH >= 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) @@ -695,10 +720,14 @@ def _causal_conv1d_update_kernel( # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. conv_state_ptrs_source = ( conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] mask = ((conv_state_batch_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] @@ -820,6 +849,7 @@ def causal_conv1d_update( activation: Union[bool, str, None] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, @@ -890,10 +920,11 @@ def causal_conv1d_update( ) # X (batch, dim, seqlen) stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) - state_len = width - 1 + stride_state_indices = conv_state_indices.stride( + 0) if conv_state_indices is not None else 0 + state_len = width - 1 + (seqlen - 1) # effective state_len needed np2_statelen = triton.next_power_of_2(state_len) def grid(META): @@ -910,6 +941,7 @@ def causal_conv1d_update( conv_state, cache_seqlens, conv_state_indices, + num_accepted_tokens, out, # Matrix dimensions batch, @@ -926,6 +958,7 @@ def causal_conv1d_update( stride_istate_seq, stride_istate_dim, stride_istate_token, + stride_state_indices, stride_o_seq, stride_o_dim, stride_o_token, @@ -936,6 +969,7 @@ def causal_conv1d_update( KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index f38e7fc202..687af7a189 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -312,7 +312,8 @@ class MambaModelConfig(VerifyAndUpdateConfig): # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ - "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" + "Lfm2ForCausalLM", + "MiniMaxText01ForCausalLM", ] if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py new file mode 100644 index 0000000000..12db3d9461 --- /dev/null +++ b/vllm/model_executor/models/qwen3_next.py @@ -0,0 +1,1294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm import envs +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, + VllmConfig, get_current_vllm_config) +from vllm.distributed import (divide, get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fla.ops import ( + RMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.layernorm import ( + GemmaRMSNorm as Qwen3NextRMSNorm) +# yapf: enable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + mamba_v2_sharded_weight_loader) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.triton_utils import tl, triton +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from .interfaces import (HasInnerState, IsHybrid, MixtureOfExperts, + SupportsLoRA, SupportsPP) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +class Qwen3NextSparseMoeBlock(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE(num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + use_v1=True) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = (self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.projection_size_qkvz, self.projection_size_ba], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + query_key_settings, + query_key_settings, + value_settings, + ], self.tp_size, self.tp_rank) + }) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + )) + + set_weight_attrs(self.A_log, + {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj") + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkvz, + mixed_ba, + ): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + (self.head_k_dim + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // + self.num_k_heads), + ) + new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], + # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, + split_arg_list_qkvz, + dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), + (query, key)) + value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) + return query, key, value + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + cache_params: Optional[MambaCacheParams] = None, + ): + return torch.ops.vllm.gdn_attention( + hidden_states, + output, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + num_accepted_tokens = attn_metadata.num_accepted_tokens + + # 1. Set up dimensions for reshapes later + projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + projected_states_qkvz, projected_states_ba = torch.split( + projected_states, + [ + self.projection_size_qkvz // self.tp_size, + self.projection_size_ba // self.tp_size + ], + dim=-1, + ) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.1: process the mutli-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = mixed_qkv_spec.view( + attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0] + [:attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + validate_data=False, + ) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class Qwen3NextAttention(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None) + self.attn_output_gate = getattr(config, "attn_output_gate", True) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=getattr(config, "qkv_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, + partial_rotary_factor=config.partial_rotary_factor, + dual_chunk_attention_config=self.dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": + self.dual_chunk_attention_config, + } if self.dual_chunk_attention_config else {}, + ) + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + output: torch.Tensor, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + +class Qwen3NextDecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + layer_type: str, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.config = config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f'{prefix}.linear_attn') + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.self_attn', + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (self.layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype) + 1) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + assert len(hidden_states.shape) == len( + self.ffn_layer_scale.shape + ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype) + 1) + + return hidden_states, residual + + +@support_torch_compile +class Qwen3NextModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Qwen3NextConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + speculative_config = vllm_config.speculative_config + enable_eplb = parallel_config.enable_eplb + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return Qwen3NextDecoderLayer( + config, + layer_type=config.layer_types[extract_layer_index(prefix)], + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=prefix, + enable_eplb=enable_eplb, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("in_proj", "in_proj_qkvz", 0), + ("in_proj", "in_proj_ba", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + MixtureOfExperts, IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"], + "in_proj": ["in_proj_qkvz", "in_proj_ba"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3Next currently does not support prefix caching" + assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = (vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + use_v1=True) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def gdn_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output) + + +def gdn_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="gdn_attention", + op_func=gdn_attention, + mutates_args=["output"], + fake_impl=gdn_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid](g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1) + return g diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py new file mode 100644 index 0000000000..e7aff377e9 --- /dev/null +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, + Qwen3NextRMSNorm) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, maybe_prefix) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class Qwen3NextMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear(self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False) + + self.layers = torch.nn.ModuleList( + Qwen3NextDecoderLayer( + config, + layer_type="full_attention", + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}', + ) for idx in range(self.num_mtp_layers)) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = (spec_step_idx % self.num_mtp_layers) + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class Qwen3NextMTP(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3NextMTP currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, hidden_states, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif not any(key in name for key in shared_weight_names): + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8fe8402c15..7d7654e846 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = { "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), + "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), @@ -285,6 +286,7 @@ _SPECULATIVE_DECODING_MODELS = { "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), + "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d6ebcdf805..2852d16ec5 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -79,7 +79,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ultravox="UltravoxConfig", step3_vl="Step3VLConfig", step3_text="Step3TextConfig", -) + qwen3_next="Qwen3NextConfig") _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f651ecb078..cdae59ccc2 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.ovis import OvisConfig +from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, Step3VisionEncoderConfig, @@ -50,4 +51,5 @@ __all__ = [ "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "Qwen3NextConfig", ] diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py new file mode 100644 index 0000000000..c7af26acd1 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3-Next model configuration""" + +from transformers.configuration_utils import (PretrainedConfig, + layer_type_validation) +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ # noqa: E501 + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + layer_types=None, + **kwargs, + ): + if mlp_only_layers is None: + mlp_only_layers = [] + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "linear_attention" if bool((i + 1) % 4) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + +__all__ = ["Qwen3NextConfig"] diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py new file mode 100644 index 0000000000..12233af057 --- /dev/null +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Backend for GatedDeltaNet attention.""" +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class GDNAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: + return GDNAttentionMetadataBuilder + + +@dataclass +class GDNAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_spec_decodes: int + num_spec_decode_tokens: int + + has_initial_state: Optional[torch.Tensor] = None + + spec_query_start_loc: Optional[ + torch.Tensor] = None # shape: [num_spec_decodes + 1,] + non_spec_query_start_loc: Optional[ + torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,] + + spec_state_indices_tensor: Optional[ + torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[ + torch.Tensor] = None # shape: [batch - num_spec_decodes,] + spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] + spec_token_masks: Optional[ + torch. + Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,] + num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + + +class GDNAttentionMetadataBuilder( + AttentionMetadataBuilder[GDNAttentionMetadata]): + + cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.speculative_config = vllm_config.speculative_config + self.kv_cache_spec = kv_cache_spec + if self.speculative_config: + self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501 + else: + self.num_spec = 0 + self.use_spec_decode = self.num_spec > 0 + self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + + self.spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, self.num_spec + 1), + dtype=torch.int32, + device=device, + ) + self.non_spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.spec_sequence_masks = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.bool, + device=device, + ) + self.spec_token_masks = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1), ), + dtype=torch.bool, + device=device, + ) + self.spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1, ), + dtype=torch.int32, + device=device, + ) + self.non_spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1, ), + dtype=torch.int32, + device=device, + ) + self.num_accepted_tokens = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build( # type: ignore[override] + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + num_accepted_tokens: Optional[torch.Tensor] = None, + num_draft_tokens: Optional[torch.Tensor] = None, + fast_build: bool = False, + ) -> GDNAttentionMetadata: + m = common_attn_metadata + + query_start_loc = m.query_start_loc + context_lens = m.num_computed_tokens_cpu + context_lens_tensor = context_lens.to(query_start_loc.device) + seq_lens_tensor = m.seq_lens + + if (not self.use_spec_decode or num_draft_tokens is None + or num_draft_tokens.sum().item() == 0): + spec_sequence_masks = None + else: + spec_sequence_masks = (num_draft_tokens > 0) & ( + context_lens_tensor + + (num_draft_tokens + 1) == seq_lens_tensor) + if spec_sequence_masks.sum().item() == 0: + spec_sequence_masks = None + + if spec_sequence_masks is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(m, decode_threshold=1)) + num_spec_decodes = 0 + num_spec_decode_tokens = 0 + spec_token_masks = None + spec_state_indices_tensor = None + non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + spec_query_start_loc = None + non_spec_query_start_loc = query_start_loc + num_accepted_tokens = None + else: + num_spec_decodes = spec_sequence_masks.sum().item() + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + non_spec_query_lens = query_lens[~spec_sequence_masks] + num_decodes = (non_spec_query_lens == 1).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes + num_decode_tokens = num_decodes + num_prefill_tokens = non_spec_query_lens.sum().item( + ) - num_decode_tokens + + if num_prefills == 0 and num_decodes == 0: + spec_token_masks = torch.ones( + (min(num_spec_decodes * + (self.num_spec + 1), query_start_loc[-1].item())), + dtype=torch.bool, + device=query_start_loc.device) + spec_state_indices_tensor = m.block_table_tensor[:, :self. + num_spec + 1] + non_spec_state_indices_tensor = None + spec_query_start_loc = query_start_loc + non_spec_query_start_loc = None + else: + spec_token_masks = torch.repeat_interleave( + spec_sequence_masks, query_lens) + spec_state_indices_tensor = m.block_table_tensor[ + spec_sequence_masks, :self.num_spec + 1] + non_spec_state_indices_tensor = \ + m.block_table_tensor[~spec_sequence_masks, 0] + + spec_query_start_loc = torch.zeros( + num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device) + torch.cumsum(query_lens[spec_sequence_masks], + dim=0, + out=spec_query_start_loc[1:]) + non_spec_query_start_loc = torch.zeros( + query_lens.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device) + torch.cumsum(query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:]) + + num_spec_decode_tokens = min( + num_spec_decodes * (self.num_spec + 1), + spec_token_masks.size(0)) + assert num_accepted_tokens is not None + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + + if num_prefills > 0: + has_initial_state = context_lens_tensor > 0 + if spec_sequence_masks is not None: + has_initial_state = has_initial_state[~spec_sequence_masks] + else: + has_initial_state = None + + # prepare tensors for cudagraph + if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs): + num_total_tokens = self.vllm_config.pad_for_cudagraph( + m.num_actual_tokens) + batch_size = num_total_tokens // (self.num_spec + 1) + + self.spec_state_indices_tensor[:num_spec_decodes].copy_( + spec_state_indices_tensor, non_blocking=True) + spec_state_indices_tensor = self.spec_state_indices_tensor[: + batch_size] + spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + + self.spec_sequence_masks[:num_spec_decodes].copy_( + spec_sequence_masks, non_blocking=True) + spec_sequence_masks = self.spec_sequence_masks[:batch_size] + spec_sequence_masks[num_spec_decodes:].fill_(False) + + assert spec_token_masks is not None + self.spec_token_masks[:spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True) + spec_token_masks = self.spec_token_masks[:m.num_actual_tokens] + spec_token_masks[spec_token_masks.size(0):].fill_(False) + + self.spec_query_start_loc[:num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True) + spec_num_query_tokens = spec_query_start_loc[ + -1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1] + spec_query_start_loc[num_spec_decodes + + 1:].fill_(spec_num_query_tokens) + + self.num_accepted_tokens[:num_spec_decodes].copy_( + num_accepted_tokens, non_blocking=True) + num_accepted_tokens = self.num_accepted_tokens[:batch_size] + num_accepted_tokens[num_spec_decodes:].fill_(1) + + if (self.use_full_cuda_graph and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs): + num_total_tokens = self.vllm_config.pad_for_cudagraph( + m.num_actual_tokens) + batch_size = num_total_tokens + + self.non_spec_state_indices_tensor[:num_decodes].copy_( + non_spec_state_indices_tensor, non_blocking=True) + non_spec_state_indices_tensor = \ + self.non_spec_state_indices_tensor[:batch_size] + non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + + self.non_spec_query_start_loc[:num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True) + non_spec_num_query_tokens = non_spec_query_start_loc[ + -1] # type: ignore[index] + non_spec_query_start_loc = \ + self.non_spec_query_start_loc[:batch_size + 1] + non_spec_query_start_loc[num_decodes + + 1:].fill_(non_spec_num_query_tokens) + + attn_metadata = GDNAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_spec_decodes=num_spec_decodes, + num_spec_decode_tokens=num_spec_decode_tokens, + has_initial_state=has_initial_state, + spec_query_start_loc=spec_query_start_loc, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_state_indices_tensor=spec_state_indices_tensor, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + spec_sequence_masks=spec_sequence_masks, + spec_token_masks=spec_token_masks, + num_accepted_tokens=num_accepted_tokens, + ) + return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens + and ((m.num_reqs + 1) * (self.num_spec + 1) + >= m.num_actual_tokens)), \ + "GDN only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + num_accepted_tokens = torch.full((m.num_reqs, ), + m.max_query_len, + dtype=torch.int32, + device=m.query_start_loc.device) + num_drafted_tokens = torch.full((m.num_reqs, ), + self.num_spec, + dtype=torch.int32, + device=m.query_start_loc.device) + + # Fixes query-start loc for spec-sequence-indices. + m.query_start_loc = torch.arange(0, + m.num_actual_tokens + 1, + step=m.max_query_len, + device=m.query_start_loc.device, + dtype=torch.int32) + m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full( + (m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu')) + + return self.build(0, m, num_accepted_tokens, num_drafted_tokens) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8159349e46..d27239164b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -559,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager): num_running_requests: int) -> int: return 0 + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks + """ + + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += (self.kv_cache_spec.block_size * + self.kv_cache_spec.num_speculative_blocks) + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null + for blk in new_computed_blocks) + return num_new_blocks + num_evictable_computed_blocks + def allocate_new_blocks(self, request_id: str, num_tokens: int) -> list[KVCacheBlock]: - new_blocks = super().allocate_new_blocks(request_id, num_tokens) - assert len(self.req_to_blocks[request_id]) == 1, ( - "MambaManager should only allocate 1 block for each request.") - return new_blocks + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += (self.kv_cache_spec.block_size * + self.kv_cache_spec.num_speculative_blocks) + return super().allocate_new_blocks(request_id, num_tokens) class CrossAttentionManager(SingleTypeKVCacheManager): diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6467fcfe40..6e8f569fff 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -194,6 +194,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" + num_speculative_blocks: int = 0 @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d83..7132d507c7 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -218,7 +218,7 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method in ("deepseek_mtp", "ernie_mtp"): + if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -322,12 +322,18 @@ class EagleProposer: with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) + if self.method in ("deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp"): + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 0e509b7453..1901de6d2e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -156,9 +156,14 @@ class BlockTable: class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, block_sizes: list[int]) -> None: + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -170,10 +175,11 @@ class MultiGroupBlockTable: dcp_world_size = 1 self.block_tables = [ - BlockTable(block_size, max_num_reqs, - cdiv(max_model_len, block_size * dcp_world_size), - max_num_batched_tokens, pin_memory, device) - for block_size in block_sizes + BlockTable( + block_size, max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens), max_num_batched_tokens, + pin_memory, device) for block_size in block_sizes ] def append_row(self, block_ids: tuple[list[int], ...], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bf9b16575e..1cf56656d7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -83,6 +83,7 @@ class InputBatch: logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, + num_speculative_tokens: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -127,6 +128,7 @@ class InputBatch: pin_memory=pin_memory, device=device, block_sizes=block_sizes, + num_speculative_tokens=num_speculative_tokens, ) # Sampling-related. @@ -202,6 +204,14 @@ class InputBatch: self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() + # Speculative decoding + self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), + dtype=torch.int64, + device="cpu", + pin_memory=pin_memory) + self.num_accepted_tokens_cpu = \ + self.num_accepted_tokens_cpu_tensor.numpy() + # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) @@ -394,6 +404,9 @@ class InputBatch: else: raise NotImplementedError("Unrecognized request type") + # Speculative decoding: by default 1 token is generated. + self.num_accepted_tokens_cpu[req_index] = 1 + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -515,6 +528,8 @@ class InputBatch: self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ + self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) @@ -609,6 +624,8 @@ class InputBatch: empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] + self.num_accepted_tokens_cpu[ + empty_index] = self.num_accepted_tokens_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ce53154896..1b785af96a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -53,9 +53,9 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, round_up, - supports_dynamo) + GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, + is_pin_memory_available, round_up, supports_dynamo) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, @@ -324,6 +324,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.hidden_size, dtype=self.dtype, numpy=False) + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -663,6 +667,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = (torch.cat( + [ + output_token_ids, + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), + ], + dim=1) == -1).int().argmax(-1).cpu().numpy() + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + def _init_mrope_positions(self, req_state: CachedRequestState): image_grid_thw = [] video_grid_thw = [] @@ -936,6 +965,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 + num_draft_tokens = None spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -950,6 +980,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: @@ -964,6 +997,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) spec_decode_common_attn_metadata = None + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -1034,10 +1072,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): builder, ) + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, + GDNAttentionMetadataBuilder): + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - ) + **extra_attn_metadata_args) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1814,6 +1861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids + self._update_states_after_model_execute(output_token_ids) return sampler_output @@ -2644,13 +2692,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: - assert not create_mixed_batch - num_reqs = cdiv(num_tokens, max_query_len) + num_reqs = num_tokens // max_query_len assert num_reqs <= max_num_reqs, \ "Do not capture num_reqs > max_num_reqs for uniform batch" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: - num_scheduled_tokens_list[-1] = num_tokens % max_query_len + num_scheduled_tokens_list[-1] += num_tokens % max_query_len else: num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs @@ -3297,6 +3344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), ) def _allocate_kv_cache_tensors( @@ -3647,7 +3697,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if self.vllm_config.speculative_config is not None: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: @@ -3666,7 +3718,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type) + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) return kv_cache_spec diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b4a67e2899..142b1afce8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -78,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase): "deepseek_mtp", "glm4_moe_mtp", "mimo_mtp", - "ernie_mtp")) \ + "ernie_mtp", + "qwen3_next_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner