Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/dbo-cudagraph-size

This commit is contained in:
Sage Moore
2025-09-26 20:45:41 +00:00
44 changed files with 1170 additions and 473 deletions

View File

@ -522,7 +522,7 @@ steps:
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
- label: LM Eval Small Models # 53min
timeout_in_minutes: 75
@ -830,6 +830,23 @@ steps:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
- label: Blackwell Quantized MoE Test
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
gpu: b200
source_file_dependencies:
- tests/quantization/test_blackwell_moe.py
- vllm/model_executor/models/deepseek_v2.py
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/models/llama4.py
- vllm/model_executor/layers/fused_moe
- vllm/model_executor/layers/quantization/compressed_tensors
- vllm/model_executor/layers/quantization/modelopt.py
- vllm/model_executor/layers/quantization/mxfp4.py
- vllm/v1/attention/backends/flashinfer.py
commands:
- pytest -s -v tests/quantization/test_blackwell_moe.py
##### 1 GPU test #####
##### multi gpus test #####

View File

@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
*Latest News* 🔥
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).

View File

@ -2,6 +2,7 @@
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
- [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing)
- [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ), August 30th 2025. [[Slides]](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA)
- [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet), August 27th 2025. [[Slides]](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing)
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH)

View File

@ -20,7 +20,80 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]
--8<-- "docs/getting_started/installation/cpu/build.inc.md"
Install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
```bash
sudo apt-get update -y
sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
```
Clone the vLLM project:
```bash
git clone https://github.com/vllm-project/vllm.git vllm_source
cd vllm_source
```
Install the required dependencies:
```bash
uv pip install -r requirements/cpu-build.txt --torch-backend cpu
uv pip install -r requirements/cpu.txt --torch-backend cpu
```
??? console "pip"
```bash
pip install --upgrade pip
pip install -v -r requirements/cpu-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
```
Build and install vLLM:
```bash
VLLM_TARGET_DEVICE=cpu uv pip install . --no-build-isolation
```
If you want to develop vLLM, install it in editable mode instead.
```bash
VLLM_TARGET_DEVICE=cpu uv pip install -e . --no-build-isolation
```
Optionally, build a portable wheel which you can then install elsewhere:
```bash
VLLM_TARGET_DEVICE=cpu uv build --wheel
```
```bash
uv pip install dist/*.whl
```
??? console "pip"
```bash
VLLM_TARGET_DEVICE=cpu python -m build --wheel --no-isolation
```
```bash
pip install dist/*.whl
```
!!! example "Troubleshooting"
- **NumPy ≥2.0 error**: Downgrade using `pip install "numpy<2.0"`.
- **CMake picks up CUDA**: Add `CMAKE_DISABLE_FIND_PACKAGE_CUDA=ON` to prevent CUDA detection during CPU builds, even if CUDA is installed.
- `AMD` requies at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU.
- If you receive an error such as: `Could not find a version that satisfies the requirement torch==X.Y.Z+cpu+cpu`, consider updating [pyproject.toml](https://github.com/vllm-project/vllm/blob/main/pyproject.toml) to help pip resolve the dependency.
```toml title="pyproject.toml"
[build-system]
requires = [
"cmake>=3.26.1",
...
"torch==X.Y.Z+cpu" # <-------
]
```
- If you are building vLLM from source and not using the pre-built images, remember to set `LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD"` on x86 machines before running vLLM.
# --8<-- [end:build-wheel-from-source]
# --8<-- [start:pre-built-images]
@ -57,4 +130,4 @@ docker run --rm \
# --8<-- [end:build-image-from-source]
# --8<-- [start:extra-information]
# --8<-- [end:extra-information]
# --8<-- [end:extra-information]

View File

@ -3,12 +3,11 @@
import contextlib
import os
import weakref
from dataclasses import dataclass
from typing import Optional
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
@ -33,89 +32,6 @@ def temporary_environ(env_vars):
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA

View File

@ -4,7 +4,7 @@ import pytest
import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.utils import _is_torch_equal_or_newer
@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
enforce_eager=True,
gpu_memory_utilization=0.4) as _):
pass
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if _is_torch_equal_or_newer('2.9.0.dev'):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
splitting_ops=["silly_attention"]))
# should ignore splitting_ops
assert config.compilation_config.splitting_ops == []
# When attn_fusion pass enabled.
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
))
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer('2.9.0.dev'):
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.PIECEWISE

View File

@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import pytest
from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
if not current_platform.is_device_capability(100):
pytest.skip("This test only runs on Blackwell GPUs (SM100).",
allow_module_level=True)
os.environ["FLASHINFER_NVCC_THREADS"] = "16"
# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4,
# "text_config": {"num_layers": 4, "num_hidden_layers": 4}}
dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4}
def can_initialize(model: str, extra_args: list[str]):
# Server arguments
server_args = [
"--max-model-len",
"2048",
"--max-num-batched-tokens",
"256",
"--load-format",
"dummy",
"--trust-remote-code",
"--limit-mm-per-prompt",
json.dumps({"image": 0}),
*extra_args,
]
# Launch server and make a simple request
with RemoteOpenAIServer(
model,
server_args,
max_wait_seconds=1000, # Due to FlashInfer compile
override_hf_configs=dummy_hf_overrides) as server:
client = server.get_client()
# Make a simple request to verify the server works
completion = client.completions.create(
model=model,
prompt=["Hello, World!"],
temperature=0,
max_tokens=2,
)
print(completion)
assert completion.choices[0].text is not None
## Llama4 ##
@pytest.mark.skip(reason=(
"RuntimeError: run_moe() Expected a value of type "
"'Optional[List[Tensor]]' for argument '_9' but instead found type "
"'list'."))
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
@pytest.mark.skip(reason="Works, but takes too long to run")
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
@pytest.mark.skip(reason="Works, but takes too long to run")
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
## DeepSeekV3 ##
def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
can_initialize("deepseek-ai/DeepSeek-V3.1", [])
def test_deepseek_nvfp4_moe_flashinfer_cutlass(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
## GPT-OSS ##
def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
can_initialize("openai/gpt-oss-20b", [])
def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
can_initialize("openai/gpt-oss-20b", [])
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
can_initialize("openai/gpt-oss-20b", [])

View File

@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "glm45"
start_token = "<think>"
end_token = "</think>"
REASONING_MODEL_NAME = "zai-org/GLM-4.5"
@pytest.fixture(scope="module")
def glm45_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
WITH_THINK_STREAM = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
WITHOUT_THINK = {
"output": "This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": False,
}
WITHOUT_THINK_STREAM = {
"output": "This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": False,
}
COMPLETE_REASONING = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTILINE_REASONING = {
"output":
"<think>This is a reasoning\nsection</think>This is the rest\nThat",
"reasoning_content": "This is a reasoning\nsection",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
ONLY_OPEN_TAG = {
"output": "<think>This is a reasoning section",
"reasoning_content": None,
"content": "<think>This is a reasoning section",
"is_reasoning_end": False,
}
ONLY_OPEN_TAG_STREAM = {
"output": "<think>This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
TEST_CASES = [
pytest.param(
False,
WITH_THINK,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
id="with_think_stream",
),
pytest.param(
False,
WITHOUT_THINK,
id="without_think",
),
pytest.param(
True,
WITHOUT_THINK_STREAM,
id="without_think_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_stream",
),
pytest.param(
False,
MULTILINE_REASONING,
id="multiline_reasoning",
),
pytest.param(
True,
MULTILINE_REASONING,
id="multiline_reasoning_stream",
),
pytest.param(
False,
ONLY_OPEN_TAG,
id="only_open_tag",
),
pytest.param(
True,
ONLY_OPEN_TAG_STREAM,
id="only_open_tag_stream",
),
]
STILL_REASONING_PROMPT = """[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
What is the capital of France?<|assistant|>
<think>The user is asking for the capital of"""
DONE_REASONING_PROMPT = """[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
What is the capital of France?<|assistant|>
<think>The user is asking for the capital of France.</think>
The capital of France is Paris."""
MULTI_TURN_STILL_REASONING_PROMPT = """[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
What is the capital of France?<|assistant|>
<think></think>
The capital of France is Paris.<|user|>
What about Chile?<|assistant|>
<think>The user is asking for the capital of"""
MULTI_TURN_DONE_REASONING_PROMPT = """[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
What is the capital of France?<|assistant|>
<think></think>
The capital of France is Paris.<|user|>
What about Chile?<|assistant|>
<think>The user is asking for the capital of Chile.</think>
The capital of Chile is Santiago."""
REASONING_END_TEST_CASES = [
pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"),
pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"),
pytest.param(MULTI_TURN_STILL_REASONING_PROMPT,
False,
id="multi_turn_still_reasoning"),
pytest.param(MULTI_TURN_DONE_REASONING_PROMPT,
True,
id="multi_turn_done_reasoning")
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
glm45_tokenizer,
):
output = glm45_tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = [
glm45_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(glm45_tokenizer)
reasoning, content = run_reasoning_extraction(parser,
output_tokens,
streaming=streaming)
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
output_ids = glm45_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES)
def test_is_reasoning_end_full_prompt(prompt: str, is_reasoning_end: bool,
glm45_tokenizer):
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(glm45_tokenizer)
tokens = glm45_tokenizer.tokenize(prompt)
token_ids = glm45_tokenizer.convert_tokens_to_ids(tokens)
check_is_reasoning_end = parser.is_reasoning_end(token_ids)
assert check_is_reasoning_end == is_reasoning_end

View File

@ -91,8 +91,10 @@ class RemoteOpenAIServer:
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
self.proc: subprocess.Popen = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
serve_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,

View File

@ -23,15 +23,16 @@ from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.transformers_utils.detokenizer_utils import (
convert_ids_list_to_tokens)
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype,
current_stream, deprecate_kwargs, get_open_port,
get_tcp_uri, is_lossless_cast, join_host_port,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
# isort: off
from vllm.utils import (
CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot,
PlaceholderModule, bind_kv_cache, common_broadcastable_dtype,
current_stream, deprecate_kwargs, get_open_port, get_tcp_uri,
is_lossless_cast, join_host_port, make_zmq_path, make_zmq_socket,
memory_profiling, merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values, unique_filepath)
# isort: on
from ..utils import create_new_process_for_each_test, error_on_warning
@ -1032,3 +1033,15 @@ def test_load_config_file(tmp_path):
# Assert that the processed arguments match the expected output
assert processed_args == expected_args
os.remove(str(config_file_path))
def test_unique_filepath():
temp_dir = tempfile.mkdtemp()
path_fn = lambda i: Path(temp_dir) / f"file_{i}.txt"
paths = set()
for i in range(10):
path = unique_filepath(path_fn)
path.write_text("test")
paths.add(path)
assert len(paths) == 10
assert len(list(Path(temp_dir).glob("*.txt"))) == 10

View File

@ -3,7 +3,7 @@
"""Utility functions for attention-related v1 tests."""
from dataclasses import dataclass
from typing import Union
from typing import Optional, Union
import pytest
import torch
@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
dtype=dtype,
device=device)
return kv_cache
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict # compilation config
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}

View File

@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig,
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"params",
"case_id,cudagraph_mode_str,compilation_level",
[
# Test case 0: Full CG for mixed batches, no separate routine
{
"case_id": 0,
"cudagraph_mode": "FULL",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
(0, "FULL", CompilationLevel.NO_COMPILATION),
# Test case 1: Full CG for uniform batches, piecewise for mixed
{
"case_id": 1,
"cudagraph_mode": "FULL_AND_PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
# Test case 2: Full CG for uniform batches, no CG for mixed
{
"case_id": 2,
"cudagraph_mode": "FULL_DECODE_ONLY",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
# Test case 3: Piecewise for all
{
"case_id": 3,
"cudagraph_mode": "PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
])
def test_dispatcher(self, params):
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=params["cudagraph_mode"],
level=params["compilation_level"],
cudagraph_capture_sizes=[1, 8])
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
level=compilation_level,
cudagraph_capture_sizes=[1, 8])
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
@ -86,11 +69,11 @@ class TestCudagraphDispatcher:
uniform_decode_query_len=1)
# Verify the key is initialized correctly
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
@ -99,10 +82,10 @@ class TestCudagraphDispatcher:
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if params["cudagraph_mode"] == "FULL":
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
else:
@ -111,15 +94,13 @@ class TestCudagraphDispatcher:
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
if params["cudagraph_mode"] == "FULL":
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
elif params["cudagraph_mode"] in [
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
]:
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif params["cudagraph_mode"] == "PIECEWISE":
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
else:
@ -131,6 +112,16 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
assert key is None
# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact,
use_cascade_attn=True)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform
else:
assert rt_mode == CUDAGraphMode.NONE
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:

View File

@ -4,12 +4,11 @@ import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
@ -34,74 +33,6 @@ def temporary_environ(env_vars):
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
@ -114,9 +45,10 @@ combo_cases_1 = [
]
@pytest.mark.parametrize("combo_case", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(combo_case):
backend_name, cudagraph_mode, supported = combo_case
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported",
combo_cases_1)
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode,
supported):
if backend_name == "FlashInfer":
try:
import flashinfer # noqa: F401
@ -142,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case):
compilation_config=CompilationConfig(
level=3, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
# when above code raises, `llm` may be undefined, so we need to catch that
try:
llm = weakref.proxy(llm)
del llm
@ -173,7 +105,8 @@ combo_cases_2 = [
]
@pytest.mark.parametrize("combo_case", combo_cases_2)
@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\
"supported", combo_cases_2)
def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_level, supported\
= combo_case
@ -192,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
compilation_config=CompilationConfig(
level=compilation_level, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
# when above code raises, `llm` may be undefined, so we need to catch that
try:
llm = weakref.proxy(llm)
del llm

View File

@ -8,7 +8,8 @@ import ray
from vllm.config import ModelDType
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
from vllm.v1.metrics.ray_wrappers import (RayPrometheusMetric,
RayPrometheusStatLogger)
@pytest.fixture(scope="function", autouse=True)
@ -65,3 +66,39 @@ def test_engine_log_metrics_ray(
# Create the actor and call the async method
actor = EngineTestActor.remote() # type: ignore[attr-defined]
ray.get(actor.run.remote())
def test_sanitized_opentelemetry_name():
"""Test the metric name sanitization logic for Ray."""
# Only a-z, A-Z, 0-9, _, test valid characters are preserved
valid_name = "valid_metric_123_abcDEF"
assert RayPrometheusMetric._get_sanitized_opentelemetry_name(
valid_name) == valid_name
# Test dash, dot, are replaced
name_with_dash_dot = "metric-name.test"
expected = "metric_name_test"
assert RayPrometheusMetric._get_sanitized_opentelemetry_name(
name_with_dash_dot) == expected
# Test colon is replaced with underscore
name_with_colon = "metric:name"
expected = "metric_name"
assert RayPrometheusMetric._get_sanitized_opentelemetry_name(
name_with_colon) == expected
# Test multiple invalid characters are replaced
name_with_invalid = "metric:name@with#special%chars"
expected = "metric_name_with_special_chars"
assert RayPrometheusMetric._get_sanitized_opentelemetry_name(
name_with_invalid) == expected
# Test mixed valid and invalid characters
complex_name = "vllm:engine_stats/time.latency_ms-99p"
expected = "vllm_engine_stats_time_latency_ms_99p"
assert RayPrometheusMetric._get_sanitized_opentelemetry_name(
complex_name) == expected
# Test empty string
assert RayPrometheusMetric._get_sanitized_opentelemetry_name("") == ""

View File

@ -340,15 +340,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_piecewise_backend import PiecewiseBackend
from .piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and
if (self.compilation_config.cudagraph_mode.\
has_piecewise_cudagraphs() and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.

View File

@ -336,7 +336,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition):
from torch._inductor.utils import CUDAGraphWrapperMetadata
@ -365,7 +365,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
yield
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None)

View File

@ -270,6 +270,7 @@ class VllmConfig:
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
quant_config.maybe_update_config(model_config.model)
return quant_config
return None
@ -458,15 +459,22 @@ class VllmConfig:
"to True to enable.")
current_platform.check_and_update_config(self)
# final check of cudagraph mode after platform-specific update
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
# final check of cudagraph mode after all possible updates
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
and self.model_config is not None and \
not self.model_config.disable_cascade_attn:
logger.info("CUDAGraphMode.FULL is not supported with "
"cascade attention currently. Disabling cascade"
"attention.")
self.model_config.disable_cascade_attn = True
not self.model_config.disable_cascade_attn and\
not self.compilation_config.cudagraph_mode.\
has_piecewise_cudagraphs():
logger.warning_once(
"No piecewise cudagraph for executing cascade attention."
" Will fall back to eager execution if a batch runs "
"into cascade attentions")
if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation():
@ -476,6 +484,12 @@ class VllmConfig:
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
# final migrate the deprecated flags
self.compilation_config.use_cudagraph = self.compilation_config.\
cudagraph_mode!= CUDAGraphMode.NONE
self.compilation_config.full_cuda_graph = self.compilation_config.\
cudagraph_mode.has_full_cudagraphs()
if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend in \
@ -486,14 +500,14 @@ class VllmConfig:
"variable to deepep_low_latency or deepep_high_throughput and "\
"install the DeepEP kernels."
if not self.model_config.disable_cascade_attn:
self.model_config.disable_cascade_attn = True
logger.warning_once(
"Disabling cascade attention when DBO is enabled.")
if not self.instance_id:
self.instance_id = random_uuid()[:5]
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we

View File

@ -61,9 +61,17 @@ class CUDAGraphMode(enum.Enum):
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
def has_piecewise_cudagraphs(self) -> bool:
return self.requires_piecewise_compilation()
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
def valid_runtime_modes(self) -> bool:
return self in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
]
@config
@dataclass
@ -269,7 +277,8 @@ class CompilationConfig:
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
@ -294,7 +303,8 @@ class CompilationConfig:
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
"""
use_inductor_graph_partition: bool = False
@ -464,7 +474,8 @@ class CompilationConfig:
if not self.use_cudagraph:
logger.warning("use_cudagraph is deprecated, use "
"cudagraph_mode=NONE instead.")
if self.cudagraph_mode is not None:
if self.cudagraph_mode is not None and \
self.cudagraph_mode != CUDAGraphMode.NONE:
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
@ -473,7 +484,8 @@ class CompilationConfig:
if self.full_cuda_graph:
logger.warning("full_cuda_graph is deprecated, use "
"cudagraph_mode=FULL instead.")
if self.cudagraph_mode is not None:
if self.cudagraph_mode is not None and \
not self.cudagraph_mode.has_full_cudagraphs():
raise ValueError("full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated.")
@ -570,48 +582,75 @@ class CompilationConfig:
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE")
if self.use_inductor_graph_partition:
self.set_splitting_ops_for_inductor_graph_partition()
return
if self.pass_config.enable_attn_fusion:
# here use_inductor_graph_partition is False
self.set_splitting_ops_for_attn_fusion()
return
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once(
"Using piecewise compilation with empty splitting_ops")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not" \
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention backends "
"that support cudagraph, consider manually setting "
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
"full cudagraphs.")
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not "
"contains piecewise cudagraph. Setting cudagraph_mode "
"to FULL.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
"used to annotate custom ops for graph partition.")
if self.splitting_ops is None:
if self.use_inductor_graph_partition:
# When using inductor graph partition, we set splitting_ops
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
# annotate custom ops as splitting ops.
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
else:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once(
"Using piecewise compilation with empty "
"splitting_ops and use_inductor_graph_partition"
f"={self.use_inductor_graph_partition}.")
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
and not self.use_inductor_graph_partition):
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
"treated as FULL cudagraph_mode. Please ensure you are "
"using attention backends that support cudagraph or set "
"cudagraph_mode to NONE explicitly if encountering "
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
elif self.use_inductor_graph_partition:
if self.splitting_ops is not None and \
len(self.splitting_ops) > 0:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off."
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
"when enable_attn_fusion is True")
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(

View File

@ -540,7 +540,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type

View File

@ -246,8 +246,7 @@ class ForwardContext:
ubatch_slices: Optional[UBatchSlices] = None
def __post_init__(self):
assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"

View File

@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8(
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 256
assert global_num_experts <= 256
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!

View File

@ -162,3 +162,9 @@ class QuantizationConfig(ABC):
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
def maybe_update_config(self, model_name: str): # noqa: B027
"""
Interface to update values after config initialization.
"""
pass

View File

@ -7,6 +7,7 @@ from fractions import Fraction
from typing import Any, Optional, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
@ -22,6 +23,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
class GPTQConfig(QuantizationConfig):
@ -38,6 +41,7 @@ class GPTQConfig(QuantizationConfig):
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
autoround_version: str = "",
modules_in_block_to_quantize: Optional[list[str]] = None,
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
@ -75,15 +79,20 @@ class GPTQConfig(QuantizationConfig):
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
return (
f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
@ -114,8 +123,10 @@ class GPTQConfig(QuantizationConfig):
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic, autoround_version)
dynamic, autoround_version, modules_in_block_to_quantize)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
@ -136,6 +147,35 @@ class GPTQConfig(QuantizationConfig):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)
def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class ExllamaState(Enum):

View File

@ -5,6 +5,7 @@ from copy import deepcopy
from typing import Any, Callable, Optional, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
@ -35,6 +36,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
logger = init_logger(__name__)
@ -71,10 +74,16 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any]) -> None:
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
@ -121,15 +130,19 @@ class GPTQMarlinConfig(QuantizationConfig):
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = full_config.get("autoround_version", "")
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
return (
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
@ -158,8 +171,11 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic, config)
lm_head_quantized, dynamic, config,
modules_in_block_to_quantize)
@classmethod
def override_quantization_method(
@ -223,6 +239,35 @@ class GPTQMarlinConfig(QuantizationConfig):
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)
def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.

View File

@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from copy import deepcopy
from fractions import Fraction
from types import MappingProxyType
from typing import Optional, Union
import regex as re
@ -70,6 +72,49 @@ def get_dynamic_override(
return default_value
def is_layer_gptq_quantized(
prefix: str,
quantized_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
# GPTQ's `modules_in_block_to_quantize`:
# Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
# Full prefix ["model.layers.0.self_attn.q_proj"]
proj_name = prefix.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_quantized = None
for shard_prefix in shard_prefixes:
is_shard_quantized = any(layer in shard_prefix
for layer in quantized_layers)
if is_quantized is None:
is_quantized = is_shard_quantized
elif is_shard_quantized != is_quantized:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_quantized = any(layer in prefix for layer in quantized_layers)
assert is_quantized is not None
return is_quantized
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
@ -80,10 +125,15 @@ def get_linear_quant_method(
parallel_lm_head_quantized = isinstance(
layer, ParallelLMHead) and cloned_config.lm_head_quantized
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
is_layer_quantized = is_layer_gptq_quantized(
prefix=prefix,
quantized_layers=cloned_config.modules_in_block_to_quantize,
fused_mapping=cloned_config.packed_modules_mapping)
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
layer_name=prefix) == False or (
not is_layer_quantized): # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()

View File

@ -25,9 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -1281,11 +1278,6 @@ class BaseKeyeModule(nn.Module):
raise ValueError("Only image or video modality is supported")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
@ -1297,14 +1289,14 @@ class BaseKeyeModule(nn.Module):
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "mlp_AR"),
)

View File

@ -55,7 +55,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -381,6 +381,9 @@ class MiniCPMModel(nn.Module):
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
@ -408,7 +411,8 @@ class MiniCPMModel(nn.Module):
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@ -419,18 +423,29 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states +
residual if residual is not None else hidden_states)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
@ -502,7 +517,7 @@ class MiniCPMModel(nn.Module):
return loaded_params
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -568,16 +583,36 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) / self.scale_width
return hidden_states
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
if isinstance(model_output, tuple) and len(model_output) == 2:
# Aux hidden states are present.
hidden_states, aux_hidden_states = model_output
hidden_states = hidden_states / self.scale_width
return hidden_states, aux_hidden_states
else:
# Only hidden states or IntermediateTensors
if isinstance(model_output, IntermediateTensors):
return model_output
else:
hidden_states = model_output / self.scale_width
return hidden_states
def compute_logits(
self,

View File

@ -28,7 +28,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import (ACT2FN,
WhisperAttention,
@ -36,10 +36,6 @@ from transformers.models.whisper.modeling_whisper import (ACT2FN,
WhisperEncoder)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
@ -548,36 +544,6 @@ class MiniCPMO(MiniCPMV2_6):
self.audio_token_id = None
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/openbmb/MiniCPM-o-2_6-int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
# MiniCPMO GPTQ model leave vpm unquantized.
quant_config = self._maybe_ignore_quant_config(quant_config)
return super().init_vision_module(config, quant_config, prefix)
def init_resampler(
self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
# MiniCPMO GPTQ model leave resampler unquantized.
quant_config = self._maybe_ignore_quant_config(quant_config)
return super().init_resampler(embed_dim, vision_dim, quant_config,
prefix)
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily
audio_config = self.config.audio_config

View File

@ -31,9 +31,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.models.aimv2 import AIMv2Model
from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
@ -416,7 +413,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
self.visual_tokenizer = VisualTokenizer(
config=config.visual_tokenizer_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
@ -430,14 +427,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.get_language_model().make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)

View File

@ -52,9 +52,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -1015,8 +1012,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(
self.quant_config),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
@ -1032,13 +1028,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
return None
return config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):

View File

@ -50,9 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -1270,7 +1267,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
@ -1286,14 +1283,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):

View File

@ -46,9 +46,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
@ -149,24 +146,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.gate")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization while AutoRound does.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
if isinstance(
quant_config,
(GPTQConfig,
GPTQMarlinConfig)) and not quant_config.autoround_version:
return None
return quant_config
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert hidden_states.dim(
@ -699,4 +683,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
return self.model.get_expert_mapping()

View File

@ -41,9 +41,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -119,12 +116,11 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.gate")
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3NextMLP(
@ -142,16 +138,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
1,
bias=False)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization while AutoRound does.
if isinstance(
quant_config,
(GPTQConfig,
GPTQMarlinConfig)) and not quant_config.autoround_version:
return None
return quant_config
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape

View File

@ -50,9 +50,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -1058,7 +1055,7 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
@ -1116,13 +1113,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):

View File

@ -322,7 +322,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

View File

@ -30,6 +30,7 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
super().__init__(tokenizer, *args, **kwargs)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.assistant_token = "<|assistant|>"
if not self.model_tokenizer:
raise ValueError(
@ -38,14 +39,26 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
self.assistant_token_id = self.vocab.get(self.assistant_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
or self.think_end_token_id is None
or self.assistant_token_id is None):
raise RuntimeError(
"Glm4MoeModel reasoning parser could not locate "
"think start/end tokens in the tokenizer!")
"think start/end or assistant tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
"""
GLM's chat template has <think></think> tokens after every
<|assistant|> token. Thus, we need to check if </think> is
after the most recent <|assistant|> token (if present).
"""
for token_id in input_ids[::-1]:
if token_id == self.think_end_token_id:
return True
elif token_id == self.assistant_token_id:
return False
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""

View File

@ -4,6 +4,7 @@
import json
import os
import time
from dataclasses import asdict
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Literal, Optional, TypeVar, Union
@ -27,7 +28,8 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import check_gguf_file
from vllm.transformers_utils.utils import (check_gguf_file,
parse_safetensors_file_metadata)
if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@ -999,6 +1001,34 @@ def try_get_tokenizer_config(
return None
def get_safetensors_params_metadata(
model: str,
*,
revision: Optional[str] = None,
) -> dict[str, Any]:
"""
Get the safetensors metadata for remote model repository.
"""
full_metadata = {}
if (model_path := Path(model)).exists():
safetensors_to_check = model_path.glob("*.safetensors")
full_metadata = {
param_name: info
for file_path in safetensors_to_check if file_path.is_file()
for param_name, info in parse_safetensors_file_metadata(
file_path).items()
}
else:
repo_mt = try_get_safetensors_metadata(model, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
full_metadata = {
param_name: asdict(info)
for file_mt in files_mt.values()
for param_name, info in file_mt.tensors.items()
}
return full_metadata
def _download_mistral_config_file(model, revision) -> dict:
config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision)

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import struct
from functools import cache
from os import PathLike
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.envs import VLLM_MODEL_REDIRECT_PATH
from vllm.logger import init_logger
@ -97,3 +98,11 @@ def maybe_model_redirect(model: str) -> str:
return redirect_model
return model
def parse_safetensors_file_metadata(
path: Union[str, PathLike]) -> dict[str, Any]:
with open(path, "rb") as f:
length_of_metadata = struct.unpack('<Q', f.read(8))[0]
metadata = json.loads(f.read(length_of_metadata).decode('utf-8'))
return metadata

View File

@ -45,6 +45,7 @@ from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from pathlib import Path
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, TextIO, TypeVar, Union, cast, overload)
@ -3536,3 +3537,23 @@ def set_env_var(key, value):
del os.environ[key]
else:
os.environ[key] = old
def unique_filepath(fn: Callable[[int], Path]) -> Path:
"""
unique_filepath returns a unique path by trying
to include an integer in increasing order.
fn should be a callable that returns a path that
includes the passed int at a fixed location.
Note: This function has a TOCTOU race condition.
Caller should use atomic operations (e.g., open with 'x' mode)
when creating the file to ensure thread safety.
"""
i = 0
while True:
p = fn(i)
if not p.exists():
return p
i += 1

View File

@ -22,10 +22,10 @@ class CudagraphDispatcher:
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (communicate via forward context),
the cudagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without cudagraph (if mode no match or mode is NONE).
based on the input key. After dispatching (communicated via forward
context), the cudagraph wrappers will trust the dispatch key to either
capture or replay (if the mode matches), or pass through to the underlying
runnable without cudagraph (if the mode does not match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
@ -57,19 +57,15 @@ class CudagraphDispatcher:
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}"
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for cudagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
@ -94,10 +90,13 @@ class CudagraphDispatcher:
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
self,
batch_descriptor: BatchDescriptor,
use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a cudagraph mode.
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
@ -107,14 +106,16 @@ class CudagraphDispatcher:
"initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# also check if non-uniform key exists for more "general"
# piecewise cudagraph

View File

@ -11,6 +11,7 @@ try:
from ray.util.metrics import Metric
except ImportError:
ray_metrics = None
import regex as re
class RayPrometheusMetric:
@ -42,6 +43,21 @@ class RayPrometheusMetric:
return self
@staticmethod
def _get_sanitized_opentelemetry_name(name: str) -> str:
"""
For compatibility with Ray + OpenTelemetry, the metric name must be
sanitized. In particular, this replaces disallowed character (e.g., ':')
with '_' in the metric name.
Allowed characters: a-z, A-Z, 0-9, _
# ruff: noqa: E501
Ref: https://github.com/open-telemetry/opentelemetry-cpp/blob/main/sdk/src/metrics/instrument_metadata_validator.cc#L22-L23
Ref: https://github.com/ray-project/ray/blob/master/src/ray/stats/metric.cc#L107
"""
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
class RayGaugeWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Gauge to provide same API as
@ -58,6 +74,7 @@ class RayGaugeWrapper(RayPrometheusMetric):
# implemented at the observability layer (Prometheus/Grafana).
del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None
name = self._get_sanitized_opentelemetry_name(name)
self.metric = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
@ -79,6 +96,7 @@ class RayCounterWrapper(RayPrometheusMetric):
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
name = self._get_sanitized_opentelemetry_name(name)
self.metric = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
@ -99,6 +117,7 @@ class RayHistogramWrapper(RayPrometheusMetric):
labelnames: Optional[list[str]] = None,
buckets: Optional[list[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
name = self._get_sanitized_opentelemetry_name(name)
boundaries = buckets if buckets else []
self.metric = ray_metrics.Histogram(name=name,
description=documentation,

View File

@ -923,11 +923,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
Optional[SpecDecodeMetadata], np.ndarray,
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
Optional[torch.Tensor]]:
Optional[torch.Tensor], bool]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata
logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens, use_cascade_attn
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -1045,11 +1047,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
num_tokens_unpadded)
uniform_decode = \
(max_num_scheduled_tokens == self.uniform_decode_query_len) and \
(total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
ubatch_slices, num_tokens_after_padding = \
ubatch_split(num_scheduled_tokens,
num_tokens_unpadded,
num_tokens_padded,
self.vllm_config)
uniform_decode=uniform_decode,
vllm_config=self.vllm_config)
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@ -1131,6 +1137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
use_cascade_attn = False
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
@ -1247,9 +1254,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade",
False)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
# disable cascade attention when DBO
if ubatch_slices is not None:
use_cascade_attn = False
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
@ -1257,7 +1270,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens, ubatch_slices,
num_tokens_after_padding)
num_tokens_after_padding, use_cascade_attn)
def _compute_cascade_attn_prefix_len(
self,
@ -2247,8 +2260,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len, ubatch_slices, num_tokens_after_padding
) = self._prepare_inputs(scheduler_output)
max_query_len, ubatch_slices, num_tokens_after_padding,
use_cascade_attn) = self._prepare_inputs(scheduler_output)
(
num_scheduled_tokens,
@ -2269,7 +2282,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
self.cudagraph_dispatcher.dispatch(batch_descriptor,
use_cascade_attn)
# This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens`
@ -2697,16 +2711,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model = self.get_model()
model_loader.load_weights(model, model_config=self.model_config)
model_loader.load_weights(self.get_model(),
model_config=self.model_config)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
model = self.get_model()
TensorizerLoader.save_model(
model,
self.get_model(),
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
@ -2922,9 +2935,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
assert cudagraph_runtime_mode is None or \
cudagraph_runtime_mode.valid_runtime_modes()
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
@ -2989,7 +3001,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
self.vllm_config,
uniform_decode=uniform_decode,
vllm_config=self.vllm_config,
)
# If we failed to microbatch, currently need to resynchronize
@ -3108,7 +3121,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
uniform_decode=uniform_decode)) \
if not is_profile else (CUDAGraphMode.NONE, None)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
@ -3448,8 +3462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
cudagraph_runtime_mode.valid_runtime_modes(), \
f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
@ -3572,6 +3586,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.calculate_reorder_batch_threshold()
def initialize_cudagraph_capture(self) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
Then initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None

View File

@ -139,6 +139,7 @@ def ubatch_split(
num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int,
num_tokens_padded: int,
uniform_decode: bool,
vllm_config: VllmConfig,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
@ -164,7 +165,7 @@ def ubatch_split(
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
vllm_config,
uniform_decode=uniform_decode,
)
# Don't microbatch unless every other DP worker is also microbatching