mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
13 Commits
zhuohan/mo
...
v0.11.0rc3
Author | SHA1 | Date | |
---|---|---|---|
8ce5d3198d | |||
09c2cbc04a | |||
4c347044c9 | |||
19e7ab7315 | |||
6de3d431d9 | |||
b14773bd64 | |||
26a7a33b88 | |||
5aa5811a16 | |||
c2fa2d4dc9 | |||
32335c8b34 | |||
04c2b26972 | |||
ee10d7e6ff | |||
bb79c4da2f |
@ -76,7 +76,7 @@ steps:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
|
||||
|
||||
# Add job to create multi-arch manifest
|
||||
|
@ -404,6 +404,9 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
|
||||
uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1)
|
||||
# Build AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
|
@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
||||
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
|
||||
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
|
||||
|
||||
!!! tip
|
||||
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||
|
||||
## Offline Inference
|
||||
|
||||
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
|
||||
|
@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
|
||||
- Implement proper authentication and authorization for management interfaces
|
||||
- Follow the principle of least privilege for all system components
|
||||
|
||||
### 4. **Restrict Domains Access for Media URLs:**
|
||||
|
||||
Restrict domains that vLLM can access for media URLs by setting
|
||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||
|
||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||
|
||||
While vLLM is designed to allow unsafe network services to be isolated to
|
||||
|
@ -45,6 +45,7 @@ class MockModelConfig:
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
|
@ -240,6 +240,7 @@ class MockModelConfig:
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format,
|
||||
resolve_chat_template_kwargs,
|
||||
resolve_hf_chat_template)
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
||||
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_kwargs",
|
||||
[
|
||||
(
|
||||
QWEN2VL_MODEL_ID,
|
||||
{
|
||||
"add_vision_id", "add_generation_prompt",
|
||||
"continue_final_message", "tools"
|
||||
},
|
||||
),
|
||||
(
|
||||
QWEN3_MODEL_ID,
|
||||
{
|
||||
"enable_thinking", "add_generation_prompt",
|
||||
"continue_final_message", "tools"
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
|
||||
expected_kwargs):
|
||||
"""checks that chat_template is a dict type for HF models."""
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
tools = ([{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema,
|
||||
},
|
||||
}])
|
||||
|
||||
chat_template_kwargs = {
|
||||
# both unused
|
||||
"unsed_kwargs_1": 123,
|
||||
"unsed_kwargs_2": "abc",
|
||||
# should not appear
|
||||
"chat_template": "{% Hello world! %}",
|
||||
# used by tokenizer
|
||||
"continue_final_message": True,
|
||||
"tools": tools,
|
||||
# both used by Qwen2-VL and Qwen3
|
||||
"add_generation_prompt": True,
|
||||
# only used by Qwen2-VL
|
||||
"add_vision_id": True,
|
||||
# only used by Qwen3
|
||||
"enable_thinking": True,
|
||||
}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model,
|
||||
tokenizer=model_info.tokenizer or model,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
# Build the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
|
||||
|
||||
|
||||
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
||||
# processor class instead of using `tokenizer_config.json`
|
||||
# yapf: disable
|
||||
|
@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
|
||||
raw_image_url: str, suffix: str):
|
||||
connector = MediaConnector()
|
||||
connector = MediaConnector(
|
||||
# Domain restriction should not apply to data URLs.
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
])
|
||||
url_image = url_images[raw_image_url]
|
||||
|
||||
try:
|
||||
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
async def test_allowed_media_domains(video_url: str, num_frames: int):
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={"video": {
|
||||
"num_frames": num_frames,
|
||||
}},
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
])
|
||||
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = connector.fetch_video(disallowed_url)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = await connector.fetch_video_async(disallowed_url)
|
||||
|
@ -137,6 +137,9 @@ class ModelConfig:
|
||||
"""Allowing API requests to read local images or videos from directories
|
||||
specified by the server file system. This is a security risk. Should only
|
||||
be enabled in trusted environments."""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
"""If set, only media URLs that belong to this domain can be used for
|
||||
multi-modal inputs. """
|
||||
revision: Optional[str] = None
|
||||
"""The specific model version to use. It can be a branch name, a tag name,
|
||||
or a commit id. If unspecified, will use the default version."""
|
||||
|
@ -279,6 +279,24 @@ class ParallelConfig:
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("allgather_reducescatter", "naive",
|
||||
"deepep_high_throughput", "deepep_low_latency")
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
and self.data_parallel_size > 1)
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup,
|
||||
has_unfinished: bool) -> bool:
|
||||
|
@ -281,6 +281,8 @@ class SpeculativeConfig:
|
||||
trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool) -> torch.Tensor:
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = (self.world_size
|
||||
if is_sequence_parallel else self.dp_world_size)
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
for idx in range(world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||
end = cu_tokens_across_sp_cpu[idx]
|
||||
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||
|
||||
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
self.initialized = False
|
||||
|
@ -28,6 +28,8 @@ class Cache:
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
@ -40,6 +42,7 @@ class All2AllManagerBase:
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
@ -60,17 +63,21 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
|
@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
|
||||
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
SymmMemCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return output_list
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
return hidden_states
|
||||
|
@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
return hidden_states
|
||||
|
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
self.reqs_in_batch: set[ReqId] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
|
||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
self._reqs_in_batch: set[ReqId] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
self._reqs_in_batch.add(request.request_id)
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
load_remote_cache=True,
|
||||
save_to_host=False,
|
||||
)
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
|
||||
)
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
meta.reqs_in_batch = self._reqs_in_batch
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_save.clear()
|
||||
self._reqs_in_batch = set()
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
@ -465,8 +474,11 @@ class NixlConnectorWorker:
|
||||
"backends", ["UCX"])
|
||||
# Agent.
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
config = nixl_agent_config(backends=self.nixl_backends) if len(
|
||||
non_ucx_backends) > 0 and nixl_agent_config is not None else None
|
||||
if nixl_agent_config is None:
|
||||
config = None
|
||||
else:
|
||||
config = nixl_agent_config(backends=self.nixl_backends) if len(
|
||||
non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
|
||||
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
@ -546,6 +558,8 @@ class NixlConnectorWorker:
|
||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||
# Track the expiration time of requests that are waiting to be sent.
|
||||
self._reqs_to_send: dict[ReqId, float] = {}
|
||||
# Set of requests that have been part of a batch, regardless of status.
|
||||
self._reqs_to_process: set[ReqId] = set()
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
@ -1082,6 +1096,7 @@ class NixlConnectorWorker:
|
||||
"Releasing expired KV blocks for request %s which were "
|
||||
"retrieved by %d decode worker(s) within %d seconds.", req_id,
|
||||
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
|
||||
self._reqs_to_process.remove(req_id)
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
@ -1097,7 +1112,8 @@ class NixlConnectorWorker:
|
||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||
for notif in notifs:
|
||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||
if req_id not in self._reqs_to_send:
|
||||
if (req_id not in self._reqs_to_send
|
||||
and req_id not in self._reqs_to_process):
|
||||
logger.error(
|
||||
"Potentially invalid KV blocks for "
|
||||
"unrecognized request %s were retrieved by "
|
||||
@ -1110,7 +1126,8 @@ class NixlConnectorWorker:
|
||||
tp_ratio):
|
||||
notified_req_ids.add(req_id)
|
||||
del self.consumer_notification_counts_by_req[req_id]
|
||||
del self._reqs_to_send[req_id]
|
||||
self._reqs_to_process.remove(req_id)
|
||||
self._reqs_to_send.pop(req_id, None)
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(
|
||||
@ -1171,8 +1188,19 @@ class NixlConnectorWorker:
|
||||
while not self._ready_requests.empty():
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
|
||||
# Keep around the requests that have been part of a batch. This is
|
||||
# needed because async scheduling pushes the misalignment between the
|
||||
# moment in which requests expiration is set (P side) and the moment in
|
||||
# which blocks are read from D. As P can now more easily lag behind D
|
||||
# while processing the next batch, we make sure to only set an
|
||||
# expiration for requests that have not been read from D yet.
|
||||
for req_id in metadata.reqs_in_batch:
|
||||
self._reqs_to_process.add(req_id)
|
||||
|
||||
# Add to requests that are waiting to be read and track expiration.
|
||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||
for req_id, expiration_time in metadata.reqs_to_send.items():
|
||||
if req_id in self._reqs_to_process:
|
||||
self._reqs_to_send[req_id] = expiration_time
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
logger.debug(
|
||||
|
@ -871,17 +871,24 @@ class GroupCoordinator:
|
||||
model)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(hidden_states,
|
||||
router_logits)
|
||||
router_logits,
|
||||
is_sequence_parallel)
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.combine(hidden_states)
|
||||
return self.device_communicator.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
@ -297,6 +297,8 @@ class EngineArgs:
|
||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
allowed_media_domains: Optional[
|
||||
list[str]] = ModelConfig.allowed_media_domains
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
@ -531,6 +533,8 @@ class EngineArgs:
|
||||
**model_kwargs["hf_config_path"])
|
||||
model_group.add_argument("--allowed-local-media-path",
|
||||
**model_kwargs["allowed_local_media_path"])
|
||||
model_group.add_argument("--allowed-media-domains",
|
||||
**model_kwargs["allowed_media_domains"])
|
||||
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||
model_group.add_argument("--code-revision",
|
||||
**model_kwargs["code_revision"])
|
||||
@ -997,6 +1001,7 @@ class EngineArgs:
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
allowed_local_media_path=self.allowed_local_media_path,
|
||||
allowed_media_domains=self.allowed_media_domains,
|
||||
dtype=self.dtype,
|
||||
seed=self.seed,
|
||||
revision=self.revision,
|
||||
|
@ -11,7 +11,12 @@ from pathlib import Path
|
||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||
cast)
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import random_uuid, supports_kw
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@property
|
||||
def allowed_media_domains(self):
|
||||
return self._model_config.allowed_media_domains
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return MULTIMODAL_REGISTRY
|
||||
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
fn_kw = {
|
||||
k for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template"}
|
||||
accept_vars = (fn_kw | template_vars) - unexpected_vars
|
||||
return {
|
||||
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
|
||||
}
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: list[ConversationMessage],
|
||||
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
|
||||
)
|
||||
|
||||
try:
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=hf_chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=hf_chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
|
@ -86,6 +86,8 @@ class LLM:
|
||||
or videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@ -169,6 +171,7 @@ class LLM:
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: ModelDType = "auto",
|
||||
quantization: Optional[QuantizationMethods] = None,
|
||||
@ -264,6 +267,7 @@ class LLM:
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
allowed_media_domains=allowed_media_domains,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
|
@ -3,12 +3,14 @@
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing.forkserver as forkserver
|
||||
import os
|
||||
import secrets
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization header exists and equals "Bearer {api_key}".
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = {f"Bearer {token}" for token in tokens}
|
||||
self.api_tokens = [
|
||||
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
|
||||
]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and headers.get(
|
||||
"Authorization") not in self.api_tokens:
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return response(scope, receive, send)
|
||||
@ -1696,6 +1716,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
exclude_tools_when_tool_choice_none=args.
|
||||
|
@ -103,9 +103,13 @@ class FrontendArgs:
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to OpenAI
|
||||
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
ssl_keyfile: Optional[str] = None
|
||||
|
@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
reasoning_parser: str = "",
|
||||
enable_auto_tools: bool = False,
|
||||
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.response_role = response_role
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
|
||||
# set up tool use
|
||||
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if not self.use_harmony:
|
||||
# Common case.
|
||||
request_chat_template = request.chat_template
|
||||
chat_template_kwargs = request.chat_template_kwargs
|
||||
if not self.trust_request_chat_template and (
|
||||
request_chat_template is not None or
|
||||
(chat_template_kwargs and
|
||||
chat_template_kwargs.get("chat_template") is not None)):
|
||||
return self.create_error_response(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template.")
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template=request_chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
|
@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
|
||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int) -> list[int]:
|
||||
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
|
||||
sequence_parallel_size)
|
||||
|
||||
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
|
||||
return sp_tokens.tolist()
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int,
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
dp_size = len(num_tokens_across_dp_cpu)
|
||||
|
||||
local_size = [-1] * dp_size
|
||||
for i in range(dp_size):
|
||||
dp_tokens = num_tokens_across_dp_cpu[i]
|
||||
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
|
||||
sequence_parallel_size)
|
||||
sp_size = len(sp_tokens)
|
||||
|
||||
local_size = [-1] * sp_size
|
||||
for i in range(sp_size):
|
||||
# Take into account sharding if MoE activation is sequence parallel.
|
||||
local_size[i] = min(max_num_tokens,
|
||||
dp_tokens - (max_num_tokens * chunk_idx))
|
||||
sp_tokens[i] - (max_num_tokens * chunk_idx))
|
||||
if local_size[i] <= 0:
|
||||
local_size[i] = 1 # ensure lockstep even if done
|
||||
return local_size
|
||||
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
cu_tokens_across_dp_cpu: torch.Tensor
|
||||
num_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
@ -98,6 +113,17 @@ class DPMetadata:
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
|
||||
num_tokens_across_sp_cpu = (
|
||||
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||
@ -147,10 +173,10 @@ class DPMetadata:
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
|
||||
) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
@ -167,18 +193,18 @@ class DPMetadata:
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
|
||||
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp is None:
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
assert (num_tokens_across_dp_cpu is None
|
||||
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp_cpu is None:
|
||||
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
|
||||
num_tokens_across_dp)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
def chunked_sizes(self, sequence_parallel_size: int,
|
||||
max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
"""
|
||||
Context manager to compute and temporarily set the per-rank local token
|
||||
sizes for a specific chunk during chunked forward execution.
|
||||
@ -192,31 +218,40 @@ class DPMetadata:
|
||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||
of tokens to process in that chunk on each rank.
|
||||
|
||||
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
||||
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
||||
to determine the chunk-wise split.
|
||||
|
||||
`self.local_sizes` is only valid inside the context.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: When Attn is TP and MoE layers are EP,
|
||||
we use SP between the layers to avoid
|
||||
redundant ops. We need this value to
|
||||
compute the chunked sizes.
|
||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||
allowed to process in this chunk.
|
||||
chunk_idx: The index of the chunk to compute sizes for.
|
||||
"""
|
||||
cu_sizes = self.cu_tokens_across_dp_cpu
|
||||
num_tokens_across_dp_cpu = [
|
||||
(cu_sizes[i] -
|
||||
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
||||
for i in range(len(cu_sizes))
|
||||
]
|
||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size,
|
||||
max_chunk_size_per_rank, chunk_idx)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
@contextmanager
|
||||
def sp_local_sizes(self, sequence_parallel_size: int):
|
||||
"""
|
||||
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
|
||||
but without any chunking.
|
||||
"""
|
||||
self.local_sizes = _compute_sp_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||
assert self.local_sizes is not None
|
||||
return self.local_sizes
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union, get_args, overload
|
||||
|
||||
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
|
||||
with ctx.dp_metadata.chunked_sizes(self.sp_size,
|
||||
moe_dp_chunk_size_per_rank,
|
||||
chunk_idx):
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = ctx.dp_metadata.sp_local_sizes(
|
||||
self.sp_size) if ctx.dp_metadata else nullcontext()
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
with sp_ctx:
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits, self.is_sequence_parallel)
|
||||
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states,
|
||||
self.is_sequence_parallel)
|
||||
|
||||
return states
|
||||
if (not self.is_sequence_parallel and self.reduce_results
|
||||
and (self.tp_size > 1 or self.ep_size > 1)):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
states)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||
from transformers.models.aria.processing_aria import AriaProcessor
|
||||
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.config import QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -298,14 +298,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
||||
Experts (MoE) Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaTextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, cache_config, quant_config, prefix)
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config, prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.mlp = AriaTextMoELayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
@ -32,7 +32,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
|
||||
chunk = x.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(x, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("deepep_high_throughput",
|
||||
"deepep_low_latency")
|
||||
and parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# TODO: We can replace the all_reduce at the end of attn with a
|
||||
# reduce_scatter instead of chunking here.
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
||||
hidden_states)
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
@ -29,10 +29,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.mtp_emb_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
||||
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
|
||||
prefix)
|
||||
self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
ErnieMultiTokenPredictorLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
|
@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
|
||||
|
||||
class Glm4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[Glm4Config] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
vllm_config: VllmConfig,
|
||||
layer_idx: int,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False,
|
||||
has_bias=True,
|
||||
activation="swigluoai")
|
||||
activation="swigluoai",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
if self.is_sequence_parallel:
|
||||
x = sequence_parallel_chunk(x)
|
||||
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
x = x[:num_tokens]
|
||||
return x
|
||||
|
||||
|
||||
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
cache_config: CacheConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.attn = OAIAttention(config,
|
||||
prefix=f"{prefix}.attn",
|
||||
cache_config=cache_config)
|
||||
self.mlp = MLPBlock(config,
|
||||
self.mlp = MLPBlock(vllm_config,
|
||||
self.layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.config.num_hidden_layers,
|
||||
lambda prefix: TransformerBlock(
|
||||
self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
vllm_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
|
@ -29,12 +29,13 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.granitemoe import GraniteMoeConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
is_sequence_parallel=False,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(hidden_size,
|
||||
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts")
|
||||
prefix=f"{prefix}.experts",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
num_tokens = orig_shape[0]
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteMoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
|
||||
prefix=f"{prefix}.block_sparse_moe")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GraniteMoeDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
disable_tp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
|
@ -28,7 +28,8 @@ from vllm.attention import Attention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
|
||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
|
||||
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
|
||||
router_scores = torch.sigmoid(router_scores.float())
|
||||
return (router_scores, router_indices.to(torch.int32))
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
intermediate_size_moe = config.intermediate_size
|
||||
self.router = ReplicatedLinear(config.hidden_size,
|
||||
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False,
|
||||
disable_tp=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
|
||||
shared_out, routed_out = self.experts(
|
||||
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
|
||||
)
|
||||
experts_out = routed_out + shared_out
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.is_sequence_parallel:
|
||||
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
|
||||
experts_out = experts_out[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
experts_out)
|
||||
|
||||
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
|
||||
|
||||
class Llama4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[Llama4TextConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
|
@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Llama4DecoderLayer(
|
||||
self.config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
config=self.config,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||
|
@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
vllm_config: VllmConfig,
|
||||
disable_input_layernorm: bool,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, prefix=prefix)
|
||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||
|
||||
# Skip the input_layernorm
|
||||
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
||||
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
vllm_config,
|
||||
i == 0,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
config=self.config,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||
|
@ -9,13 +9,11 @@ import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -29,17 +27,14 @@ logger = init_logger(__name__)
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
@ -127,9 +122,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
config=self.config,
|
||||
cache_config=current_vllm_config.cache_config,
|
||||
current_vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||
config=self.config,
|
||||
)
|
||||
])
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
|
@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.attn_backend = attn_backend
|
||||
self.use_upstream_fa = use_upstream_fa
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
use_upstream_fa=use_upstream_fa)
|
||||
self.mlp = Qwen2_5_VisionMLP(dim,
|
||||
mlp_hidden_dim,
|
||||
act_fn=act_fn,
|
||||
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
use_upstream_fa = False
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen2_5_VisionBlock(dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(
|
||||
vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(depth)
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
|
||||
])
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 32
|
||||
_MAX_FRAMES_PER_VIDEO = 14
|
||||
|
||||
# === Vision Inputs === #
|
||||
|
||||
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
_, num_image_tokens = self._get_vision_info(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
num_frames=1,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
return num_image_tokens
|
||||
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
max_image_size, _ = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
num_frames=1,
|
||||
image_processor=None,
|
||||
)
|
||||
return max_image_size
|
||||
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
def _get_max_video_frames(self,
|
||||
max_tokens: int,
|
||||
start_num_frames: int = 1) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
num_frames = start_num_frames
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
|
||||
) -> int:
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
max_frames_per_video)
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
|
@ -29,13 +29,13 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen3MoeConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
assert hidden_states.dim(
|
||||
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
|
||||
is_input_1d = hidden_states.dim() == 1
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
|
||||
# return to 1d if input is 1d
|
||||
return final_hidden_states.squeeze(0) if is_input_1d else \
|
||||
final_hidden_states
|
||||
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3MoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb)
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
@ -361,11 +372,9 @@ class Qwen3MoeModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config.get_text_config()
|
||||
cache_config = vllm_config.cache_config
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Qwen3MoeDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb),
|
||||
lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -580,7 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
VllmConfig, get_current_vllm_config)
|
||||
from vllm.distributed import (divide, get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states)
|
||||
|
||||
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
vllm_config: VllmConfig,
|
||||
layer_type: str,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
config.num_experts > 0 and
|
||||
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen3NextSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Qwen3NextMLP(
|
||||
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
self.config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
self.config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
|
||||
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config: Qwen3NextConfig = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
lora_config = vllm_config.lora_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
|
||||
|
||||
def get_layer(prefix: str):
|
||||
return Qwen3NextDecoderLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
layer_type=config.layer_types[extract_layer_index(prefix)],
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
speculative_config=speculative_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
|
@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
config: Qwen3NextConfig = model_config.hf_config
|
||||
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
Qwen3NextDecoderLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
layer_type="full_attention",
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.layers.{idx}',
|
||||
) for idx in range(self.num_mtp_layers))
|
||||
|
||||
|
@ -33,11 +33,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
smart_resize as image_smart_resize)
|
||||
from transformers.models.qwen3_vl import (Qwen3VLProcessor,
|
||||
Qwen3VLVideoProcessor)
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
|
||||
Qwen3VLConfig, Qwen3VLVisionConfig)
|
||||
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
||||
smart_resize as video_smart_resize)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
@ -63,7 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import is_list_of
|
||||
@ -84,6 +87,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Official recommended max pixels is 24576 * 32 * 32
|
||||
_MAX_FRAMES_PER_VIDEO = 24576
|
||||
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
|
||||
@ -158,6 +164,8 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -170,7 +178,9 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
use_upstream_fa=use_upstream_fa)
|
||||
self.mlp = Qwen3_VisionMLP(dim,
|
||||
mlp_hidden_dim,
|
||||
act_fn=act_fn,
|
||||
@ -287,19 +297,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
self.merger = Qwen3_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
@ -325,10 +322,42 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now.")
|
||||
if current_platform.is_device_capability(
|
||||
100) and self.attn_backend != _Backend.TORCH_SDPA:
|
||||
# TODO(Roger/Wentao): remove this after FA
|
||||
# or XFORMERS's issue fixed on Blackwell
|
||||
logger.info_once("Qwen3-VL vision attention does not support "
|
||||
f"{self.attn_backend} backend on Blackwell now. "
|
||||
"Vision attention backend is set to TORCH_SDPA.")
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -569,11 +598,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
image_height: int,
|
||||
num_frames: int = 2,
|
||||
do_resize: bool = True,
|
||||
image_processor: Optional[Qwen2VLImageProcessorFast],
|
||||
image_processor: Optional[Union[Qwen2VLImageProcessorFast,
|
||||
Qwen3VLVideoProcessor]],
|
||||
) -> tuple[ImageSize, int]:
|
||||
if image_processor is None:
|
||||
if image_processor is None and num_frames > 1:
|
||||
image_processor = self.get_video_processor()
|
||||
elif image_processor is None:
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
|
||||
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
patch_size = vision_config.patch_size
|
||||
@ -581,12 +615,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
if do_resize:
|
||||
if is_video:
|
||||
smart_resize = video_smart_resize
|
||||
extra_kwargs = {
|
||||
"num_frames": num_frames,
|
||||
"temporal_factor": temporal_patch_size
|
||||
}
|
||||
else:
|
||||
smart_resize = image_smart_resize
|
||||
extra_kwargs = {}
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=image_processor.size["shortest_edge"],
|
||||
max_pixels=image_processor.size["longest_edge"],
|
||||
**extra_kwargs,
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width,
|
||||
height=resized_height)
|
||||
@ -605,6 +649,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
|
||||
return preprocessed_size, num_vision_tokens
|
||||
|
||||
def _get_max_video_frames(self,
|
||||
max_tokens: int,
|
||||
start_num_frames: int = 2) -> int:
|
||||
return super()._get_max_video_frames(max_tokens,
|
||||
start_num_frames=start_num_frames)
|
||||
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
return super().get_num_frames_with_most_features(
|
||||
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
def get_max_video_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
video_soft_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
||||
return int(formatted_video_soft_tokens)
|
||||
|
||||
def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
|
||||
video_fps: float, merge_size: int):
|
||||
if not isinstance(indices, list):
|
||||
@ -674,6 +751,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
||||
self.info.get_image_size_with_most_features())
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts)
|
||||
target_video_size, _ = self.info._get_vision_info(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
image_processor=self.info.get_video_processor(),
|
||||
)
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
@ -681,8 +764,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
width=target_video_size.width,
|
||||
height=target_video_size.height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
),
|
||||
|
@ -212,6 +212,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
if is_fused_expert:
|
||||
loaded_weight = loaded_weight.transpose(-1,
|
||||
-2) # no bias
|
||||
@ -230,8 +232,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
name_mapped, params_dict, loaded_weight,
|
||||
shard_id, num_experts)
|
||||
else:
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
# Skip loading extra parameters for GPTQ/modelopt models
|
||||
if name_mapped.endswith(
|
||||
ignore_suffixes
|
||||
|
@ -13,11 +13,14 @@ from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
from vllm.utils import (cdiv, direct_register_custom_op,
|
||||
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
is_uva_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
|
||||
return hf_config.hidden_size
|
||||
text_config = hf_config.get_text_config()
|
||||
return text_config.hidden_size
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
y = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
else:
|
||||
y = x
|
||||
|
||||
chunk = y.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(y, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk_impl",
|
||||
op_func=sequence_parallel_chunk_impl,
|
||||
fake_impl=sequence_parallel_chunk_impl_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
@ -50,6 +50,7 @@ class MediaConnector:
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -82,6 +83,9 @@ class MediaConnector:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
@ -115,6 +119,14 @@ class MediaConnector:
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
if self.allowed_media_domains and url_spec.hostname not in \
|
||||
self.allowed_media_domains:
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}")
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
@ -125,6 +137,8 @@ class MediaConnector:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
@ -150,6 +164,8 @@ class MediaConnector:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
|
@ -1288,4 +1288,9 @@ class Scheduler(SchedulerInterface):
|
||||
self.finished_recving_kv_req_ids.add(req_id)
|
||||
for req_id in (kv_connector_output.finished_sending or ()):
|
||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||
self._free_blocks(self.requests[req_id])
|
||||
if req_id not in self.requests:
|
||||
logger.warning(
|
||||
"Got finished sending KV transfer for request %s,"
|
||||
"but the request is already freed.", req_id)
|
||||
else:
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
@ -3351,6 +3351,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
expected_num_items=max_mm_items_per_batch,
|
||||
)
|
||||
|
||||
# NOTE: This happens when encoder cache needs to store
|
||||
# the embeddings that encoder outputs are scattered onto.
|
||||
# In this case we create dummy embeddings of size
|
||||
# (encode_budget, hidden_size) and scatter encoder
|
||||
# output into it.
|
||||
encoder_output_shape = dummy_encoder_outputs[0].shape
|
||||
if encoder_output_shape[0] < encoder_budget:
|
||||
expanded_outputs = []
|
||||
for output in dummy_encoder_outputs:
|
||||
expanded = output.new_zeros(
|
||||
(encoder_budget, encoder_output_shape[-1]))
|
||||
num_tokens = output.shape[0]
|
||||
expanded[:num_tokens].copy_(output)
|
||||
expanded_outputs.append(expanded)
|
||||
|
||||
dummy_encoder_outputs = expanded_outputs
|
||||
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(
|
||||
enumerate(dummy_encoder_outputs))
|
||||
@ -3468,8 +3485,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph and for uniform decode batches.
|
||||
capture_ubatched_graph = self.parallel_config.enable_dbo \
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
# version of the graph
|
||||
allow_microbatching = self.parallel_config.enable_dbo \
|
||||
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||
and uniform_decode \
|
||||
and check_ubatch_thresholds(
|
||||
@ -3478,37 +3497,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# Currently we capture both microbatched and non-microbatched
|
||||
# graphs when capture_ubatched_graph is True, this is because
|
||||
# occasionally we will be forced out of microbatching due to other
|
||||
# DP ranks not microbatching (usually caused by an empty second
|
||||
# microbatch; once we resolve this, we can remove the
|
||||
# non-microbatched graph capture).
|
||||
allow_microbatching_options = [True, False] if \
|
||||
capture_ubatched_graph else [False]
|
||||
for allow_microbatching in allow_microbatching_options:
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
@ -330,6 +330,18 @@ class UBatchWrapper:
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
|
||||
# This is to account for the case where ubatching was aborted.
|
||||
# When we capture full graphs we only capture one graph per shape,
|
||||
# meaning that if we have a ubatched cudagraph for the current
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the cudagraph wrapper will try to capture a cudagraph
|
||||
# for this shape during a normal run.
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.cudagraphs:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
|
||||
CUDAGraphMode.PIECEWISE):
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user