Compare commits

...

13 Commits

Author SHA1 Message Date
8ce5d3198d [P/D] NIXL Updates (#25844)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-28 22:55:33 -07:00
09c2cbc04a [Bugfix] fix Qwen3VLMoe load when pp > 1 (#25838)
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com>
Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-28 22:55:17 -07:00
4c347044c9 [VLM] Update Qwen3-VL max_num_video_tokens calculation for configurable video profiling (#25557)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:12 -07:00
19e7ab7315 [Bugfix] Fix Qwen3-VL regression from #24982 (#25814)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
6de3d431d9 [MM] Optimize memory profiling for scattered multimodal embeddings (#25810)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
b14773bd64 [Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
26a7a33b88 [Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:03 -07:00
5aa5811a16 [CI] Fix FlashInfer AOT in release docker image (#25730)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
c2fa2d4dc9 [Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
32335c8b34 Add option to restrict media domains (#25783)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
04c2b26972 Add filtering for chat template kwargs (#25794)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
ee10d7e6ff Validate API tokens in constant time (#25781)
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
bb79c4da2f Reduce the Cuda Graph memory footprint when running with DBO (#25779)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
48 changed files with 1047 additions and 485 deletions

View File

@ -76,7 +76,7 @@ steps:
queue: arm64_cpu_queue_postmerge queue: arm64_cpu_queue_postmerge
commands: commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
# Add job to create multi-arch manifest # Add job to create multi-arch manifest

View File

@ -404,6 +404,9 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi fi
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1)
# Build AOT kernels # Build AOT kernels
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
python3 -m flashinfer.aot python3 -m flashinfer.aot

View File

@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
!!! tip
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
## Offline Inference ## Offline Inference
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:

View File

@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
- Implement proper authentication and authorization for management interfaces - Implement proper authentication and authorization for management interfaces
- Follow the principle of least privilege for all system components - Follow the principle of least privilege for all system components
### 4. **Restrict Domains Access for Media URLs:**
Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
## Security and Firewalls: Protecting Exposed vLLM Systems ## Security and Firewalls: Protecting Exposed vLLM Systems
While vLLM is designed to allow unsafe network services to be isolated to While vLLM is designed to allow unsafe network services to be isolated to

View File

@ -45,6 +45,7 @@ class MockModelConfig:
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False

View File

@ -240,6 +240,7 @@ class MockModelConfig:
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)

View File

@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format, resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template) resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# NOTE: Qwen2-Audio default chat template is specially defined inside # NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json` # processor class instead of using `tokenizer_config.json`
# yapf: disable # yapf: disable

View File

@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes()) @pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: dict[str, Image.Image], async def test_fetch_image_base64(url_images: dict[str, Image.Image],
raw_image_url: str, suffix: str): raw_image_url: str, suffix: str):
connector = MediaConnector() connector = MediaConnector(
# Domain restriction should not apply to data URLs.
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
url_image = url_images[raw_image_url] url_image = url_images[raw_image_url]
try: try:
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
modality_idxs = argsort_mm_positions(mm_positions) modality_idxs = argsort_mm_positions(mm_positions)
assert modality_idxs == expected_modality_idxs assert modality_idxs == expected_modality_idxs
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}},
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
with pytest.raises(ValueError):
_, _ = connector.fetch_video(disallowed_url)
with pytest.raises(ValueError):
_, _ = await connector.fetch_video_async(disallowed_url)

View File

@ -137,6 +137,9 @@ class ModelConfig:
"""Allowing API requests to read local images or videos from directories """Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only specified by the server file system. This is a security risk. Should only
be enabled in trusted environments.""" be enabled in trusted environments."""
allowed_media_domains: Optional[list[str]] = None
"""If set, only media URLs that belong to this domain can be used for
multi-modal inputs. """
revision: Optional[str] = None revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name, """The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version.""" or a commit id. If unspecified, will use the default version."""

View File

@ -279,6 +279,24 @@ class ParallelConfig:
assert last_exc is not None assert last_exc is not None
raise last_exc raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod @staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool: has_unfinished: bool) -> bool:

View File

@ -281,6 +281,8 @@ class SpeculativeConfig:
trust_remote_code, trust_remote_code,
allowed_local_media_path=self.target_model_config. allowed_local_media_path=self.target_model_config.
allowed_local_media_path, allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype, dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed, seed=self.target_model_config.seed,
revision=self.revision, revision=self.revision,

View File

@ -6,7 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import get_dp_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx from vllm.utils import has_deep_ep, has_pplx
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
super().__init__(cpu_group) super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor, def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor): cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2) assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device, device=x.device,
dtype=x.dtype) dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ rank = self.rank if is_sequence_parallel else self.dp_rank
self.dp_rank - 1] world_size = (self.world_size
end = cu_tokens_across_dp_cpu[self.dp_rank] if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x) buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size): for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx] end = cu_tokens_across_sp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx) get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer return buffer
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
sizes = get_forward_context( hidden_states: torch.Tensor,
).dp_metadata.get_chunk_sizes_across_dp_rank() router_logits: torch.Tensor,
hidden_states, router_logits = get_dp_group().all_gatherv( is_sequence_parallel: bool = False
[hidden_states, router_logits], ) -> tuple[torch.Tensor, torch.Tensor]:
dim=0, sp_size = self.tp_group.world_size if is_sequence_parallel else 1
sizes=sizes, dp_metadata = get_forward_context().dp_metadata
) cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
sizes = get_forward_context( hidden_states: torch.Tensor,
).dp_metadata.get_chunk_sizes_across_dp_rank() is_sequence_parallel: bool = False) -> torch.Tensor:
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0, ep_rank = self.rank if is_sequence_parallel else self.dp_rank
sizes=sizes)
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states return hidden_states
def destroy(self): def destroy(self):
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Gather hidden_states and router_logits from all dp ranks. Gather hidden_states and router_logits from all dp ranks.
""" """
sizes = get_forward_context( sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank() ).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits], [hidden_states, router_logits],
dim=0, dim=0,
sizes=sizes, sizes=sizes,
) )
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
""" """
Reduce-scatter hidden_states across all dp ranks. Reduce-scatter hidden_states across all dp ranks.
""" """
sizes = get_forward_context( sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank() ).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0, dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
sizes=sizes) hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states return hidden_states
def destroy(self): def destroy(self):
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
kwargs, pplx.AllToAll.internode kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode) if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.workspace_tensor = None self.workspace_tensor = None
self.prepare_workspace_tensor = None self.prepare_workspace_tensor = None
self.mapping = None self.mapping = None
self.initialized = False self.initialized = False

View File

@ -28,6 +28,8 @@ class Cache:
class All2AllManagerBase: class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group): def __init__(self, cpu_group):
self.cpu_group = cpu_group self.cpu_group = cpu_group
@ -40,6 +42,7 @@ class All2AllManagerBase:
# all2all lives in ep group, which is merged from dp and tp group # all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group() self.dp_group = get_dp_group()
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction # no self.ep_group since self.ep_group is still in construction
# when we create this object # when we create this object
self.dp_rank = self.dp_group.rank_in_group self.dp_rank = self.dp_group.rank_in_group
@ -60,17 +63,21 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int): def set_num_sms(self, num_sms: int):
pass pass
def max_sms_used(self) -> Optional[int]: def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor, def combine(self,
router_logits: torch.Tensor): hidden_states: torch.Tensor,
raise NotImplementedError is_sequence_parallel: bool = False):
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
module.quant_method.init_prepare_finalize(module) module.quant_method.init_prepare_finalize(module)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
""" """
Combine the hidden states and router logits from the appropriate device. Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.

View File

@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem self.use_torch_symm_mem = use_torch_symm_mem
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
SymmMemCommunicator) SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states return hidden_states

View File

@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states return hidden_states

View File

@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {} self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
def add_new_req( def add_new_req(
self, self,
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time # Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {} self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", self, request: "Request",
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
if not params: if not params:
return return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"): if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl, # NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer. # prefilled blocks need to be saved to host memory before transfer.
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
request_id=req_id, request_id=req_id,
local_block_ids=block_ids, local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
) )
for req_id, (req, block_ids) in self._reqs_need_save.items(): for req_id, (req, block_ids) in self._reqs_need_save.items():
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
) )
meta.reqs_to_send = self._reqs_need_send meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
# Clear the list once workers start the transfers # Clear the list once workers start the transfers
self._reqs_need_recv.clear() self._reqs_need_recv.clear()
self._reqs_need_save.clear() self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_need_send = {} self._reqs_need_send = {}
return meta return meta
@ -465,8 +474,11 @@ class NixlConnectorWorker:
"backends", ["UCX"]) "backends", ["UCX"])
# Agent. # Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len( if nixl_agent_config is None:
non_ucx_backends) > 0 and nixl_agent_config is not None else None config = None
else:
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@ -546,6 +558,8 @@ class NixlConnectorWorker:
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Track the expiration time of requests that are waiting to be sent. # Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {} self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()
# Background thread for handling new handshake requests. # Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@ -1082,6 +1096,7 @@ class NixlConnectorWorker:
"Releasing expired KV blocks for request %s which were " "Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id, "retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
self._reqs_to_process.remove(req_id)
del self._reqs_to_send[req_id] del self._reqs_to_send[req_id]
done_sending.add(req_id) done_sending.add(req_id)
@ -1097,7 +1112,8 @@ class NixlConnectorWorker:
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
if req_id not in self._reqs_to_send: if (req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process):
logger.error( logger.error(
"Potentially invalid KV blocks for " "Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by " "unrecognized request %s were retrieved by "
@ -1110,7 +1126,8 @@ class NixlConnectorWorker:
tp_ratio): tp_ratio):
notified_req_ids.add(req_id) notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id] del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id] self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
return notified_req_ids return notified_req_ids
def _pop_done_transfers( def _pop_done_transfers(
@ -1171,8 +1188,19 @@ class NixlConnectorWorker:
while not self._ready_requests.empty(): while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait()) self._read_blocks_for_req(*self._ready_requests.get_nowait())
# Keep around the requests that have been part of a batch. This is
# needed because async scheduling pushes the misalignment between the
# moment in which requests expiration is set (P side) and the moment in
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# Add to requests that are waiting to be read and track expiration. # Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send) for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug( logger.debug(

View File

@ -871,17 +871,24 @@ class GroupCoordinator:
model) model)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states, return self.device_communicator.dispatch(hidden_states,
router_logits) router_logits,
is_sequence_parallel)
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor: def combine(self,
hidden_states,
is_sequence_parallel: bool = False) -> torch.Tensor:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states) return self.device_communicator.combine(hidden_states,
is_sequence_parallel)
else: else:
return hidden_states return hidden_states

View File

@ -297,6 +297,8 @@ class EngineArgs:
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: Optional[
list[str]] = ModelConfig.allowed_media_domains
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
@ -531,6 +533,8 @@ class EngineArgs:
**model_kwargs["hf_config_path"]) **model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path", model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"]) **model_kwargs["allowed_local_media_path"])
model_group.add_argument("--allowed-media-domains",
**model_kwargs["allowed_media_domains"])
model_group.add_argument("--revision", **model_kwargs["revision"]) model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision", model_group.add_argument("--code-revision",
**model_kwargs["code_revision"]) **model_kwargs["code_revision"])
@ -997,6 +1001,7 @@ class EngineArgs:
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path, allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
dtype=self.dtype, dtype=self.dtype,
seed=self.seed, seed=self.seed,
revision=self.revision, revision=self.revision,

View File

@ -11,7 +11,12 @@ from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast) cast)
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable # yapf: enable
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid, supports_kw
logger = init_logger(__name__) logger = init_logger(__name__)
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
@property
def allowed_media_domains(self):
return self._model_config.allowed_media_domains
@property @property
def mm_registry(self): def mm_registry(self):
return MULTIMODAL_REGISTRY return MULTIMODAL_REGISTRY
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
) )
def parse_image( def parse_image(
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
) )
def parse_image( def parse_image(
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
def apply_hf_chat_template( def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
) )
try: try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template, chat_template=hf_chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **resolved_kwargs,
) )
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's

View File

@ -86,6 +86,8 @@ class LLM:
or videos from directories specified by the server file system. or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted This is a security risk. Should only be enabled in trusted
environments. environments.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
@ -169,6 +171,7 @@ class LLM:
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: ModelDType = "auto", dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None, quantization: Optional[QuantizationMethods] = None,
@ -264,6 +267,7 @@ class LLM:
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path, allowed_local_media_path=allowed_local_media_path,
allowed_media_domains=allowed_media_domains,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
quantization=quantization, quantization=quantization,

View File

@ -3,12 +3,14 @@
import asyncio import asyncio
import gc import gc
import hashlib
import importlib import importlib
import inspect import inspect
import json import json
import multiprocessing import multiprocessing
import multiprocessing.forkserver as forkserver import multiprocessing.forkserver as forkserver
import os import os
import secrets
import signal import signal
import socket import socket
import tempfile import tempfile
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
class AuthenticationMiddleware: class AuthenticationMiddleware:
""" """
Pure ASGI middleware that authenticates each request by checking Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}". if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes Notes
----- -----
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, tokens: list[str]) -> None: def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app self.app = app
self.api_tokens = {f"Bearer {token}" for token in tokens} self.api_tokens = [
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive, def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]: send: Send) -> Awaitable[None]:
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
url_path = URL(scope=scope).path.removeprefix(root_path) url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope) headers = Headers(scope=scope)
# Type narrow to satisfy mypy. # Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get( if url_path.startswith("/v1") and not self.verify_token(headers):
"Authorization") not in self.api_tokens:
response = JSONResponse(content={"error": "Unauthorized"}, response = JSONResponse(content={"error": "Unauthorized"},
status_code=401) status_code=401)
return response(scope, receive, send) return response(scope, receive, send)
@ -1696,6 +1716,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args. exclude_tools_when_tool_choice_none=args.

View File

@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto" chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template. """The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"` * "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI * "openai" will render the content as a list of dictionaries, similar to
schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant" response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`.""" """The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None ssl_keyfile: Optional[str] = None

View File

@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "", reasoning_parser: str = "",
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role self.response_role = response_role
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
# set up tool use # set up tool use
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
( (
conversation, conversation,
request_prompts, request_prompts,
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self. chat_template_content_format=self.
chat_template_content_format, chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,

View File

@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
return BatchDescriptor(self.num_tokens, uniform_decode=False) return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int) -> list[int]:
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
sequence_parallel_size)
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int, max_num_tokens: int,
chunk_idx: int) -> list[int]: chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
for i in range(dp_size): sequence_parallel_size)
dp_tokens = num_tokens_across_dp_cpu[i] sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens, local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx)) sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0: if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done local_size[i] = 1 # ensure lockstep even if done
return local_size return local_size
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
@dataclass @dataclass
class DPMetadata: class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None local_sizes: Optional[list[int]] = None
@staticmethod @staticmethod
@ -98,6 +113,17 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group) dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu() return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
num_tokens_across_sp_cpu = (
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod @staticmethod
def should_ubatch_across_dp( def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int, should_ubatch: bool, orig_num_tokens_per_ubatch: int,
@ -147,10 +173,10 @@ class DPMetadata:
@staticmethod @staticmethod
def make( def make(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
attn_metadata: Any, attn_metadata: Any,
num_tokens: int, num_tokens: int,
num_tokens_across_dp: Optional[torch.Tensor] = None num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata": ) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1 assert parallel_config.data_parallel_size > 1
@ -167,18 +193,18 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce # If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] assert (num_tokens_across_dp_cpu is None
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" or num_tokens_across_dp_cpu[dp_rank] == batchsize
if num_tokens_across_dp is None: ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank) batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
@contextmanager @contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): def chunked_sizes(self, sequence_parallel_size: int,
max_chunk_size_per_rank: int, chunk_idx: int):
""" """
Context manager to compute and temporarily set the per-rank local token Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution. sizes for a specific chunk during chunked forward execution.
@ -192,31 +218,40 @@ class DPMetadata:
`chunk_idx`, this context manager sets `self.local_sizes` to the number `chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank. of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context. `self.local_sizes` is only valid inside the context.
Args: Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk. allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for. chunk_idx: The index of the chunk to compute sizes for.
""" """
cu_sizes = self.cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu = [
(cu_sizes[i] -
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
for i in range(len(cu_sizes))
]
self.local_sizes = _compute_chunked_local_num_tokens( self.local_sizes = _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) self.num_tokens_across_dp_cpu, sequence_parallel_size,
max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size)
try: try:
yield self.local_sizes yield self.local_sizes
finally: finally:
self.local_sizes = None self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
assert self.local_sizes is not None
return self.local_sizes return self.local_sizes

View File

@ -3,6 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, Union, get_args, overload from typing import Callable, Literal, Optional, Union, get_args, overload
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
if dp_size is not None else get_dp_group().world_size) if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel: self.sp_size = tp_size_ if is_sequence_parallel else 1
self.sp_size = tp_size_
self.moe_parallel_config: FusedMoEParallelConfig = ( self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make( FusedMoEParallelConfig.make(
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
# clamp start and end # clamp start and end
chunk_start = min(chunk_start, num_tokens - 1) chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens) chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, with ctx.dp_metadata.chunked_sizes(self.sp_size,
moe_dp_chunk_size_per_rank,
chunk_idx): chunk_idx):
process_chunk(chunk_start, process_chunk(chunk_start,
chunk_end, chunk_end,
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
else: else:
shared_output = None shared_output = None
if do_naive_dispatch_combine: ctx = get_forward_context()
hidden_states, router_logits = get_ep_group().dispatch( sp_ctx = ctx.dp_metadata.sp_local_sizes(
hidden_states, router_logits) self.sp_size) if ctx.dp_metadata else nullcontext()
# Matrix multiply. with sp_ctx:
final_hidden_states = self.quant_method.apply( if do_naive_dispatch_combine:
layer=self, hidden_states, router_logits = get_ep_group().dispatch(
x=hidden_states, hidden_states, router_logits, self.is_sequence_parallel)
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None: # Matrix multiply.
assert not isinstance(final_hidden_states, tuple) final_hidden_states = self.quant_method.apply(
assert self.shared_experts is not None layer=self,
final_hidden_states = ( x=hidden_states,
shared_output, router_logits=router_logits,
final_hidden_states, top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor, if shared_output is not None:
do_combine: bool = True) -> torch.Tensor: assert not isinstance(final_hidden_states, tuple)
if do_naive_dispatch_combine and do_combine: assert self.shared_experts is not None
states = get_ep_group().combine(states) final_hidden_states = (
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): def reduce_output(states: torch.Tensor,
states = self.maybe_all_reduce_tensor_model_parallel(states) do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states,
self.is_sequence_parallel)
return states if (not self.is_sequence_parallel and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)):
states = self.maybe_all_reduce_tensor_model_parallel(
states)
if self.shared_experts is not None: return states
return (
reduce_output(final_hidden_states[0], do_combine=False), if self.shared_experts is not None:
reduce_output(final_hidden_states[1]), return (
) reduce_output(final_hidden_states[0], do_combine=False),
elif self.zero_expert_num is not None and self.zero_expert_num > 0: reduce_output(final_hidden_states[1]),
assert isinstance(final_hidden_states, torch.Tensor) )
return reduce_output(final_hidden_states) + zero_expert_result elif self.zero_expert_num is not None and self.zero_expert_num > 0:
else: assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(

View File

@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -298,14 +298,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
Experts (MoE) Layer. Experts (MoE) Layer.
""" """
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self, super().__init__(vllm_config, prefix)
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None, config = vllm_config.model_config.hf_config
quant_config: Optional[QuantizationConfig] = None, quant_config = vllm_config.quant_config
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.mlp = AriaTextMoELayer(config, self.mlp = AriaTextMoELayer(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")

View File

@ -32,7 +32,6 @@ import torch
from torch import nn from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
return x return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
fake_impl=sequence_parallel_chunk_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
if config.hidden_act != "silu": if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
# TODO: We can replace the all_reduce at the end of attn with a # TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here. # reduce_scatter instead of chunking here.
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk( hidden_states = sequence_parallel_chunk(hidden_states)
hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)

View File

@ -29,10 +29,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
prefix: str, prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_emb_norm = RMSNorm(config.hidden_size, self.mtp_emb_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
prefix)
def forward( def forward(
self, self,
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict({ self.layers = torch.nn.ModuleDict({
str(idx): str(idx):
ErnieMultiTokenPredictorLayer( ErnieMultiTokenPredictorLayer(
config, vllm_config,
f"{prefix}.layers.{idx}", f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
) )
for idx in range(self.mtp_start_layer_idx, for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers) self.mtp_start_layer_idx + self.num_mtp_layers)

View File

@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
class Glm4DecoderLayer(nn.Module): class Glm4DecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: Glm4Config, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[Glm4Config] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)

View File

@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
layer_idx: int, layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok self.experts_per_token = config.num_experts_per_tok
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
has_bias=True, has_bias=True,
activation="swigluoai") activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)
g = self.router(x) g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g) x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
return x return x
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config, self.attn = OAIAttention(config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
cache_config=cache_config) cache_config=cache_config)
self.mlp = MLPBlock(config, self.mlp = MLPBlock(vllm_config,
self.layer_idx, self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
): ):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding( self.embedding = VocabParallelEmbedding(
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers, self.config.num_hidden_layers,
lambda prefix: TransformerBlock( lambda prefix: TransformerBlock(
self.config, vllm_config,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=prefix, prefix=prefix,
), ),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",

View File

@ -29,12 +29,13 @@ from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
is_sequence_parallel=False,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.is_sequence_parallel = is_sequence_parallel
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size, self.gate = ReplicatedLinear(hidden_size,
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
num_tokens = orig_shape[0]
final_hidden_states = final_hidden_states[:num_tokens]
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GraniteMoeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
prefix=f"{prefix}.block_sparse_moe") prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer( lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
reduce_results: bool = True, reduce_results: bool = True,
disable_tp: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: layer_type(config=config, lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:

View File

@ -28,7 +28,8 @@ from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
router_scores = torch.sigmoid(router_scores.float()) router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32)) return (router_scores, router_indices.to(torch.int32))
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size, self.router = ReplicatedLinear(config.hidden_size,
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
bias=False, bias=False,
prefix=f"{prefix}.shared_expert", prefix=f"{prefix}.shared_expert",
reduce_results=False, reduce_results=False,
disable_tp=self.is_sequence_parallel,
) )
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
num_tokens = hidden_states.shape[0]
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
shared_out, routed_out = self.experts( shared_out, routed_out = self.experts(
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
) )
experts_out = routed_out + shared_out experts_out = routed_out + shared_out
if self.tp_size > 1: if self.is_sequence_parallel:
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out) experts_out)
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
class Llama4DecoderLayer(nn.Module): class Llama4DecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: Llama4TextConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[Llama4TextConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
self.layer_idx + 1) % config.interleave_moe_layer_step == 0 self.layer_idx + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer: if is_moe_layer:
self.feed_forward = Llama4MoE( self.feed_forward = Llama4MoE(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward", prefix=f"{prefix}.feed_forward",
) )
else: else:

View File

@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Llama4DecoderLayer( Llama4DecoderLayer(
self.config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])
self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
disable_input_layernorm: bool, disable_input_layernorm: bool,
prefix: str = "", prefix: str = "",
config: Optional[LlamaConfig] = None,
) -> None: ) -> None:
super().__init__(config, prefix=prefix) super().__init__(vllm_config, prefix=prefix, config=config)
# Skip the input_layernorm # Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer( LlamaDecoderLayer(
self.config, vllm_config,
i == 0, i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])
self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -9,13 +9,11 @@ import torch.nn as nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -29,17 +27,14 @@ logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer): class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None, super().__init__(vllm_config, prefix=prefix, config=config)
prefix: str = "",
) -> None: config = config or vllm_config.model_config.hf_config
super().__init__(config, quant_config = vllm_config.quant_config
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
# override qkv # override qkv
self.self_attn.qkv_proj = QKVParallelLinear( self.self_attn.qkv_proj = QKVParallelLinear(
@ -127,9 +122,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer( LlamaDecoderLayer(
config=self.config, current_vllm_config,
cache_config=current_vllm_config.cache_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
config=self.config,
) )
]) ])
if hasattr(self.config, "target_hidden_size"): if hasattr(self.config, "target_hidden_size"):

View File

@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
disable_tp=use_data_parallel) disable_tp=use_data_parallel)
self.attn_backend = attn_backend
# Detect attention implementation. self.use_upstream_fa = use_upstream_fa
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
} }
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel) use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen2_5_VisionMLP(dim, self.mlp = Qwen2_5_VisionMLP(dim,
mlp_hidden_dim, mlp_hidden_dim,
act_fn=act_fn, act_fn=act_fn,
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Qwen2_5_VisionBlock(dim=self.hidden_size, Qwen2_5_VisionBlock(
num_heads=self.num_heads, dim=self.hidden_size,
mlp_hidden_dim=vision_config.intermediate_size, num_heads=self.num_heads,
act_fn=get_act_and_mul_fn( mlp_hidden_dim=vision_config.intermediate_size,
vision_config.hidden_act), act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel) use_data_parallel=use_data_parallel,
for layer_idx in range(depth) attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
]) ])
self.merger = Qwen2_5_VisionPatchMerger( self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
) )
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:

View File

@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__) logger = init_logger(__name__)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 32 _MAX_FRAMES_PER_VIDEO = 14
# === Vision Inputs === # # === Vision Inputs === #
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=1,
image_processor=image_processor, image_processor=image_processor,
) )
return num_image_tokens return num_image_tokens
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=9999999, image_width=9999999,
image_height=9999999, image_height=9999999,
num_frames=1,
image_processor=None, image_processor=None,
) )
return max_image_size return max_image_size
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None, image_processor=None,
) )
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 1) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = start_num_frames
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
) -> int: ) -> int:
max_videos = mm_counts.get("video", 0) max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len) max_total_frames = self._get_max_video_frames(seq_len)
max_frames_per_video = min(max_total_frames // max(max_videos, 1), max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO) max_frames_per_video)
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)

View File

@ -29,13 +29,13 @@ from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen3MoeConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3MoeConfig, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
assert hidden_states.dim( assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1 is_input_1d = hidden_states.dim() == 1
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d # return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \ return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states final_hidden_states
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module): class Qwen3MoeDecoderLayer(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self,
config: Qwen3MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and ( if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config, self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
quant_config=quant_config, prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@ -361,11 +372,9 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config.get_text_config() config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
prefix=f"{prefix}.embed_tokens") prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config, lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
cache_config=cache_config, prefix=prefix),
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -580,7 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config

View File

@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config) VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group, from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module): class Qwen3NextSparseMoeBlock(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states) shared_output = self.shared_expert(hidden_states)
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3NextConfig, vllm_config: VllmConfig,
layer_type: str, layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0): (self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock( self.mlp = Qwen3NextSparseMoeBlock(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = Qwen3NextMLP( self.mlp = Qwen3NextMLP(
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__() super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str): def get_layer(prefix: str):
return Qwen3NextDecoderLayer( return Qwen3NextDecoderLayer(
config, vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)], layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix, prefix=prefix,
enable_eplb=enable_eplb,
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(

View File

@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__() super().__init__()
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config config: Qwen3NextConfig = model_config.hf_config
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer( Qwen3NextDecoderLayer(
config, vllm_config,
layer_type="full_attention", layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}', prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers)) ) for idx in range(self.num_mtp_layers))

View File

@ -33,11 +33,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BatchFeature from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
smart_resize as image_smart_resize)
from transformers.models.qwen3_vl import (Qwen3VLProcessor, from transformers.models.qwen3_vl import (Qwen3VLProcessor,
Qwen3VLVideoProcessor) Qwen3VLVideoProcessor)
from transformers.models.qwen3_vl.configuration_qwen3_vl import ( from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig, Qwen3VLVisionConfig) Qwen3VLConfig, Qwen3VLVisionConfig)
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
@ -63,7 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -84,6 +87,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__) logger = init_logger(__name__)
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576
class Qwen3_VisionPatchEmbed(nn.Module): class Qwen3_VisionPatchEmbed(nn.Module):
@ -158,6 +164,8 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -170,7 +178,9 @@ class Qwen3_VisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel) use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen3_VisionMLP(dim, self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim, mlp_hidden_dim,
act_fn=act_fn, act_fn=act_fn,
@ -287,19 +297,6 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(vision_config.depth)
])
self.merger = Qwen3_VisionPatchMerger( self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
@ -325,10 +322,42 @@ class Qwen3_VisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()) head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \ if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability( check_upstream_fa_availability(
torch.get_default_dtype()): torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now.")
if current_platform.is_device_capability(
100) and self.attn_backend != _Backend.TORCH_SDPA:
# TODO(Roger/Wentao): remove this after FA
# or XFORMERS's issue fixed on Blackwell
logger.info_once("Qwen3-VL vision attention does not support "
f"{self.attn_backend} backend on Blackwell now. "
"Vision attention backend is set to TORCH_SDPA.")
self.attn_backend = _Backend.TORCH_SDPA
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa)
for layer_idx in range(vision_config.depth)
])
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -569,11 +598,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 2, num_frames: int = 2,
do_resize: bool = True, do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessorFast], image_processor: Optional[Union[Qwen2VLImageProcessorFast,
Qwen3VLVideoProcessor]],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None: if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
@ -581,12 +615,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
if do_resize: if do_resize:
if is_video:
smart_resize = video_smart_resize
extra_kwargs = {
"num_frames": num_frames,
"temporal_factor": temporal_patch_size
}
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * merge_size, factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"], min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"], max_pixels=image_processor.size["longest_edge"],
**extra_kwargs,
) )
preprocessed_size = ImageSize(width=resized_width, preprocessed_size = ImageSize(width=resized_width,
height=resized_height) height=resized_height)
@ -605,6 +649,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
return preprocessed_size, num_vision_tokens return preprocessed_size, num_vision_tokens
def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 2) -> int:
return super()._get_max_video_frames(max_tokens,
start_num_frames=start_num_frames)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
return super().get_num_frames_with_most_features(
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
def _calculate_timestamps(self, indices: list[int] | torch.Tensor, def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
video_fps: float, merge_size: int): video_fps: float, merge_size: int):
if not isinstance(indices, list): if not isinstance(indices, list):
@ -674,6 +751,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self.info.get_image_size_with_most_features()) self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features( target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts) seq_len, mm_counts)
target_video_size, _ = self.info._get_vision_info(
image_width=target_width,
image_height=target_height,
num_frames=target_num_frames,
image_processor=self.info.get_video_processor(),
)
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
@ -681,8 +764,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
num_images=num_images), num_images=num_images),
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_video_size.width,
height=target_height, height=target_video_size.height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
), ),

View File

@ -212,6 +212,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
# attempted to load as other weights later # attempted to load as other weights later
is_expert_weight = True is_expert_weight = True
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert: if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, loaded_weight = loaded_weight.transpose(-1,
-2) # no bias -2) # no bias
@ -230,8 +232,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
name_mapped, params_dict, loaded_weight, name_mapped, params_dict, loaded_weight,
shard_id, num_experts) shard_id, num_experts)
else: else:
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models # Skip loading extra parameters for GPTQ/modelopt models
if name_mapped.endswith( if name_mapped.endswith(
ignore_suffixes ignore_suffixes

View File

@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available) is_uva_available)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size return hf_config.hidden_size
text_config = hf_config.get_text_config() text_config = hf_config.get_text_config()
return text_config.hidden_size return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -50,6 +50,7 @@ class MediaConnector:
connection: HTTPConnection = global_http_connection, connection: HTTPConnection = global_http_connection,
*, *,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
) -> None: ) -> None:
""" """
Args: Args:
@ -82,6 +83,9 @@ class MediaConnector:
allowed_local_media_path_ = None allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_ self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url( def _load_data_url(
self, self,
@ -115,6 +119,14 @@ class MediaConnector:
return media_io.load_file(filepath) return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if self.allowed_media_domains and url_spec.hostname not in \
self.allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}")
def load_from_url( def load_from_url(
self, self,
url: str, url: str,
@ -125,6 +137,8 @@ class MediaConnector:
url_spec = urlparse(url) url_spec = urlparse(url)
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout) data = connection.get_bytes(url, timeout=fetch_timeout)
@ -150,6 +164,8 @@ class MediaConnector:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout) data = await connection.async_get_bytes(url, timeout=fetch_timeout)
future = loop.run_in_executor(global_thread_pool, future = loop.run_in_executor(global_thread_pool,

View File

@ -1288,4 +1288,9 @@ class Scheduler(SchedulerInterface):
self.finished_recving_kv_req_ids.add(req_id) self.finished_recving_kv_req_ids.add(req_id)
for req_id in (kv_connector_output.finished_sending or ()): for req_id in (kv_connector_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id) logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id]) if req_id not in self.requests:
logger.warning(
"Got finished sending KV transfer for request %s,"
"but the request is already freed.", req_id)
else:
self._free_blocks(self.requests[req_id])

View File

@ -3351,6 +3351,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
expected_num_items=max_mm_items_per_batch, expected_num_items=max_mm_items_per_batch,
) )
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (encode_budget, hidden_size) and scatter encoder
# output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
if encoder_output_shape[0] < encoder_budget:
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(encoder_budget, encoder_output_shape[-1]))
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)
dummy_encoder_outputs = expanded_outputs
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict( self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs)) enumerate(dummy_encoder_outputs))
@ -3468,8 +3485,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases: for num_tokens in compilation_cases:
# We currently only capture ubatched graphs when its a FULL # We currently only capture ubatched graphs when its a FULL
# cudagraph and for uniform decode batches. # cudagraph, a uniform decode batch, and the number of tokens
capture_ubatched_graph = self.parallel_config.enable_dbo \ # is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph
allow_microbatching = self.parallel_config.enable_dbo \
and cudagraph_runtime_mode == CUDAGraphMode.FULL \ and cudagraph_runtime_mode == CUDAGraphMode.FULL \
and uniform_decode \ and uniform_decode \
and check_ubatch_thresholds( and check_ubatch_thresholds(
@ -3478,37 +3497,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
) )
# Currently we capture both microbatched and non-microbatched for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# graphs when capture_ubatched_graph is True, this is because # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# occasionally we will be forced out of microbatching due to other # But be careful, warm up with `NONE`is orthogonal to
# DP ranks not microbatching (usually caused by an empty second # if we want to warm up attention or not. This is
# microbatch; once we resolve this, we can remove the # different from the case where `FULL` implies capture
# non-microbatched graph capture). # attention while `PIECEWISE` implies no attention.
allow_microbatching_options = [True, False] if \ force_attention = (
capture_ubatched_graph else [False] cudagraph_runtime_mode == CUDAGraphMode.FULL)
for allow_microbatching in allow_microbatching_options:
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
self._dummy_run(num_tokens, self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
skip_eplb=True, skip_eplb=True,
remove_lora=False) remove_lora=False)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
self.maybe_remove_all_loras(self.lora_config) self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

View File

@ -330,6 +330,18 @@ class UBatchWrapper:
# If there's no ubatching, just run the runnable object # If there's no ubatching, just run the runnable object
if ubatch_slices is None: if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched cudagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
cudagraph_runtime_mode = CUDAGraphMode.NONE
if cudagraph_runtime_mode in (CUDAGraphMode.NONE, if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
CUDAGraphMode.PIECEWISE): CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs) return self.runnable(*args, **kwargs)