Compare commits

...

2 Commits

Author SHA1 Message Date
296c6572dd Revert "[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (#15906)"
This reverts commit 651cf0fec19ed1da44ad266066f86b74e5246006.
2025-04-05 21:10:57 -07:00
c575232395 [Model] Support Llama4 in vLLM (#16104) 2025-04-05 21:01:00 -07:00
39 changed files with 2438 additions and 255 deletions

View File

@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
if not hasattr(config, "hidden_size"):
# Support for llama4
config = config.text_config
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok

View File

@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi
To check if the modeling backend is Transformers, you can simply do this:
```python
```python
from vllm import LLM
llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
```python
```python
from vllm import LLM
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
@ -840,6 +840,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Llama4ForConditionalGeneration`
* Llama-4-17B-Omni-Instruct
* T + I<sup>+</sup>
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
*
*
* ✅︎
- * `LlavaForConditionalGeneration`
* LLaVA-1.5
* T + I<sup>E+</sup>
@ -982,10 +989,10 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
:::
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
:::{important}

View File

@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)
def run_llama4(questions: list[str], modality: str):
assert modality == "image"
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
gpu_memory_utilization=0.4,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{
"role":
"user",
"content": [{
"type": "image"
}, {
"type": "text",
"text": f"{question}"
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
stop_token_ids = None
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# Molmo
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -907,6 +943,7 @@ model_example_map = {
"minicpmv": run_minicpmv,
"mistral3": run_mistral3,
"mllama": run_mllama,
"llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
"paligemma": run_paligemma,

View File

@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@ -567,6 +604,7 @@ model_example_map = {
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
"llama4": load_llama4,
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,

View File

@ -6,7 +6,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.50.3
transformers >= 4.51.0
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.

View File

@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.50.3
transformers==4.51.0
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3

View File

@ -645,7 +645,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.50.3
transformers==4.51.0
# via
# -r requirements/test.in
# genai-perf

View File

@ -536,6 +536,22 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 1},
)],
),
"llama4": VLMTestInfo(
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
img_idx_to_prompt=lambda _: "<|image|>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
distributed_executor_backend="mp",
image_size_factors=[(.25, 0.5, 1.0)],
hf_model_kwargs={"device_map": "auto"},
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
auto_cls=AutoModelForImageTextToText,
tensor_parallel_size=8,
vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
marks=[large_gpu_mark(min_gb=80), multi_gpu_marks(num_gpus=8)],
),
}
# yapf: enable

View File

@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])

View File

@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for Llama4's multimodal preprocessing kwargs."""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.tokenizer import encode_tokens
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id",
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
@pytest.mark.parametrize("num_imgs", [1, 5])
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
@pytest.mark.parametrize("tokenized_prompt", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict,
num_imgs: int,
disable_mm_preprocessor_cache: bool,
tokenized_prompt: bool,
):
"""Ensure llama4 processor works properly."""
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": num_imgs},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
config = processor.info.get_hf_config()
tokenizer = processor.info.get_tokenizer()
hf_processor = processor.info.get_hf_processor()
vocab = tokenizer.get_vocab()
prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \
+ "<|image|>" * num_imgs \
+ "<|eot|><|header_start|>assistant<|header_end|>"
mm_data = {
"image": [
image_assets[(i % len(image_assets))].pil_image
for i in range(num_imgs)
]
}
if tokenized_prompt:
prompt = encode_tokens(tokenizer, prompt)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"]
# place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"]
num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1:
num_x_separators += (tiles_x - 1) * tiles_y
num_y_separators += tiles_y
assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \
== num_x_separators
assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \
== num_y_separators
# image token offsets
img_locs = processed_inputs["mm_placeholders"].get("image", [])
assert len(img_locs) == num_imgs
assert [img_loc["offset"] for img_loc in img_locs] == \
[i for i, v in enumerate(prompt_token_ids) \
if v == config.boi_token_index]
# patch sizes and masks
assert prompt_token_ids.count(config.image_token_index) \
== sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
patch_token_id = vocab[hf_processor.img_patch_token]
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
mm_counts = {"image": num_imgs}
assert num_patches / num_imgs <= \
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum()
for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
mm_kwargs["aspect_ratios"]):
assert embed_is_patch.shape[0] == \
len(tokenizer.encode(
hf_processor._prompt_split_image(
aspect_ratio, num_patches_per_chunk),
add_special_tokens=False))

View File

@ -337,6 +337,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}

View File

@ -23,6 +23,11 @@ from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Llama4ForCausalLM does not have a standalone model
if model_arch == "Llama4ForCausalLM":
return
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")
@ -91,8 +96,11 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
def test_hf_registry_coverage():
untested_archs = (ModelRegistry.get_supported_archs() -
HF_EXAMPLE_MODELS.get_supported_archs())
untested_archs = set(ModelRegistry.get_supported_archs() -
HF_EXAMPLE_MODELS.get_supported_archs())
# Llama4ForCausalLM does not have a standalone model
untested_archs.discard("Llama4ForCausalLM")
assert not untested_archs, (
"Please add the following architectures to "

View File

@ -354,6 +354,8 @@ class ModelConfig:
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(self.hf_text_config,
"attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)

View File

@ -500,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"

View File

@ -0,0 +1,200 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}

View File

@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@ -96,8 +97,14 @@ def cutlass_moe_fp8(
n = w2_q.size(1)
topk = topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
@ -139,6 +146,8 @@ def cutlass_moe_fp8(
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
return (c2[c_map].view(m, topk, k) *
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
# Gather tokens
c2 = c2[c_map].view(m, topk, k)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)

View File

@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
def inplace_fused_experts_fake(
@ -980,6 +981,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@ -1010,6 +1012,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@ -1023,10 +1026,11 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
def outplace_fused_experts_fake(
@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,

View File

@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
raise NotImplementedError
@ -156,22 +158,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation)
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def forward_cuda(
self,
@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
@ -202,15 +208,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
def forward_cpu(
self,
@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
**kwargs,
):
assert activation == "silu", f"{activation} is not supported."
assert apply_router_weight_on_input is False
return layer.ipex_fusion(
x,
use_grouped_topk,
@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU.")
@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
super().__init__()
@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.apply_router_weight_on_input = apply_router_weight_on_input
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module):
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
if self.dp_size > 1:

View File

@ -92,6 +92,7 @@ class RMSNorm(CustomOp):
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
@ -100,8 +101,10 @@ class RMSNorm(CustomOp):
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)

View File

@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,

View File

@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -240,20 +241,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,

View File

@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):

View File

@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),

View File

@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused GGUF MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,

View File

@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input is not None:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype

View File

@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -312,21 +313,23 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts(x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size])
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size])
@staticmethod
def get_weight_loader(layer, weight_loader):

View File

@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)

View File

@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
return inv_freqs
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches = self.max_position_embeddings
img_idx = torch.arange(num_patches,
dtype=torch.int32) \
.reshape(num_patches, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
num_patches_single_dim = int(math.sqrt(num_patches))
frequencies_x = img_idx % num_patches_single_dim
frequencies_y = img_idx // num_patches_single_dim
freqs_x = ((frequencies_x + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y],
dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
cache = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
return cache
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(
*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(

View File

@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
"norm": "model.norm",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

View File

@ -0,0 +1,530 @@
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type
import torch
from torch import nn
from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter)
class Llama4MoE(nn.Module):
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router")
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
apply_router_weight_on_input=True,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
return experts_out
class Llama4Attention(nn.Module):
def __init__(self,
config: Llama4TextConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
self.no_rope_layers = config.no_rope_layers
self.nope = self.no_rope_layers[self.layer_idx] == 0
self.use_qk_norm = config.use_qk_norm and not self.nope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
# TODO: attn_temperature_tuning should be a bool in huggingface
self.attn_temperature_tuning = self.nope and \
config.attn_temperature_tuning > 0
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm(
hidden_size=self.q_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
) if not self.nope else None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None:
q = self.q_norm(q.float()).to(q.dtype)
if self.k_norm is not None:
k = self.k_norm(k.float()).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
# is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
#
# We should apply temperature tuning between (after) rotary / QK norm
# and (before) attention.
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: Llama4TextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
self.layer_idx = extract_layer_index(prefix)
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
is_moe_layer = (self.layer_idx +
1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.feed_forward",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class Llama4Model(LlamaModel):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_moe_expert_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
loaded_params: Set[str],
expert_params_mapping: List[Tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
expert_param_loaded = False
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-1)
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
new_loaded_weight = loaded_weight
if fused:
e_str, _, proj_str, _ = weight_name.split('.')
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
if weight_name not in name:
continue
full_param_name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[full_param_name]
weight_loader = param.weight_loader
if fused:
if "w13" in full_param_name:
shard_idx = 0 if shard_id == "w1" else 1
new_loaded_weight = new_loaded_weight[shard_idx]
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
layer_idx = extract_layer_index(name)
# EP mapping
expert_map = self.layers[
layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
local_expert_indices = (expert_map != -1) \
.nonzero() \
.flatten() \
.to(new_loaded_weight.device)
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
# TODO: add EP support for non fused weights
pass
weight_loader(param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
fused_experts_params = False
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts)
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "experts" in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
moe_loaded = self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params)
if not moe_loaded:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False)
LlamaForCausalLM.__init__(self,
vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
return loader.load_weights(weights)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif ("wq" in modules or "q_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
return name, loaded_weight

View File

@ -0,0 +1,886 @@
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions, get_best_fit)
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
class Llama4ImagePatchInputs(TypedDict):
type: Literal["pixel_values"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_chunks, num_channels, image size, image size)`
"""
patches_per_image: torch.Tensor
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
"""
A list of aspect ratios corresponding to the number of tiles
in each dimension that each image in the batch corresponds to.
Shape:
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
"""
class Llama4VisionMLP(nn.Module):
def __init__(self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.fc1 = ColumnParallelLinear(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.activation_fn = nn.GELU()
self.output_activation = output_activation
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
if self.output_activation:
return self.activation_fn(hidden_states)
return hidden_states
class Llama4MultiModalProjector(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.linear_1 = ColumnParallelLinear(
input_size=config.vision_config.vision_output_dim,
output_size=config.text_config.hidden_size,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.linear_1",
)
def forward(self, image_features):
hidden_states, _ = self.linear_1(image_features)
return hidden_states
def pixel_shuffle(input_tensor, shuffle_ratio):
# input_tensor: [batch_size, num_patches, channels]
batch_size, num_patches, channels = input_tensor.shape
patch_size = int(math.sqrt(num_patches))
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(batch_size, height,
int(width * shuffle_ratio),
int(channels / shuffle_ratio))
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)))
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1,
reshaped_tensor.shape[-1])
return output_tensor
class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.inner_dim = int(config.projector_input_dim //
(self.pixel_shuffle_ratio**2))
self.output_dim = config.projector_output_dim
self.mlp = Llama4VisionMLP(
input_size=config.intermediate_size,
intermediate_size=config.projector_input_dim,
output_size=config.projector_output_dim,
bias=config.multi_modal_projector_bias,
output_activation=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches,
self.pixel_shuffle_ratio)
return self.mlp(encoded_patches)
class Llama4VisionAttention(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
assert self.num_heads % self.tp_size == 0
self.num_local_heads = self.num_heads // self.tp_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
self.scaling)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches
max_position=(config.image_size // config.patch_size)**2,
base=config.rope_theta,
rope_scaling={"rope_type": "mllama4"},
is_neox_style=False,
dtype=torch.complex64, # important
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_shape = hidden_states.shape[:-1]
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
q, k = self.rotary_emb(q, k)
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
attn_output = self.attn(q, k, v)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output, _ = self.o_proj(attn_output)
return attn_output
class Llama4VisionEncoderLayer(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.intermediate_size = config.intermediate_size
self.self_attn = Llama4VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
def forward(
self,
hidden_state: torch.Tensor,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state)
hidden_state = residual + hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
hidden_state = residual + hidden_state
outputs = (hidden_state, )
return outputs
class Llama4VisionEncoder(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Llama4VisionEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
) for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding
lookup matrix.
"""
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs[0]
return BaseModelOutput(last_hidden_state=hidden_states, )
class Llama4UnfoldConvolution(nn.Module):
def __init__(self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
self.linear = ColumnParallelLinear(config.num_channels *
kernel_size[0] * kernel_size[1],
config.hidden_size,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.linear")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states, _ = self.linear(hidden_states)
return hidden_states
class Llama4VisionModel(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = Llama4UnfoldConvolution(
config,
quant_config=quant_config,
prefix=f"{prefix}.patch_embedding")
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.positional_embedding_vlm = nn.Parameter(
self.scale * torch.randn(self.num_patches, self.hidden_size))
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
# encoders
self.model = Llama4VisionEncoder(config,
quant_config=quant_config,
prefix=f"{prefix}.model")
self.vision_adapter = Llama4VisionPixelShuffleMLP(
config, quant_config, prefix=f"{prefix}.vision_adapter")
def forward(
self,
images_flattened: torch.Tensor,
) -> BaseModelOutput:
# Patch embedding
hidden_state = self.patch_embedding(images_flattened)
num_tiles, num_patches, hidden_dim = hidden_state.shape
# Add cls token
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1,
hidden_state.shape[-1])
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1
# Position embeddings
hidden_state = hidden_state.reshape(
num_tiles,
1,
num_patches,
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device)
hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
# Apply encoder
output = self.model(hidden_state)
hidden_state = output.last_hidden_state
hidden_state = self.layernorm_post(hidden_state)
# Remove CLS token output
hidden_state = hidden_state[:, :-1, :]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state = self.vision_adapter(hidden_state)
return BaseModelOutput(
last_hidden_state=hidden_state,
attentions=None,
)
class Mllama4ProcessingInfo(BaseProcessingInfo):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(ctx)
def get_hf_config(self) -> Llama4Config:
return self.ctx.get_hf_config(Llama4Config)
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
return self.ctx.get_hf_processor(Llama4Processor,
use_fast=True,
**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 10}
@staticmethod
def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
image_size = vision_config.image_size
patch_size = vision_config.patch_size
assert (
image_size %
patch_size == 0), f"chunk size {image_size} should be multiple of "
f"patch_size {patch_size}"
ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size)**2 // ds_ratio
def get_max_num_tiles(self) -> int:
image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
# image_start + local tiles * (patches + 1 x separator) +
# 1 global tile * (image x 1 + patches) + image_end
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
return {"image": mm_max_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data is None:
return tokenizer(prompt, add_special_tokens=False) # exclude bos
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = processor.image_processor
vision_config = self.info.get_hf_config().vision_config
if processed_outputs.get("pixel_values") is not None:
assert "images" in mm_data, \
"images expected to be in mm_data when pixel_values is present"
images = mm_data["images"]
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
tile_size = vision_config.image_size
possible_resolutions = find_supported_resolutions(
max_num_chunks=self.info.get_max_num_tiles(),
patch_size=SizeDict(height=tile_size, width=tile_size),
)
best_fit_sizes = [
get_best_fit(
(image.size[1], image.size[0]),
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas)
for image in parsed_images
]
# TODO tile height/width do not necessarily need to match
aspect_ratios = [(image_size[0] // tile_size,
image_size[1] // tile_size)
for image_size in best_fit_sizes]
patches_per_image = [
1 if r_h * r_w == 1 else 1 + r_h * r_w
for (r_h, r_w) in aspect_ratios
]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
num_patches_per_chunk = self.info.get_patch_per_chunk(
vision_config)
expanded_image_tokens_list = [
processor._prompt_split_image(aspect_ratio,
num_patches_per_chunk)
for aspect_ratio in aspect_ratios
]
expanded_image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in expanded_image_tokens_list
]
embed_is_patch = [
torch.tensor(tokens) == patch_id
for tokens in expanded_image_token_ids
]
processed_outputs["aspect_ratios"] = aspect_ratios
processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"),
aspect_ratios=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> List[PromptUpdate]:
assert (
mm_items.get_count("image", strict=False) == 0
or "aspect_ratios" in out_mm_kwargs
), "Transformers expect to include aspect_ratios in out_mm_kwargs"
config = self.info.get_hf_config()
vision_config = config.vision_config
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
def get_replacement(item_idx: int):
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
return hf_processor._prompt_split_image(
aspect_ratio=aspect_ratio,
num_patches_per_chunk=num_patches_per_chunk)
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement,
)
]
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
(target_width,
target_height) = self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
image_token = self.info.get_hf_processor().fake_image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(
Mllama4MultiModalProcessor,
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_model = Llama4VisionModel(config.vision_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"))
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
architectures=["Llama4ForCausalLM"],
prefix=maybe_prefix(prefix, "language_model"))
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
# num_images, 1, num_chunks, channel, image_size, image_size
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
# num_images x num_chunks, channel, image_size, image_size
# TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
aspect_ratios = kwargs.pop("aspect_ratios", None)
if not isinstance(aspect_ratios, (torch.Tensor, list)):
raise ValueError("Incorrect type of aspect_ratios. "
f"Got type: {type(aspect_ratios)}")
return Llama4ImagePatchInputs(
type="pixel_values",
flat_data=flat_pixel_values,
patches_per_image=patches_per_image,
embed_is_patch=embed_is_patch,
aspect_ratios=aspect_ratios,
)
def _process_image_input(
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data).last_hidden_state
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
# num_images x [num_chunks, num_patches, hidden_dim]
image_features = self._process_image_input(image_input)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat = [img.flatten(0, 1) for img in image_features]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat = [
is_patch.flatten(0, 1)
for is_patch in image_input["embed_is_patch"]
]
return scatter_patch_features(
image_features_flat,
embed_is_patch_flat,
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings)
mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, select_patch_features(mm_embeddings),
self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if "pixel_values" in kwargs:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
return self.language_model(input_ids, positions, intermediate_tensors,
inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def separate_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str,
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
str, torch.Tensor]]]:
weights1, weights2 = tee(weights, 2)
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
for name, data in weights1:
if name.startswith(prefix):
yield (name, data)
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
for name, data in weights2:
if not name.startswith(prefix):
yield (name, data)
return get_prefix_weights(), get_other_weights()
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: Set[str] = set()
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_prefix = "language_model.model."
language_model_weights, other_weights = self.separate_weights(
weights, prefix=language_model_prefix)
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(
language_model_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
for name, loaded_weight in other_weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params

View File

@ -73,6 +73,7 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
@ -194,6 +195,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}

View File

@ -2189,8 +2189,6 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
socket_type: Any,
bind: Optional[bool] = None,
identity: Optional[bytes] = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
@ -2209,24 +2207,16 @@ def make_zmq_socket(
else:
buf_size = -1 # Use system default buffer size
if bind is None:
bind = socket_type != zmq.PUSH
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
if bind:
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.bind(path)
else:
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")
return socket
@ -2235,19 +2225,14 @@ def make_zmq_socket(
def zmq_socket_ctx(
path: str,
socket_type: Any,
bind: Optional[bool] = None,
linger: int = 0,
identity: Optional[bytes] = None,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx,
path,
socket_type,
bind=bind,
identity=identity)
yield make_zmq_socket(ctx, path, socket_type)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")

View File

@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_attn_metadata: Optional[LocalAttentionMetadata] = None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
q_seqlens).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = \
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1),
attn_chunk_size)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
.astype(np.int32)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1],
attn_chunk_size,
dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
(rarange * attn_chunk_size + \
np.repeat(tokens_in_last_block, local_blocks))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, \
f"attn_chunk_size {attn_chunk_size} is not " \
f"divisible by page_size {page_size}"
pages_per_local_batch = attn_chunk_size // page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices= np.broadcast_to(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch)) \
+ np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten()
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\
.view(virtual_batches, -1)
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
block_table_local
class FlashAttentionMetadataBuilder:
@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table,
self.runner.block_size,
)
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(
virt_q_cu_seqlens_np).to(self.runner.device,
non_blocking=True),
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True),
local_block_table=virt_block_table,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=virt_k_seqlens_np.max(),
)
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
)
return attn_metadata
@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade:
# Regular attention (common case).
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
block_table=block_table,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
assert not use_local_attn, (
"Cascade attention does not support local attention.")
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],

View File

@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@ -156,24 +158,41 @@ class TritonAttentionImpl(AttentionImpl):
layer._v_scale,
)
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
sequesd_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
sequesd_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=attn_metadata.block_table,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
max_query_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=sequesd_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
return output

View File

@ -318,11 +318,6 @@ class EngineCoreProc(EngineCore):
):
super().__init__(vllm_config, executor_class, log_stats)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.global_unfinished_reqs = False
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
@ -332,16 +327,22 @@ class EngineCoreProc(EngineCore):
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, engine_index),
args=(input_path, ),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True).start()
self.global_unfinished_reqs = False
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process."""
@ -376,6 +377,9 @@ class EngineCoreProc(EngineCore):
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop()
except SystemExit:
@ -472,22 +476,14 @@ class EngineCoreProc(EngineCore):
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def process_input_socket(self, input_path: str, engine_index: int):
def process_input_socket(self, input_path: str):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")
with zmq_socket_ctx(input_path,
zmq.DEALER,
identity=identity,
bind=False) as socket:
# Send ready message to front-end once input socket is connected.
socket.send(b'READY')
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)

View File

@ -8,7 +8,7 @@ import threading
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
@ -35,8 +35,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
STARTUP_POLL_PERIOD_MS = 10000
class EngineCoreClient(ABC):
"""
@ -263,13 +261,15 @@ class CoreEngine:
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
input_path: str,
ctx: Union[zmq.Context, zmq.asyncio.Context],
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
@ -291,9 +291,14 @@ class CoreEngine:
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass
@ -304,7 +309,6 @@ class BackgroundResources:
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
shutdown_path: Optional[str] = None
def __call__(self):
@ -317,8 +321,6 @@ class BackgroundResources:
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
@ -385,51 +387,21 @@ class MPClient(EngineCoreClient):
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
index, local_dp_rank)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Wait for engine core process(es) to start.
self._wait_for_engine_startup()
for engine in self.resources.core_engines:
engine.proc_handle.wait_for_startup()
self.utility_results: dict[int, AnyFuture] = {}
def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines)
while identities:
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
logger.info("Waiting for %d core engine proc(s) to start: %s",
len(identities), identities)
eng_id_bytes, msg = sync_input_socket.recv_multipart()
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
if eng_id not in identities:
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
if msg != b'READY':
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id)
# Double check that the process are running.
for engine in self.resources.core_engines:
proc = engine.proc_handle.proc
if proc.exitcode is not None:
raise RuntimeError(f"Engine proc {proc.name} not running")
def _init_core_engines(
self,
vllm_config: VllmConfig,
@ -522,10 +494,9 @@ class SyncMPClient(MPClient):
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.core_engine.send_multipart(msg)
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
@ -654,34 +625,30 @@ class AsyncMPClient(MPClient):
assert self.outputs_queue is not None
return await self.outputs_queue.get()
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
if engine is None:
engine = self.core_engine
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
message = (request_type.value, self.encoder.encode(request))
return self._send_input_message(message, engine)
def _send_input_message(self, message: tuple[bytes, bytes],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message # type: ignore[assignment]
return self.input_socket.send_multipart(message, copy=False)
self._ensure_output_queue_task()
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args,
engine: CoreEngine) -> Any:
async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await self._send_input_message(message, engine)
await engine.send_multipart(message)
self._ensure_output_queue_task()
return await future
@ -690,7 +657,6 @@ class AsyncMPClient(MPClient):
# tokenized.
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()
async def abort_requests_async(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
@ -795,15 +761,15 @@ class DPAsyncMPClient(AsyncMPClient):
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await self._send_input_message(msg, chosen_engine)
await chosen_engine.send_multipart(msg)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
self._send_input_message(
msg if engine is chosen_engine else self.start_dp_msg,
engine) for engine in self.core_engines
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
])
self._ensure_output_queue_task()
@ -828,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient):
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
self._send_input_message(self.start_dp_msg, engine)
engine.send_multipart(self.start_dp_msg)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
@ -854,5 +820,5 @@ class DPAsyncMPClient(AsyncMPClient):
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))

View File

@ -105,9 +105,12 @@ class BackgroundProcHandle:
process_kwargs: dict[Any, Any],
):
context = get_mp_context()
self.reader, writer = context.Pipe(duplex=False)
assert ("input_path" not in process_kwargs
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
@ -119,6 +122,12 @@ class BackgroundProcHandle:
input_path, output_path)
self.proc.start()
def wait_for_startup(self):
# Wait for startup.
if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.")
def shutdown(self):
self._finalizer()

View File

@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
self.attn_backend = get_attn_backend(
self.head_size,