[BugFix] Fix Qwen3-Next PP (#24709)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-09-11 23:35:04 -07:00
committed by GitHub
parent 7920de0a2a
commit f592b3174b

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
@ -917,8 +918,11 @@ class Qwen3NextModel(nn.Module):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if get_pp_group().is_last_rank:
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@ -941,7 +945,7 @@ class Qwen3NextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,