mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Add the support for the qwen3 next model (a hybrid attention model). (#24526)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -1 +1,2 @@
|
|||||||
collect_env.py
|
collect_env.py
|
||||||
|
vllm/model_executor/layers/fla/ops/*.py
|
||||||
|
|||||||
@ -403,6 +403,7 @@ th {
|
|||||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `Qwen3NextForCausalLM` | Qwen3.5MoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
||||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -228,6 +228,7 @@ fo = "fo"
|
|||||||
ba = "ba"
|
ba = "ba"
|
||||||
|
|
||||||
[tool.typos.type.py.extend-words]
|
[tool.typos.type.py.extend-words]
|
||||||
|
ba = "ba"
|
||||||
|
|
||||||
[tool.typos.type.cpp]
|
[tool.typos.type.cpp]
|
||||||
extend-glob = ["*.cu"]
|
extend-glob = ["*.cu"]
|
||||||
|
|||||||
@ -326,6 +326,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||||
|
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
|
min_transformers_version="4.56.2"),
|
||||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||||
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -640,7 +642,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
is_available_online=False),
|
is_available_online=False),
|
||||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="XiaomiMiMo/MiMo-7B-RL")
|
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||||
|
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
|
min_transformers_version="4.56.2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
|
|||||||
@ -1508,7 +1508,8 @@ class ModelConfig:
|
|||||||
if (self.hf_text_config.model_type == "deepseek_mtp"
|
if (self.hf_text_config.model_type == "deepseek_mtp"
|
||||||
or self.hf_config.model_type == "mimo_mtp"
|
or self.hf_config.model_type == "mimo_mtp"
|
||||||
or self.hf_config.model_type == "glm4_moe_mtp"
|
or self.hf_config.model_type == "glm4_moe_mtp"
|
||||||
or self.hf_config.model_type == "ernie_mtp"):
|
or self.hf_config.model_type == "ernie_mtp"
|
||||||
|
or self.hf_config.model_type == "qwen3_next_mtp"):
|
||||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||||
"num_nextn_predict_layers", 0)
|
"num_nextn_predict_layers", 0)
|
||||||
else:
|
else:
|
||||||
@ -1571,15 +1572,28 @@ class ModelConfig:
|
|||||||
if attn_type_list:
|
if attn_type_list:
|
||||||
return sum(t == 1 for t in attn_type_list[start:end])
|
return sum(t == 1 for t in attn_type_list[start:end])
|
||||||
|
|
||||||
if layers_block_type_value is None and attn_type_list is None:
|
# Hybrid model Qwen3Next
|
||||||
|
layer_types_value = getattr(self.hf_config, "layer_types", None)
|
||||||
|
if layer_types_value is not None:
|
||||||
|
if getattr(block_type, "value", block_type) == "attention":
|
||||||
|
return sum(t == "full_attention"
|
||||||
|
for t in layer_types_value[start:end])
|
||||||
|
elif getattr(block_type, "value",
|
||||||
|
block_type) == "linear_attention":
|
||||||
|
return sum(t == "linear_attention"
|
||||||
|
for t in layer_types_value[start:end])
|
||||||
|
else:
|
||||||
|
return sum(t == getattr(block_type, "value", block_type)
|
||||||
|
for t in layer_types_value[start:end])
|
||||||
|
|
||||||
|
if (layers_block_type_value is None and attn_type_list is None
|
||||||
|
and layer_types_value is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The model is an hybrid without a"
|
"The model is an hybrid without a"
|
||||||
"layers_block_type or an attn_type_list in the hf_config,"
|
"layers_block_type or an attn_type_list, or a layer_types "
|
||||||
"cannot determine the num of "
|
"in the hf_config, cannot determine the num of "
|
||||||
f"{block_type.value} layers")
|
f"{block_type.value} layers")
|
||||||
|
|
||||||
return sum(t == 1 for t in attn_type_list[start:end])
|
|
||||||
|
|
||||||
def get_mamba_chunk_size(self) -> Optional[int]:
|
def get_mamba_chunk_size(self) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
Returns the mamba chunk size if it exists
|
Returns the mamba chunk size if it exists
|
||||||
@ -1866,7 +1880,7 @@ class DeviceConfig:
|
|||||||
|
|
||||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||||
"ernie_mtp"]
|
"ernie_mtp", "qwen3_next_mtp"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -2007,7 +2021,15 @@ class SpeculativeConfig:
|
|||||||
"n_predict": n_predict,
|
"n_predict": n_predict,
|
||||||
"architectures": ["ErnieMTPModel"]
|
"architectures": ["ErnieMTPModel"]
|
||||||
})
|
})
|
||||||
return hf_config
|
|
||||||
|
if hf_config.model_type == "qwen3_next":
|
||||||
|
hf_config.model_type = "qwen3_next_mtp"
|
||||||
|
if hf_config.model_type == "qwen3_next_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["Qwen3NextMTP"]
|
||||||
|
})
|
||||||
|
|
||||||
return hf_config
|
return hf_config
|
||||||
|
|
||||||
@ -2028,9 +2050,13 @@ class SpeculativeConfig:
|
|||||||
(self.target_model_config.hf_text_config.model_type \
|
(self.target_model_config.hf_text_config.model_type \
|
||||||
== "deepseek_v3" or
|
== "deepseek_v3" or
|
||||||
self.target_model_config.hf_text_config.model_type in
|
self.target_model_config.hf_text_config.model_type in
|
||||||
("mimo","ernie4_5_moe")):
|
("mimo","ernie4_5_moe", "qwen3_next")):
|
||||||
# use the draft model from the same model:
|
# use the draft model from the same model:
|
||||||
self.model = self.target_model_config.model
|
self.model = self.target_model_config.model
|
||||||
|
# Align the quantization of draft model for cases such as
|
||||||
|
# --quantization fp8 with a bf16 checkpoint.
|
||||||
|
if not self.quantization:
|
||||||
|
self.quantization = self.target_model_config.quantization
|
||||||
elif self.method in ("ngram", "[ngram]"):
|
elif self.method in ("ngram", "[ngram]"):
|
||||||
self.model = "ngram"
|
self.model = "ngram"
|
||||||
else:
|
else:
|
||||||
@ -2140,6 +2166,15 @@ class SpeculativeConfig:
|
|||||||
"one layer. Might need some code changes " \
|
"one layer. Might need some code changes " \
|
||||||
"to support multiple layers."
|
"to support multiple layers."
|
||||||
)
|
)
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"qwen3_next_mtp"):
|
||||||
|
self.method = "qwen3_next_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Qwen3Next MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.method = "draft_model"
|
self.method = "draft_model"
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -2355,7 +2390,8 @@ class SpeculativeConfig:
|
|||||||
return self.num_speculative_tokens
|
return self.num_speculative_tokens
|
||||||
|
|
||||||
def use_eagle(self) -> bool:
|
def use_eagle(self) -> bool:
|
||||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
|
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||||
|
"qwen3_next_mtp")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
method = self.method
|
method = self.method
|
||||||
|
|||||||
@ -341,6 +341,7 @@ class CompilationConfig:
|
|||||||
"vllm.short_conv",
|
"vllm.short_conv",
|
||||||
"vllm.linear_attention",
|
"vllm.linear_attention",
|
||||||
"vllm.plamo2_mamba_mixer",
|
"vllm.plamo2_mamba_mixer",
|
||||||
|
"vllm.gdn_attention",
|
||||||
]
|
]
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||||
from .op import exp, safe_exp
|
from .op import exp
|
||||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
from .utils import is_nvidia_hopper, use_cuda_graph
|
||||||
|
|
||||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||||
@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
boundary_check=(0, 1))
|
boundary_check=(0, 1))
|
||||||
|
|
||||||
if USE_G:
|
if USE_G:
|
||||||
|
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||||
last_idx = min((i_t + 1) * BT, T) - 1
|
last_idx = min((i_t + 1) * BT, T) - 1
|
||||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
|
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
|
||||||
(i_t * BT, ), (BT, ), (0, ))
|
(i_t * BT, ), (BT, ), (0, ))
|
||||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||||
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
|
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
|
||||||
b_g_last = exp(b_g_last)
|
b_g_last = exp(b_g_last)
|
||||||
b_h1 = b_h1 * b_g_last
|
b_h1 = b_h1 * b_g_last
|
||||||
if K > 64:
|
if K > 64:
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .index import prepare_chunk_indices
|
from .index import prepare_chunk_indices
|
||||||
from .op import exp, safe_exp
|
from .op import exp
|
||||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||||
|
|
||||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||||
@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
|
|||||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||||
b_o = b_o * exp(b_g)[:, None]
|
b_o = b_o * exp(b_g)[:, None]
|
||||||
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
|
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
|
||||||
|
|
||||||
o_i = tl.arange(0, BT)
|
o_t = i_t * BT + tl.arange(0, BT)
|
||||||
m_A = o_i[:, None] >= o_i[None, :]
|
m_t = o_t < T
|
||||||
|
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||||
b_A = tl.where(m_A, b_A, 0)
|
b_A = tl.where(m_A, b_A, 0)
|
||||||
|
|
||||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .index import prepare_chunk_indices
|
from .index import prepare_chunk_indices
|
||||||
from .op import safe_exp
|
from .op import exp
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|||||||
T = eos - bos
|
T = eos - bos
|
||||||
else:
|
else:
|
||||||
bos, eos = i_b * T, i_b * T + T
|
bos, eos = i_b * T, i_b * T + T
|
||||||
o_t = tl.arange(0, BT)
|
o_t = i_t * BT + tl.arange(0, BT)
|
||||||
|
m_t = o_t < T
|
||||||
|
|
||||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||||
(i_t * BT, ), (BT, ), (0, ))
|
(i_t * BT, ), (BT, ), (0, ))
|
||||||
@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|||||||
(i_t * BT, ), (BT, ), (0, ))
|
(i_t * BT, ), (BT, ), (0, ))
|
||||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||||
b_A = b_A * safe_exp(b_g_diff)
|
b_A = b_A * exp(b_g_diff)
|
||||||
|
|
||||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||||
|
b_A = tl.where(m_A, b_A, 0)
|
||||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|||||||
@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
b_g = tl.load(p_g).to(tl.float32)
|
b_g = tl.load(p_g).to(tl.float32)
|
||||||
|
|
||||||
if USE_QK_L2NORM_IN_KERNEL:
|
if USE_QK_L2NORM_IN_KERNEL:
|
||||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||||
b_q = b_q * scale
|
b_q = b_q * scale
|
||||||
# [BK, BV]
|
# [BK, BV]
|
||||||
b_h *= exp(b_g)
|
b_h *= exp(b_g)
|
||||||
|
|||||||
@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
|||||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||||
xmask = row_idx < M
|
xmask = row_idx < M
|
||||||
rindex = tl.arange(0, N)[None, :]
|
rindex = tl.arange(0, N)[None, :]
|
||||||
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
|
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||||
rsqrt = tl.rsqrt(square_sum + eps)
|
rsqrt = tl.rsqrt(square_sum + eps)
|
||||||
|
|||||||
@ -28,11 +28,6 @@ else:
|
|||||||
log2 = tl.log2
|
log2 = tl.log2
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def safe_exp(x):
|
|
||||||
return exp(tl.where(x <= 0, x, float('-inf')))
|
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(tl, 'gather'):
|
if not hasattr(tl, 'gather'):
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@ -70,6 +70,15 @@ class MambaStateDtypeCalculator:
|
|||||||
model_dtype)
|
model_dtype)
|
||||||
return (conv_state_dtype, )
|
return (conv_state_dtype, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def gated_delta_net_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||||
|
return (state_dtype, state_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MambaStateShapeCalculator:
|
class MambaStateShapeCalculator:
|
||||||
|
|
||||||
@ -163,3 +172,31 @@ class MambaStateShapeCalculator:
|
|||||||
|
|
||||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||||
return tp_size - ngroups
|
return tp_size - ngroups
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def gated_delta_net_state_shape(
|
||||||
|
cls,
|
||||||
|
tp_world_size: int,
|
||||||
|
num_k_heads: int,
|
||||||
|
num_v_heads: int,
|
||||||
|
head_k_dim: int,
|
||||||
|
head_v_dim: int,
|
||||||
|
conv_kernel_size: int,
|
||||||
|
num_spec: int = 0,
|
||||||
|
use_v1: bool = True,
|
||||||
|
):
|
||||||
|
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||||
|
conv_state_shape = (
|
||||||
|
divide(conv_dim, tp_world_size),
|
||||||
|
conv_kernel_size - 1 + num_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
# In V0, the conv_state shape was swapped during allocation in
|
||||||
|
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||||
|
# calculation level
|
||||||
|
if use_v1:
|
||||||
|
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||||
|
|
||||||
|
temporal_state_shape = (divide(num_v_heads,
|
||||||
|
tp_world_size), head_k_dim, head_v_dim)
|
||||||
|
return conv_state_shape, temporal_state_shape
|
||||||
|
|||||||
@ -464,7 +464,9 @@ def causal_conv1d_fn(
|
|||||||
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
||||||
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
||||||
num_cache_lines = conv_states.size(0)
|
num_cache_lines = conv_states.size(0)
|
||||||
assert (num_cache_lines, dim, width - 1) == conv_states.shape
|
assert (num_cache_lines == conv_states.shape[0]
|
||||||
|
and dim == conv_states.shape[1]
|
||||||
|
and width - 1 <= conv_states.shape[2])
|
||||||
stride_istate_seq = conv_states.stride(0)
|
stride_istate_seq = conv_states.stride(0)
|
||||||
stride_istate_dim = conv_states.stride(1)
|
stride_istate_dim = conv_states.stride(1)
|
||||||
stride_istate_token = conv_states.stride(2)
|
stride_istate_token = conv_states.stride(2)
|
||||||
@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
conv_state_ptr,
|
conv_state_ptr,
|
||||||
cache_seqlens_ptr, # circular buffer
|
cache_seqlens_ptr, # circular buffer
|
||||||
conv_state_indices_ptr,
|
conv_state_indices_ptr,
|
||||||
|
num_accepted_tokens_ptr,
|
||||||
o_ptr, # (batch, dim, seqlen)
|
o_ptr, # (batch, dim, seqlen)
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
batch: int,
|
batch: int,
|
||||||
@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
stride_conv_state_seq: tl.constexpr,
|
stride_conv_state_seq: tl.constexpr,
|
||||||
stride_conv_state_dim: tl.constexpr,
|
stride_conv_state_dim: tl.constexpr,
|
||||||
stride_conv_state_tok: tl.constexpr,
|
stride_conv_state_tok: tl.constexpr,
|
||||||
|
stride_state_indices: tl.constexpr,
|
||||||
stride_o_seq: tl.constexpr,
|
stride_o_seq: tl.constexpr,
|
||||||
stride_o_dim: tl.constexpr,
|
stride_o_dim: tl.constexpr,
|
||||||
stride_o_token: tl.constexpr,
|
stride_o_token: tl.constexpr,
|
||||||
@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
KERNEL_WIDTH: tl.constexpr,
|
KERNEL_WIDTH: tl.constexpr,
|
||||||
SILU_ACTIVATION: tl.constexpr,
|
SILU_ACTIVATION: tl.constexpr,
|
||||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||||
|
IS_SPEC_DECODING: tl.constexpr,
|
||||||
NP2_STATELEN: tl.constexpr,
|
NP2_STATELEN: tl.constexpr,
|
||||||
USE_PAD_SLOT: tl.constexpr,
|
USE_PAD_SLOT: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
@ -663,8 +668,9 @@ def _causal_conv1d_update_kernel(
|
|||||||
|
|
||||||
if IS_CONTINUOUS_BATCHING:
|
if IS_CONTINUOUS_BATCHING:
|
||||||
# mask = idx_seq < batch
|
# mask = idx_seq < batch
|
||||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||||
tl.int64)
|
idx_seq * stride_state_indices).to(
|
||||||
|
tl.int64)
|
||||||
else:
|
else:
|
||||||
conv_state_batch_coord = idx_seq
|
conv_state_batch_coord = idx_seq
|
||||||
if USE_PAD_SLOT: # noqa
|
if USE_PAD_SLOT: # noqa
|
||||||
@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel(
|
|||||||
# not processing as this is not the actual sequence
|
# not processing as this is not the actual sequence
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if IS_SPEC_DECODING:
|
||||||
|
# The rolling of conv state:
|
||||||
|
#
|
||||||
|
# Before forward, the conv_state is:
|
||||||
|
# [history1, history2, ..., historyM].
|
||||||
|
#
|
||||||
|
# After forward, the conv_state becomes:
|
||||||
|
# [history2, ..., historyM, draft1, draft2, ..., draftN].
|
||||||
|
#
|
||||||
|
# After acceptance, it becomes:
|
||||||
|
#
|
||||||
|
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
||||||
|
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||||
|
# - and so on.
|
||||||
|
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
|
||||||
|
1)
|
||||||
|
else:
|
||||||
|
conv_state_token_offset = 0
|
||||||
|
|
||||||
# STEP 1: READ init_state data
|
# STEP 1: READ init_state data
|
||||||
conv_states_base = (conv_state_ptr +
|
conv_states_base = (conv_state_ptr +
|
||||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
(idx_feats * stride_conv_state_dim))
|
(idx_feats * stride_conv_state_dim))
|
||||||
mask_w = idx_feats < dim
|
mask_w = idx_feats < dim
|
||||||
|
|
||||||
prior_tokens = conv_states_base
|
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||||
if KERNEL_WIDTH >= 2:
|
if KERNEL_WIDTH >= 2:
|
||||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
@ -695,10 +720,14 @@ def _causal_conv1d_update_kernel(
|
|||||||
# STEP 2: assume state_len > seqlen
|
# STEP 2: assume state_len > seqlen
|
||||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
|
||||||
|
# The conv_state updates works in a sliding window manner,
|
||||||
|
# at each forward pass, the tokens are shift by 1, so we
|
||||||
|
# load since idx_tokens + 1.
|
||||||
conv_state_ptrs_source = (
|
conv_state_ptrs_source = (
|
||||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
|
conv_state_token_offset * stride_conv_state_tok +
|
||||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||||
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
|
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||||
) # [BLOCK_M, BLOCK_N]
|
) # [BLOCK_M, BLOCK_N]
|
||||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||||
@ -820,6 +849,7 @@ def causal_conv1d_update(
|
|||||||
activation: Union[bool, str, None] = None,
|
activation: Union[bool, str, None] = None,
|
||||||
cache_seqlens: Optional[torch.Tensor] = None,
|
cache_seqlens: Optional[torch.Tensor] = None,
|
||||||
conv_state_indices: Optional[torch.Tensor] = None,
|
conv_state_indices: Optional[torch.Tensor] = None,
|
||||||
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
validate_data=False,
|
validate_data=False,
|
||||||
@ -890,10 +920,11 @@ def causal_conv1d_update(
|
|||||||
) # X (batch, dim, seqlen)
|
) # X (batch, dim, seqlen)
|
||||||
|
|
||||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||||
|
|
||||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||||
)
|
)
|
||||||
state_len = width - 1
|
stride_state_indices = conv_state_indices.stride(
|
||||||
|
0) if conv_state_indices is not None else 0
|
||||||
|
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||||
np2_statelen = triton.next_power_of_2(state_len)
|
np2_statelen = triton.next_power_of_2(state_len)
|
||||||
|
|
||||||
def grid(META):
|
def grid(META):
|
||||||
@ -910,6 +941,7 @@ def causal_conv1d_update(
|
|||||||
conv_state,
|
conv_state,
|
||||||
cache_seqlens,
|
cache_seqlens,
|
||||||
conv_state_indices,
|
conv_state_indices,
|
||||||
|
num_accepted_tokens,
|
||||||
out,
|
out,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
batch,
|
batch,
|
||||||
@ -926,6 +958,7 @@ def causal_conv1d_update(
|
|||||||
stride_istate_seq,
|
stride_istate_seq,
|
||||||
stride_istate_dim,
|
stride_istate_dim,
|
||||||
stride_istate_token,
|
stride_istate_token,
|
||||||
|
stride_state_indices,
|
||||||
stride_o_seq,
|
stride_o_seq,
|
||||||
stride_o_dim,
|
stride_o_dim,
|
||||||
stride_o_token,
|
stride_o_token,
|
||||||
@ -936,6 +969,7 @@ def causal_conv1d_update(
|
|||||||
KERNEL_WIDTH=width,
|
KERNEL_WIDTH=width,
|
||||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||||
|
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||||
NP2_STATELEN=np2_statelen,
|
NP2_STATELEN=np2_statelen,
|
||||||
USE_PAD_SLOT=pad_slot_id is not None,
|
USE_PAD_SLOT=pad_slot_id is not None,
|
||||||
BLOCK_N=256,
|
BLOCK_N=256,
|
||||||
|
|||||||
@ -312,7 +312,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
# TODO(tdoublep): remove as full cuda graph support is added
|
# TODO(tdoublep): remove as full cuda graph support is added
|
||||||
FCG_NOT_SUPPORTED_MODELS = [
|
FCG_NOT_SUPPORTED_MODELS = [
|
||||||
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
|
"Lfm2ForCausalLM",
|
||||||
|
"MiniMaxText01ForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
|
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
|
||||||
|
|||||||
1294
vllm/model_executor/models/qwen3_next.py
Normal file
1294
vllm/model_executor/models/qwen3_next.py
Normal file
File diff suppressed because it is too large
Load Diff
285
vllm/model_executor/models/qwen3_next_mtp.py
Normal file
285
vllm/model_executor/models/qwen3_next_mtp.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Inference-only Qwen3Next MTP model."""
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer,
|
||||||
|
Qwen3NextRMSNorm)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||||
|
|
||||||
|
from .interfaces import SupportsPP
|
||||||
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
config: Qwen3NextConfig = model_config.hf_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||||
|
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
|
||||||
|
self.config.hidden_size,
|
||||||
|
gather_output=True,
|
||||||
|
bias=False,
|
||||||
|
return_bias=False)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList(
|
||||||
|
Qwen3NextDecoderLayer(
|
||||||
|
config,
|
||||||
|
layer_type="full_attention",
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
||||||
|
) for idx in range(self.num_mtp_layers))
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
|
||||||
|
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
|
||||||
|
hidden_states = self.pre_fc_norm_hidden(hidden_states)
|
||||||
|
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
|
||||||
|
hidden_states = self.fc(hidden_states)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||||
|
hidden_states, residual = self.layers[current_step_idx](
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "mlp.experts" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": ["up_proj", "down_proj"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
assert not cache_config.enable_prefix_caching, \
|
||||||
|
"Qwen3NextMTP currently does not support prefix caching"
|
||||||
|
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "model"))
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
hidden_states = self.model(input_ids, positions, hidden_states,
|
||||||
|
intermediate_tensors, inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
shared_weight_names = ["embed_tokens", "lm_head"]
|
||||||
|
|
||||||
|
def remap_weight_names(weights):
|
||||||
|
for name, weight in weights:
|
||||||
|
if name.startswith("mtp."):
|
||||||
|
name = name.replace("mtp.", "model.")
|
||||||
|
elif not any(key in name for key in shared_weight_names):
|
||||||
|
continue
|
||||||
|
yield name, weight
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(remap_weight_names(weights))
|
||||||
@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||||
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
||||||
|
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
|
||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||||
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
||||||
@ -285,6 +286,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
|
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
|
||||||
# Temporarily disabled.
|
# Temporarily disabled.
|
||||||
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
||||||
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
|
|||||||
@ -79,7 +79,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
|||||||
ultravox="UltravoxConfig",
|
ultravox="UltravoxConfig",
|
||||||
step3_vl="Step3VLConfig",
|
step3_vl="Step3VLConfig",
|
||||||
step3_text="Step3TextConfig",
|
step3_text="Step3TextConfig",
|
||||||
)
|
qwen3_next="Qwen3NextConfig")
|
||||||
|
|
||||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||||
"llm_config": "text_config",
|
"llm_config": "text_config",
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
|||||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||||
from vllm.transformers_utils.configs.ovis import OvisConfig
|
from vllm.transformers_utils.configs.ovis import OvisConfig
|
||||||
|
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
|
||||||
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
||||||
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
||||||
Step3VisionEncoderConfig,
|
Step3VisionEncoderConfig,
|
||||||
@ -50,4 +51,5 @@ __all__ = [
|
|||||||
"Step3VLConfig",
|
"Step3VLConfig",
|
||||||
"Step3VisionEncoderConfig",
|
"Step3VisionEncoderConfig",
|
||||||
"Step3TextConfig",
|
"Step3TextConfig",
|
||||||
|
"Qwen3NextConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
275
vllm/transformers_utils/configs/qwen3_next.py
Normal file
275
vllm/transformers_utils/configs/qwen3_next.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Qwen3-Next model configuration"""
|
||||||
|
|
||||||
|
from transformers.configuration_utils import (PretrainedConfig,
|
||||||
|
layer_type_validation)
|
||||||
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3NextConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
|
||||||
|
Qwen3-Next model according to the specified arguments, defining the model architecture.
|
||||||
|
Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||||
|
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 151936):
|
||||||
|
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids`.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 2):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||||
|
hidden_act (`str`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||||
|
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||||
|
accordingly.
|
||||||
|
Expected contents:
|
||||||
|
`rope_type` (`str`):
|
||||||
|
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||||
|
'llama3'], with 'default' being the original RoPE implementation.
|
||||||
|
`factor` (`float`, *optional*):
|
||||||
|
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||||
|
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||||
|
original maximum pre-trained length.
|
||||||
|
`original_max_position_embeddings` (`int`, *optional*):
|
||||||
|
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||||
|
pretraining.
|
||||||
|
`attention_factor` (`float`, *optional*):
|
||||||
|
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||||
|
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||||
|
`factor` field to infer the suggested value.
|
||||||
|
`beta_fast` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 32.
|
||||||
|
`beta_slow` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 1.
|
||||||
|
`short_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`long_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`low_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||||
|
`high_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||||
|
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
|
||||||
|
Percentage of the query and keys which will have rotary embedding.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
head_dim (`int`, *optional*, defaults to 256):
|
||||||
|
Projection weights dimension in multi-head attention.
|
||||||
|
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
|
||||||
|
Kernel size of the convolution used in linear attention layers.
|
||||||
|
linear_key_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
Dimension of each key head in linear attention.
|
||||||
|
linear_value_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
Dimension of each value head in linear attention.
|
||||||
|
linear_num_key_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of key heads used in linear attention layers.
|
||||||
|
linear_num_value_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of value heads used in linear attention layers.
|
||||||
|
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency of the MoE layer.
|
||||||
|
moe_intermediate_size (`int`, *optional*, defaults to 512):
|
||||||
|
Intermediate size of the routed expert.
|
||||||
|
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
|
||||||
|
Intermediate size of the shared expert.
|
||||||
|
num_experts_per_tok (`int`, *optional*, defaults to 10):
|
||||||
|
Number of selected experts.
|
||||||
|
num_experts (`int`, *optional*, defaults to 512):
|
||||||
|
Number of routed experts.
|
||||||
|
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the topk probabilities.
|
||||||
|
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||||
|
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||||
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||||
|
The aux loss factor for the total loss.
|
||||||
|
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
|
||||||
|
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
|
||||||
|
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||||
|
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||||
|
layer_types (`list[str]`, *optional*):
|
||||||
|
Types of each layer (attention or linear).
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Qwen3Next style configuration
|
||||||
|
>>> configuration = Qwen3NextConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
|
||||||
|
>>> model = Qwen3NextModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
model_type = "qwen3_next"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||||
|
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=151936,
|
||||||
|
hidden_size=2048,
|
||||||
|
intermediate_size=5632,
|
||||||
|
num_hidden_layers=48,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
partial_rotary_factor=0.25,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
head_dim=256,
|
||||||
|
linear_conv_kernel_dim=4,
|
||||||
|
linear_key_head_dim=128,
|
||||||
|
linear_value_head_dim=128,
|
||||||
|
linear_num_key_heads=16,
|
||||||
|
linear_num_value_heads=32,
|
||||||
|
decoder_sparse_step=1,
|
||||||
|
moe_intermediate_size=512,
|
||||||
|
shared_expert_intermediate_size=512,
|
||||||
|
num_experts_per_tok=10,
|
||||||
|
num_experts=512,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
output_router_logits=False,
|
||||||
|
router_aux_loss_coef=0.001,
|
||||||
|
mlp_only_layers=None,
|
||||||
|
layer_types=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if mlp_only_layers is None:
|
||||||
|
mlp_only_layers = []
|
||||||
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.partial_rotary_factor = partial_rotary_factor
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.head_dim = head_dim
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
self.layer_types = layer_types
|
||||||
|
if self.layer_types is None:
|
||||||
|
self.layer_types = [
|
||||||
|
"linear_attention" if bool((i + 1) % 4) else "full_attention"
|
||||||
|
for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
layer_type_validation(self.layer_types)
|
||||||
|
|
||||||
|
# linear attention part
|
||||||
|
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
||||||
|
self.linear_key_head_dim = linear_key_head_dim
|
||||||
|
self.linear_value_head_dim = linear_value_head_dim
|
||||||
|
self.linear_num_key_heads = linear_num_key_heads
|
||||||
|
self.linear_num_value_heads = linear_num_value_heads
|
||||||
|
|
||||||
|
# MoE arguments
|
||||||
|
self.decoder_sparse_step = decoder_sparse_step
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
self.mlp_only_layers = mlp_only_layers
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Qwen3NextConfig"]
|
||||||
319
vllm/v1/attention/backends/gdn_attn.py
Normal file
319
vllm/v1/attention/backends/gdn_attn.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Backend for GatedDeltaNet attention."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
split_decodes_and_prefills)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
|
||||||
|
class GDNAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||||
|
return GDNAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GDNAttentionMetadata:
|
||||||
|
num_prefills: int
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_spec_decodes: int
|
||||||
|
num_spec_decode_tokens: int
|
||||||
|
|
||||||
|
has_initial_state: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
spec_query_start_loc: Optional[
|
||||||
|
torch.Tensor] = None # shape: [num_spec_decodes + 1,]
|
||||||
|
non_spec_query_start_loc: Optional[
|
||||||
|
torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,]
|
||||||
|
|
||||||
|
spec_state_indices_tensor: Optional[
|
||||||
|
torch.Tensor] = None # shape: [batch, num_spec]
|
||||||
|
non_spec_state_indices_tensor: Optional[
|
||||||
|
torch.Tensor] = None # shape: [batch - num_spec_decodes,]
|
||||||
|
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
|
||||||
|
spec_token_masks: Optional[
|
||||||
|
torch.
|
||||||
|
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||||
|
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
|
||||||
|
|
||||||
|
|
||||||
|
class GDNAttentionMetadataBuilder(
|
||||||
|
AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||||
|
|
||||||
|
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
if self.speculative_config:
|
||||||
|
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
|
||||||
|
else:
|
||||||
|
self.num_spec = 0
|
||||||
|
self.use_spec_decode = self.num_spec > 0
|
||||||
|
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
|
||||||
|
|
||||||
|
self.use_full_cuda_graph = \
|
||||||
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
self.decode_cudagraph_max_bs = min(
|
||||||
|
self.vllm_config.scheduler_config.max_num_seqs,
|
||||||
|
self.compilation_config.max_capture_size)
|
||||||
|
|
||||||
|
self.spec_state_indices_tensor = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.non_spec_state_indices_tensor = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.spec_sequence_masks = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, ),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.spec_token_masks = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs * (self.num_spec + 1), ),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.spec_query_start_loc = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.non_spec_query_start_loc = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.num_accepted_tokens = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||||
|
num_draft_tokens: Optional[torch.Tensor] = None,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> GDNAttentionMetadata:
|
||||||
|
m = common_attn_metadata
|
||||||
|
|
||||||
|
query_start_loc = m.query_start_loc
|
||||||
|
context_lens = m.num_computed_tokens_cpu
|
||||||
|
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||||
|
seq_lens_tensor = m.seq_lens
|
||||||
|
|
||||||
|
if (not self.use_spec_decode or num_draft_tokens is None
|
||||||
|
or num_draft_tokens.sum().item() == 0):
|
||||||
|
spec_sequence_masks = None
|
||||||
|
else:
|
||||||
|
spec_sequence_masks = (num_draft_tokens > 0) & (
|
||||||
|
context_lens_tensor +
|
||||||
|
(num_draft_tokens + 1) == seq_lens_tensor)
|
||||||
|
if spec_sequence_masks.sum().item() == 0:
|
||||||
|
spec_sequence_masks = None
|
||||||
|
|
||||||
|
if spec_sequence_masks is None:
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(m, decode_threshold=1))
|
||||||
|
num_spec_decodes = 0
|
||||||
|
num_spec_decode_tokens = 0
|
||||||
|
spec_token_masks = None
|
||||||
|
spec_state_indices_tensor = None
|
||||||
|
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||||
|
spec_query_start_loc = None
|
||||||
|
non_spec_query_start_loc = query_start_loc
|
||||||
|
num_accepted_tokens = None
|
||||||
|
else:
|
||||||
|
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||||
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||||
|
|
||||||
|
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||||
|
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||||
|
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||||
|
num_decode_tokens = num_decodes
|
||||||
|
num_prefill_tokens = non_spec_query_lens.sum().item(
|
||||||
|
) - num_decode_tokens
|
||||||
|
|
||||||
|
if num_prefills == 0 and num_decodes == 0:
|
||||||
|
spec_token_masks = torch.ones(
|
||||||
|
(min(num_spec_decodes *
|
||||||
|
(self.num_spec + 1), query_start_loc[-1].item())),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
spec_state_indices_tensor = m.block_table_tensor[:, :self.
|
||||||
|
num_spec + 1]
|
||||||
|
non_spec_state_indices_tensor = None
|
||||||
|
spec_query_start_loc = query_start_loc
|
||||||
|
non_spec_query_start_loc = None
|
||||||
|
else:
|
||||||
|
spec_token_masks = torch.repeat_interleave(
|
||||||
|
spec_sequence_masks, query_lens)
|
||||||
|
spec_state_indices_tensor = m.block_table_tensor[
|
||||||
|
spec_sequence_masks, :self.num_spec + 1]
|
||||||
|
non_spec_state_indices_tensor = \
|
||||||
|
m.block_table_tensor[~spec_sequence_masks, 0]
|
||||||
|
|
||||||
|
spec_query_start_loc = torch.zeros(
|
||||||
|
num_spec_decodes + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
torch.cumsum(query_lens[spec_sequence_masks],
|
||||||
|
dim=0,
|
||||||
|
out=spec_query_start_loc[1:])
|
||||||
|
non_spec_query_start_loc = torch.zeros(
|
||||||
|
query_lens.size(0) - num_spec_decodes + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
torch.cumsum(query_lens[~spec_sequence_masks],
|
||||||
|
dim=0,
|
||||||
|
out=non_spec_query_start_loc[1:])
|
||||||
|
|
||||||
|
num_spec_decode_tokens = min(
|
||||||
|
num_spec_decodes * (self.num_spec + 1),
|
||||||
|
spec_token_masks.size(0))
|
||||||
|
assert num_accepted_tokens is not None
|
||||||
|
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||||
|
|
||||||
|
if num_prefills > 0:
|
||||||
|
has_initial_state = context_lens_tensor > 0
|
||||||
|
if spec_sequence_masks is not None:
|
||||||
|
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||||
|
else:
|
||||||
|
has_initial_state = None
|
||||||
|
|
||||||
|
# prepare tensors for cudagraph
|
||||||
|
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
|
||||||
|
and num_spec_decodes <= self.decode_cudagraph_max_bs):
|
||||||
|
num_total_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
|
m.num_actual_tokens)
|
||||||
|
batch_size = num_total_tokens // (self.num_spec + 1)
|
||||||
|
|
||||||
|
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||||
|
spec_state_indices_tensor, non_blocking=True)
|
||||||
|
spec_state_indices_tensor = self.spec_state_indices_tensor[:
|
||||||
|
batch_size]
|
||||||
|
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||||
|
spec_sequence_masks, non_blocking=True)
|
||||||
|
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||||
|
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||||
|
|
||||||
|
assert spec_token_masks is not None
|
||||||
|
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
|
||||||
|
spec_token_masks, non_blocking=True)
|
||||||
|
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
|
||||||
|
spec_token_masks[spec_token_masks.size(0):].fill_(False)
|
||||||
|
|
||||||
|
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
|
||||||
|
spec_query_start_loc, non_blocking=True)
|
||||||
|
spec_num_query_tokens = spec_query_start_loc[
|
||||||
|
-1] # type: ignore[index]
|
||||||
|
spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1]
|
||||||
|
spec_query_start_loc[num_spec_decodes +
|
||||||
|
1:].fill_(spec_num_query_tokens)
|
||||||
|
|
||||||
|
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||||
|
num_accepted_tokens, non_blocking=True)
|
||||||
|
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||||
|
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||||
|
|
||||||
|
if (self.use_full_cuda_graph and num_prefills == 0
|
||||||
|
and num_spec_decodes == 0
|
||||||
|
and num_decodes <= self.decode_cudagraph_max_bs):
|
||||||
|
num_total_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
|
m.num_actual_tokens)
|
||||||
|
batch_size = num_total_tokens
|
||||||
|
|
||||||
|
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||||
|
non_spec_state_indices_tensor, non_blocking=True)
|
||||||
|
non_spec_state_indices_tensor = \
|
||||||
|
self.non_spec_state_indices_tensor[:batch_size]
|
||||||
|
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
self.non_spec_query_start_loc[:num_decodes + 1].copy_(
|
||||||
|
non_spec_query_start_loc, non_blocking=True)
|
||||||
|
non_spec_num_query_tokens = non_spec_query_start_loc[
|
||||||
|
-1] # type: ignore[index]
|
||||||
|
non_spec_query_start_loc = \
|
||||||
|
self.non_spec_query_start_loc[:batch_size + 1]
|
||||||
|
non_spec_query_start_loc[num_decodes +
|
||||||
|
1:].fill_(non_spec_num_query_tokens)
|
||||||
|
|
||||||
|
attn_metadata = GDNAttentionMetadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_spec_decodes=num_spec_decodes,
|
||||||
|
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||||
|
has_initial_state=has_initial_state,
|
||||||
|
spec_query_start_loc=spec_query_start_loc,
|
||||||
|
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||||
|
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||||
|
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||||
|
spec_sequence_masks=spec_sequence_masks,
|
||||||
|
spec_token_masks=spec_token_masks,
|
||||||
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
|
)
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
def build_for_cudagraph_capture(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata):
|
||||||
|
"""
|
||||||
|
This method builds the metadata for full cudagraph capture.
|
||||||
|
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||||
|
"""
|
||||||
|
m = common_attn_metadata
|
||||||
|
|
||||||
|
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
|
||||||
|
and ((m.num_reqs + 1) * (self.num_spec + 1)
|
||||||
|
>= m.num_actual_tokens)), \
|
||||||
|
"GDN only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
|
num_accepted_tokens = torch.full((m.num_reqs, ),
|
||||||
|
m.max_query_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=m.query_start_loc.device)
|
||||||
|
num_drafted_tokens = torch.full((m.num_reqs, ),
|
||||||
|
self.num_spec,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=m.query_start_loc.device)
|
||||||
|
|
||||||
|
# Fixes query-start loc for spec-sequence-indices.
|
||||||
|
m.query_start_loc = torch.arange(0,
|
||||||
|
m.num_actual_tokens + 1,
|
||||||
|
step=m.max_query_len,
|
||||||
|
device=m.query_start_loc.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
|
||||||
|
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
|
||||||
|
|
||||||
|
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
|
||||||
@ -559,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
num_running_requests: int) -> int:
|
num_running_requests: int) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def get_num_blocks_to_allocate(
|
||||||
|
self, request_id: str, num_tokens: int,
|
||||||
|
new_computed_blocks: list[KVCacheBlock]) -> int:
|
||||||
|
"""
|
||||||
|
Get the number of blocks needed to be allocated for the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The request ID.
|
||||||
|
num_tokens: The total number of tokens that need a slot (including
|
||||||
|
tokens that are already allocated).
|
||||||
|
new_computed_blocks: The new computed blocks just hitting the
|
||||||
|
prefix caching.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||||
|
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||||
|
num_tokens += (self.kv_cache_spec.block_size *
|
||||||
|
self.kv_cache_spec.num_speculative_blocks)
|
||||||
|
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||||
|
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
|
||||||
|
len(self.req_to_blocks[request_id]))
|
||||||
|
# If a computed block of a request is an eviction candidate (in the
|
||||||
|
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||||
|
# to a computed block when the request is allocated, so we also count
|
||||||
|
# it as needed to be allocated.
|
||||||
|
num_evictable_computed_blocks = sum(
|
||||||
|
blk.ref_cnt == 0 and not blk.is_null
|
||||||
|
for blk in new_computed_blocks)
|
||||||
|
return num_new_blocks + num_evictable_computed_blocks
|
||||||
|
|
||||||
def allocate_new_blocks(self, request_id: str,
|
def allocate_new_blocks(self, request_id: str,
|
||||||
num_tokens: int) -> list[KVCacheBlock]:
|
num_tokens: int) -> list[KVCacheBlock]:
|
||||||
new_blocks = super().allocate_new_blocks(request_id, num_tokens)
|
# Allocate extra `num_speculative_blocks` blocks for
|
||||||
assert len(self.req_to_blocks[request_id]) == 1, (
|
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||||
"MambaManager should only allocate 1 block for each request.")
|
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||||
return new_blocks
|
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||||
|
num_tokens += (self.kv_cache_spec.block_size *
|
||||||
|
self.kv_cache_spec.num_speculative_blocks)
|
||||||
|
return super().allocate_new_blocks(request_id, num_tokens)
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionManager(SingleTypeKVCacheManager):
|
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||||
|
|||||||
@ -194,6 +194,7 @@ class MambaSpec(KVCacheSpec):
|
|||||||
dtypes: tuple[torch.dtype]
|
dtypes: tuple[torch.dtype]
|
||||||
page_size_padded: Optional[int] = None
|
page_size_padded: Optional[int] = None
|
||||||
mamba_type: str = "mamba2"
|
mamba_type: str = "mamba2"
|
||||||
|
num_speculative_blocks: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def page_size_bytes(self) -> int:
|
||||||
|
|||||||
@ -218,7 +218,7 @@ class EagleProposer:
|
|||||||
hidden_states=self.hidden_states[:num_input_tokens],
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if self.method in ("deepseek_mtp", "ernie_mtp"):
|
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"):
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
hidden_states = last_hidden_states
|
hidden_states = last_hidden_states
|
||||||
else:
|
else:
|
||||||
@ -322,12 +322,18 @@ class EagleProposer:
|
|||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=input_batch_size):
|
num_tokens=input_batch_size):
|
||||||
last_hidden_states, hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.positions[:input_batch_size],
|
positions=self.positions[:input_batch_size],
|
||||||
hidden_states=self.hidden_states[:input_batch_size],
|
hidden_states=self.hidden_states[:input_batch_size],
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
if self.method in ("deepseek_mtp", "ernie_mtp",
|
||||||
|
"qwen3_next_mtp"):
|
||||||
|
last_hidden_states = ret_hidden_states
|
||||||
|
hidden_states = ret_hidden_states
|
||||||
|
else:
|
||||||
|
last_hidden_states, hidden_states = ret_hidden_states
|
||||||
hidden_states = hidden_states[:batch_size]
|
hidden_states = hidden_states[:batch_size]
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||||
None)
|
None)
|
||||||
|
|||||||
@ -156,9 +156,14 @@ class BlockTable:
|
|||||||
class MultiGroupBlockTable:
|
class MultiGroupBlockTable:
|
||||||
"""The BlockTables for each KV cache group."""
|
"""The BlockTables for each KV cache group."""
|
||||||
|
|
||||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
def __init__(self,
|
||||||
max_num_batched_tokens: int, pin_memory: bool,
|
max_num_reqs: int,
|
||||||
device: torch.device, block_sizes: list[int]) -> None:
|
max_model_len: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
pin_memory: bool,
|
||||||
|
device: torch.device,
|
||||||
|
block_sizes: list[int],
|
||||||
|
num_speculative_tokens: int = 0) -> None:
|
||||||
# Note(hc): each dcp rank only store
|
# Note(hc): each dcp rank only store
|
||||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||||
# so the block_size which used for calc max_num_blocks_per_req
|
# so the block_size which used for calc max_num_blocks_per_req
|
||||||
@ -170,10 +175,11 @@ class MultiGroupBlockTable:
|
|||||||
dcp_world_size = 1
|
dcp_world_size = 1
|
||||||
|
|
||||||
self.block_tables = [
|
self.block_tables = [
|
||||||
BlockTable(block_size, max_num_reqs,
|
BlockTable(
|
||||||
cdiv(max_model_len, block_size * dcp_world_size),
|
block_size, max_num_reqs,
|
||||||
max_num_batched_tokens, pin_memory, device)
|
max(cdiv(max_model_len, block_size * dcp_world_size),
|
||||||
for block_size in block_sizes
|
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||||
|
pin_memory, device) for block_size in block_sizes
|
||||||
]
|
]
|
||||||
|
|
||||||
def append_row(self, block_ids: tuple[list[int], ...],
|
def append_row(self, block_ids: tuple[list[int], ...],
|
||||||
|
|||||||
@ -83,6 +83,7 @@ class InputBatch:
|
|||||||
logitsprocs: Optional[LogitsProcessors] = None,
|
logitsprocs: Optional[LogitsProcessors] = None,
|
||||||
is_spec_decode: bool = False,
|
is_spec_decode: bool = False,
|
||||||
is_pooling_model: bool = False,
|
is_pooling_model: bool = False,
|
||||||
|
num_speculative_tokens: int = 0,
|
||||||
):
|
):
|
||||||
self.is_pooling_model = is_pooling_model
|
self.is_pooling_model = is_pooling_model
|
||||||
self.is_spec_decode = is_spec_decode
|
self.is_spec_decode = is_spec_decode
|
||||||
@ -127,6 +128,7 @@ class InputBatch:
|
|||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
device=device,
|
device=device,
|
||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling-related.
|
# Sampling-related.
|
||||||
@ -202,6 +204,14 @@ class InputBatch:
|
|||||||
self.repetition_penalties_cpu_tensor.numpy()
|
self.repetition_penalties_cpu_tensor.numpy()
|
||||||
self.repetition_penalties_reqs: set[str] = set()
|
self.repetition_penalties_reqs: set[str] = set()
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.num_accepted_tokens_cpu = \
|
||||||
|
self.num_accepted_tokens_cpu_tensor.numpy()
|
||||||
|
|
||||||
# lora related
|
# lora related
|
||||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
@ -394,6 +404,9 @@ class InputBatch:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unrecognized request type")
|
raise NotImplementedError("Unrecognized request type")
|
||||||
|
|
||||||
|
# Speculative decoding: by default 1 token is generated.
|
||||||
|
self.num_accepted_tokens_cpu[req_index] = 1
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
lora_id = request.lora_request.lora_int_id
|
lora_id = request.lora_request.lora_int_id
|
||||||
@ -515,6 +528,8 @@ class InputBatch:
|
|||||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
|
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
|
||||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||||
|
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
|
||||||
|
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
|
||||||
|
|
||||||
swap_dict_values(self.generators, i1, i2)
|
swap_dict_values(self.generators, i1, i2)
|
||||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||||
@ -609,6 +624,8 @@ class InputBatch:
|
|||||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||||
self.repetition_penalties_cpu[
|
self.repetition_penalties_cpu[
|
||||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||||
|
self.num_accepted_tokens_cpu[
|
||||||
|
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
|
||||||
generator = self.generators.pop(last_req_index, None)
|
generator = self.generators.pop(last_req_index, None)
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
self.generators[empty_index] = generator
|
self.generators[empty_index] = generator
|
||||||
|
|||||||
@ -53,9 +53,9 @@ from vllm.sampling_params import SamplingType
|
|||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||||
get_dtype_size, is_pin_memory_available, round_up,
|
is_pin_memory_available, round_up, supports_dynamo)
|
||||||
supports_dynamo)
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
@ -324,6 +324,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
numpy=False)
|
numpy=False)
|
||||||
|
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||||
|
dtype=torch.int32)
|
||||||
|
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
@ -663,6 +667,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Refresh batch metadata with any pending updates.
|
# Refresh batch metadata with any pending updates.
|
||||||
self.input_batch.refresh_metadata()
|
self.input_batch.refresh_metadata()
|
||||||
|
|
||||||
|
def _update_states_after_model_execute(
|
||||||
|
self, output_token_ids: torch.Tensor) -> None:
|
||||||
|
"""Update the cached states after model execution.
|
||||||
|
|
||||||
|
This is used for MTP/EAGLE for hybrid models, as in linear attention,
|
||||||
|
only the last token's state is kept. In MTP/EAGLE, for draft tokens
|
||||||
|
the state are kept util we decide how many tokens are accepted for
|
||||||
|
each sequence, and a shifting is done during the next iteration
|
||||||
|
based on the number of accepted tokens.
|
||||||
|
"""
|
||||||
|
if not self.model_config.is_hybrid or not self.speculative_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the number of accepted tokens for each sequence.
|
||||||
|
num_accepted_tokens = (torch.cat(
|
||||||
|
[
|
||||||
|
output_token_ids,
|
||||||
|
torch.full((output_token_ids.size(0), 1),
|
||||||
|
-1,
|
||||||
|
device=output_token_ids.device),
|
||||||
|
],
|
||||||
|
dim=1) == -1).int().argmax(-1).cpu().numpy()
|
||||||
|
for i, num_tokens in enumerate(num_accepted_tokens):
|
||||||
|
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||||
|
|
||||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||||
image_grid_thw = []
|
image_grid_thw = []
|
||||||
video_grid_thw = []
|
video_grid_thw = []
|
||||||
@ -936,6 +965,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# We will ignore the sampled tokens from the partial requests.
|
# We will ignore the sampled tokens from the partial requests.
|
||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
|
num_draft_tokens = None
|
||||||
spec_decode_metadata = None
|
spec_decode_metadata = None
|
||||||
else:
|
else:
|
||||||
# Get the number of draft tokens for each request.
|
# Get the number of draft tokens for each request.
|
||||||
@ -950,6 +980,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||||
num_draft_tokens, cu_num_tokens)
|
num_draft_tokens, cu_num_tokens)
|
||||||
logits_indices = spec_decode_metadata.logits_indices
|
logits_indices = spec_decode_metadata.logits_indices
|
||||||
|
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
||||||
|
self.num_draft_tokens.np[num_reqs:].fill(0)
|
||||||
|
self.num_draft_tokens.copy_to_gpu()
|
||||||
|
|
||||||
logits_indices_padded = None
|
logits_indices_padded = None
|
||||||
if self.cache_config.kv_sharing_fast_prefill:
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
@ -964,6 +997,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_computed_tokens_cpu = (
|
num_computed_tokens_cpu = (
|
||||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
spec_decode_common_attn_metadata = None
|
spec_decode_common_attn_metadata = None
|
||||||
|
if use_spec_decode:
|
||||||
|
self.num_accepted_tokens.np[:num_reqs] = (
|
||||||
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||||
|
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||||
|
self.num_accepted_tokens.copy_to_gpu()
|
||||||
|
|
||||||
# Prepare the attention metadata for each KV cache group and make layers
|
# Prepare the attention metadata for each KV cache group and make layers
|
||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
@ -1034,10 +1072,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
builder,
|
builder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_attn_metadata_args = {}
|
||||||
|
if use_spec_decode and isinstance(builder,
|
||||||
|
GDNAttentionMetadataBuilder):
|
||||||
|
extra_attn_metadata_args = dict(
|
||||||
|
num_accepted_tokens=self.num_accepted_tokens.
|
||||||
|
gpu[:num_reqs],
|
||||||
|
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
||||||
|
)
|
||||||
|
|
||||||
attn_metadata_i = builder.build(
|
attn_metadata_i = builder.build(
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
)
|
**extra_attn_metadata_args)
|
||||||
|
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
@ -1814,6 +1861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
sampler_output.sampled_token_ids = output_token_ids
|
||||||
|
self._update_states_after_model_execute(output_token_ids)
|
||||||
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
@ -2644,13 +2692,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Note: Overriding max_query_len to be the prefill tokens
|
# Note: Overriding max_query_len to be the prefill tokens
|
||||||
max_query_len = num_prefill_tokens
|
max_query_len = num_prefill_tokens
|
||||||
elif uniform_decode:
|
elif uniform_decode:
|
||||||
assert not create_mixed_batch
|
num_reqs = num_tokens // max_query_len
|
||||||
num_reqs = cdiv(num_tokens, max_query_len)
|
|
||||||
assert num_reqs <= max_num_reqs, \
|
assert num_reqs <= max_num_reqs, \
|
||||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||||
if num_tokens % max_query_len != 0:
|
if num_tokens % max_query_len != 0:
|
||||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
num_scheduled_tokens_list[-1] += num_tokens % max_query_len
|
||||||
else:
|
else:
|
||||||
num_reqs = min(num_tokens, max_num_reqs)
|
num_reqs = min(num_tokens, max_num_reqs)
|
||||||
min_tokens_per_req = num_tokens // num_reqs
|
min_tokens_per_req = num_tokens // num_reqs
|
||||||
@ -3297,6 +3344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||||
logitsprocs=self.input_batch.logitsprocs,
|
logitsprocs=self.input_batch.logitsprocs,
|
||||||
is_pooling_model=self.is_pooling_model,
|
is_pooling_model=self.is_pooling_model,
|
||||||
|
num_speculative_tokens=(
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
if self.vllm_config.speculative_config else 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _allocate_kv_cache_tensors(
|
def _allocate_kv_cache_tensors(
|
||||||
@ -3647,7 +3697,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||||
if len(mamba_layers) > 0:
|
if len(mamba_layers) > 0:
|
||||||
if self.vllm_config.speculative_config is not None:
|
if (self.vllm_config.speculative_config is not None
|
||||||
|
and self.vllm_config.model_config.hf_config.model_type
|
||||||
|
not in ["qwen3_next"]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Mamba with speculative decoding is not supported yet.")
|
"Mamba with speculative decoding is not supported yet.")
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
@ -3666,7 +3718,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dtypes=mamba_module.get_state_dtype(),
|
dtypes=mamba_module.get_state_dtype(),
|
||||||
block_size=max_model_len,
|
block_size=max_model_len,
|
||||||
page_size_padded=page_size_padded,
|
page_size_padded=page_size_padded,
|
||||||
mamba_type=mamba_module.mamba_type)
|
mamba_type=mamba_module.mamba_type,
|
||||||
|
num_speculative_blocks=(
|
||||||
|
self.speculative_config.num_speculative_tokens
|
||||||
|
if self.speculative_config else 0),
|
||||||
|
)
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
|
|||||||
@ -78,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
"deepseek_mtp",
|
"deepseek_mtp",
|
||||||
"glm4_moe_mtp",
|
"glm4_moe_mtp",
|
||||||
"mimo_mtp",
|
"mimo_mtp",
|
||||||
"ernie_mtp")) \
|
"ernie_mtp",
|
||||||
|
"qwen3_next_mtp")) \
|
||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
|
|||||||
Reference in New Issue
Block a user