Compare commits

...

21 Commits

Author SHA1 Message Date
f67299f66d [compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)
Signed-off-by: angelayi <yiangela7@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
(cherry picked from commit f36292dbee27a5ebe0e7115c061b82f6f5372dcf)
2025-11-15 22:05:00 -08:00
5f6666fb5a LLaMA4 LoRA Adapter Enablement (#28602)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Co-authored-by: Wei Wei <wwei6@meta.com>
(cherry picked from commit 964d65deedb9ae0480fecdb2e726ba16d63409d7)
2025-11-15 21:57:58 -08:00
66a62d73da [Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)
Signed-off-by: NickLucche <nlucches@redhat.com>
(cherry picked from commit 96b23b8e3b5cd5d05345489a304e65f7ab53ef8e)
2025-11-15 21:57:42 -08:00
c505dd6b61 [BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702)
(cherry picked from commit db56a59970a84842da2adc3aa64e436f42448b48)
2025-11-15 21:56:16 -08:00
f7adf64aac [BugFix] Fix multi-modal async scheduling race condition (#28706)
Signed-off-by: Nick Hill <nhill@redhat.com>
(cherry picked from commit bc3e43069aadb1fa301a9f60a22872b6ec4453b9)
2025-11-15 21:56:05 -08:00
240d6b1758 [Bugfix] fix dots.ocr pp support (#28705)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
(cherry picked from commit c36bcfe6b37967ab52763f2ddb9400ff4fe3885b)
2025-11-15 21:54:30 -08:00
b315ba9052 [Misc] Update xformers to 0.33.0.post1 (#28678)
Signed-off-by: Roger Wang <hey@rogerw.io>
(cherry picked from commit 0aecd9138f45f6f687858ac1e0c5206d30c8425e)
2025-11-15 21:54:26 -08:00
Qiu
9b24cf6f47 [bugfix] correct local_chunk_len for DCP in reorg_kvcache with long context (#28526)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 968060c15adc0b68a76d37db00acf1273a23b829)
2025-11-15 21:54:19 -08:00
facbc2c21e [BugFix] Ensure EngineArgs.create_engine_config is idempotent (#28515)
Signed-off-by: Nick Hill <nhill@redhat.com>
(cherry picked from commit 327c0a9a23f2939923d02fbf882640753bf1e030)
2025-11-15 21:54:15 -08:00
e2fd9a2edf [Misc] Turn off encoder torch compile by default (#28634)
Signed-off-by: Roger Wang <hey@rogerw.io>
(cherry picked from commit d3387750f191f3bcf6607db95436147bbccfacb3)
2025-11-15 21:54:05 -08:00
1326f17492 Use official xformers-0.0.33 built for PT 2.9 (#28600)
Signed-off-by: Huy Do <huydhn@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
(cherry picked from commit c33b87e7778d2a6900e73969c38785e0254f880b)
2025-11-15 21:53:04 -08:00
caf412e593 Skip models that cannot currently init on Transformers v5 (#28471)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
(cherry picked from commit 51c599f0ec9c754ddf9f6094f27c1fa2be76b318)
2025-11-15 21:52:58 -08:00
a035b5cffb [CI] Skip "Multi-Modal Models Test (Extended) 3" test that's broken in current Transformers (#28559)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
(cherry picked from commit a39dd7bb06c3bea055057d5c272ca952e0e000bf)
2025-11-15 21:52:46 -08:00
5b4dcecdd7 Remove deprecated fields from CompilationConfig (#27593)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
(cherry picked from commit a742134cc5fbdec6c2af1ef383704aac5c445fbd)
2025-11-15 21:48:13 -08:00
609bb244bd [Performance] Cache loaded custom logitsprocs to avoid overheads (#28462)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit 3f770f4427cb926c24af540cc72d1b5901f7f702)
2025-11-15 21:44:19 -08:00
3a9ea77c35 [Bugfix] Fix max image size for PaddleOCR-VL (#28442)
Signed-off-by: Roger Wang <hey@rogerw.io>
(cherry picked from commit 4fd4b743a23cc6ccbd832f11be12317a8c2f0fbc)
2025-11-15 21:44:19 -08:00
28a82bb5e6 [Bugfix] Fix Stream Sync for Shared Expert Overlap (#28430)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
(cherry picked from commit e605e8e3233f895340f46665f93ab37b307491aa)
2025-11-15 21:44:19 -08:00
2a21f3e7c2 Only register rocm_aiter_ops if aiter is found (#28428)
Signed-off-by: mgoin <mgoin64@gmail.com>
(cherry picked from commit f2d9ad0620d9aa71481527dcfafdb8357da00470)
2025-11-15 21:36:19 -08:00
ab625ba2fc [CI/Test Fix] Fix CP tests on Blackwell (#28404)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 39029d519276fddbe0c36440e0eefcdda069b969)
2025-11-15 21:36:19 -08:00
324c8cbd79 [Feature] Refactor batch invariant fp8 DeepGEMM (#27606)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
(cherry picked from commit 35d801f13fa5bd79ae74707388b1fa4e1caf9ba5)
2025-11-15 21:35:58 -08:00
75ecaf48fe [Bugfix] Ensure calculated KV scales are applied in attention. (#27232)
Signed-off-by: adabeyta <aabeyta@redhat.com>
(cherry picked from commit a5a790eea6035760c71eae1861c1e5f369bc6d08)
2025-11-15 21:33:58 -08:00
44 changed files with 929 additions and 876 deletions

View File

@ -441,6 +441,7 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_config.py
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
@ -471,10 +472,11 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
# Limit to no custom ops to reduce running time
# fp8 kv scales not supported on sm89, tested on Blackwell instead
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test
timeout_in_minutes: 20
@ -867,12 +869,12 @@ steps:
optional: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- pytest -v -s tests/models/test_initialization.py
- pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_transformers.py
- pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py
# - pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
- python3 examples/offline_inference/basic/chat.py
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
@ -912,7 +914,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- label: Blackwell Fusion Tests # 30 min
- label: Blackwell Fusion & Compile Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
@ -932,8 +934,10 @@ steps:
# this runner has 2 GPUs available even though num_gpus=2 is not set
- pytest -v -s tests/compile/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
# Wrap with quotes to escape yaml
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40
@ -951,6 +955,7 @@ steps:
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- tests/compile/test_fusions_e2e.py
- tests/compile/test_full_graph.py
commands:
- nvidia-smi
# Run all e2e fusion tests
@ -1250,7 +1255,8 @@ steps:
- pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
- pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py

View File

@ -218,16 +218,6 @@ outputs = model.generate(
)
```
### Migration from legacy flags
Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:
* `use_cudagraph=False``NONE`.
* `use_cudagraph=True` and `full_cuda_graph=False``PIECEWISE`.
* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy.
As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead.
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.

View File

@ -9,7 +9,6 @@ torch==2.9.0
torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.2

View File

@ -203,7 +203,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
@ -281,7 +281,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=False,
cudagraph_mode=CUDAGraphMode.NONE,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,
)

View File

@ -62,7 +62,6 @@ def _run_simple_model(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,

View File

@ -449,7 +449,6 @@ def benchmark():
if piecewise:
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)

View File

@ -2,8 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
@ -11,7 +13,7 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
from vllm.utils.torch_utils import _is_torch_equal_or_newer
def test_version():
@ -23,14 +25,6 @@ def test_version():
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_use_cudagraphs_dynamic():
vllm_config = VllmConfig()
# Default V1 configuration now starts without cudagraphs enabled; the
# engine decides when to capture based on runtime settings instead of a
# blanket default.
assert vllm_config.compilation_config.use_cudagraph
def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
compilation_config = {
"use_cudagraph": False, # speed things up a bit
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
}
with (
compilation_counter.expect(
@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
@pytest.mark.parametrize(
"cudagraph_mode,num_cudagraph_captured",
[
(CUDAGraphMode.NONE, 0),
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
(CUDAGraphMode.PIECEWISE, 13),
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
],
)
def test_use_cudagraphs(
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = {
"cudagraph_capture_sizes": [100],
"use_cudagraph": enabled,
"cudagraph_mode": cudagraph_mode,
}
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
with (
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
num_cudagraph_captured=num_cudagraph_captured,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
@ -168,19 +173,18 @@ def test_splitting_ops_dynamic():
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
# When attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
@ -189,29 +193,41 @@ def test_splitting_ops_dynamic():
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With the new simplified logic, attention fusion works with splitting_ops
assert config.compilation_config.splitting_ops_contain_attention()
# cudagraph mode remains PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if is_torch_equal_or_newer("2.9.0.dev"):
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(ValidationError):
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
# When both use_inductor_graph_partition and attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_should_split():
@ -293,25 +309,36 @@ def test_should_split():
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"use_cudagraph",
"cudagraph_mode",
"expected_max_size",
),
[
(None, None, 1, False, 2048, True, 512),
([1, 2, 4], 4, 1, False, 2048, True, 4),
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
([1, 256], None, 1, False, 2048, 256),
([], None, 1, False, 2048, False, 0),
(None, 0, 1, False, 2048, False, 0),
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
(
[1, 2, 4],
8,
1,
False,
2048,
CUDAGraphMode.FULL_AND_PIECEWISE,
ValidationError,
),
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, True, 256),
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, True, RuntimeError),
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, True, RuntimeError),
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
],
)
def test_cudagraph_sizes_post_init(
@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init(
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
use_cudagraph,
cudagraph_mode,
expected_max_size,
):
ctx = nullcontext()
if isinstance(expected_max_size, Exception):
if expected_max_size == ValidationError:
ctx = pytest.raises(expected_max_size)
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
with ctx:
with (
ctx,
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
):
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init(
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_seqs=min(max_num_batched_tokens, 128),
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()
assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)
assert (
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)

View File

@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,

View File

@ -183,8 +183,14 @@ def test_custom_compile_config(
"compilation_mode",
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
)
def test_fp8_kv_scale_compile(compilation_mode: int):
model = "Qwen/Qwen2-0.5B"
@pytest.mark.parametrize(
"model",
[
"Qwen/Qwen2-0.5B", # Standard attention model
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
],
)
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
model_kwargs = {
"quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3",

View File

@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test
is_blackwell = lambda: current_platform.is_device_capability(100)
"""Are we running on Blackwell, a lot of tests depend on it"""
class Matches(NamedTuple):
attention_fusion: int = 0
allreduce_fusion: int = 0
sequence_parallel: int = 0
async_tp: int = 0
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
attention_fusions: int
allreduce_fusions: int | None = None
backend: AttentionBackendEnum
matches: Matches
MODELS_FP8: list[ModelBackendTestCase] = []
@ -38,17 +47,33 @@ if current_platform.is_cuda():
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
# TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
# so FI attention+fp8_quant is at least tested once
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER
if is_blackwell()
else AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
# https://github.com/vllm-project/vllm/issues/28568
# TODO FlashInfer attn broken on Blackwell for llama4:
# https://github.com/vllm-project/vllm/issues/28604
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
sequence_parallel=96,
async_tp=95, # mlp is moe, no fusion there
),
),
]
@ -56,9 +81,13 @@ if current_platform.is_cuda():
ModelBackendTestCase(
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=32,
allreduce_fusions=65,
backend=AttentionBackendEnum.FLASHINFER,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
]
@ -67,9 +96,24 @@ if current_platform.is_cuda():
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=0,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=0,
allreduce_fusion=97,
sequence_parallel=97,
async_tp=96, # MLP is MoE, half the fusions of dense
),
),
]
@ -78,20 +122,20 @@ elif current_platform.is_rocm():
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
attention_fusions=32,
backend=AttentionBackendEnum.ROCM_ATTN,
matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
matches=Matches(attention_fusion=32),
),
]
@ -99,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
"model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
@ -110,16 +153,15 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
if backend == AttentionBackendEnum.FLASHINFER and (
not is_blackwell() or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
@ -162,12 +204,12 @@ def test_attn_quant(
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall(
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions
assert len(log_matches) == 1, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
@ -180,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
@ -201,9 +242,8 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
@ -212,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
@ -251,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
matches = re.findall(
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions
assert int(matches[1]) == attention_fusions
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
matches = re.findall(
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions
assert int(matches[1]) == allreduce_fusions
assert int(log_matches[0]) == matches.allreduce_fusion
assert int(log_matches[1]) == matches.allreduce_fusion
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="sequence parallel only tested on CUDA",
)
def test_tp2_attn_quant_async_tp(
model_name: str,
model_kwargs: dict,
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if is_blackwell():
# TODO: https://github.com/vllm-project/vllm/issues/27893
pytest.skip("Blackwell is not supported for AsyncTP pass")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER:
if not has_flashinfer():
pytest.skip("FlashInfer backend requires flashinfer installed")
if not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
log_matches = re.findall(
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.sequence_parallel
assert int(log_matches[1]) == matches.sequence_parallel
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.async_tp
assert int(log_matches[1]) == matches.async_tp
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):

View File

@ -10,8 +10,8 @@ from vllm.platforms import current_platform
def test_compile():
vllm_config = VllmConfig()
# Default configuration compiles mm encoder
assert vllm_config.compilation_config.compile_mm_encoder
# Default configuration does not compile mm encoder
assert not vllm_config.compilation_config.compile_mm_encoder
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@ -39,7 +39,10 @@ def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
"Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=2048,
gpu_memory_utilization=0.8,
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
compilation_config={
"mode": CompilationMode.VLLM_COMPILE,
"compile_mm_encoder": True,
},
) as _,
):
pass

View File

@ -5,15 +5,15 @@ import pytest
import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.fx_utils import find_auto_fn
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (
CompilationConfig,
CUDAGraphMode,
DeviceConfig,
ModelConfig,
PassConfig,
@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
@ -43,172 +44,157 @@ prompts = [
]
class TestModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
def forward(self, x):
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
y2, resid = self.norm[1](x2, resid)
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
y3, resid = self.norm[2](x3, resid)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
return norm_output, residual_output
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default,
torch.ops.vllm.reduce_scatter.default,
]
def ops_in_model(self):
return [torch.ops._C.fused_add_rms_norm.default]
if RMSNorm.enabled():
return [
torch.ops._C.rms_norm.default,
torch.ops._C.fused_add_rms_norm.default,
]
else:
return []
class TestQuantModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)), requires_grad=False
)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout.
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
return fp8_linear_result, residual_output
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def ops_in_model_before(self):
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
# The following are only removed if fusion happens
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
ops_to_remove.extend(
[
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.static_scaled_fp8_quant.default,
]
)
return ops_to_remove
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
ops_to_add = [
torch.ops.vllm.reduce_scatter.default,
return [
torch.ops.vllm.all_gather.default,
torch.ops.vllm.reduce_scatter.default,
]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
]
# The following is only added if fusion happens
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
return ops_to_add
def ops_in_model(self):
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
# If fusion happens, the fused op is the one
# we check for (de)functionalization
if self.vllm_config.compilation_config.pass_config.enable_fusion:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
else:
# If no fusion, the original ops are checked
elif RMSNorm.enabled():
return [
torch.ops._C.fused_add_rms_norm.default,
# TODO functionalization pass does not handle this yet
# torch.ops._C.static_scaled_fp8_quant.default,
]
elif self.fp8_linear.quant_fp8.enabled():
return [
torch.ops._C.static_scaled_fp8_quant.default,
]
else:
return []
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
@pytest.mark.parametrize(
"test_model_cls, custom_ops",
[
(TestAllReduceRMSNormModel, "+rms_norm"),
(TestAllReduceRMSNormModel, "-rms_norm"),
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
test_model_cls: type[torch.nn.Module],
custom_ops: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
dynamic: bool,
):
num_processes = 2
@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
args=(
num_processes,
test_model_cls,
custom_ops,
batch_size,
seq_len,
hidden_size,
dtype,
enable_fusion,
dynamic,
),
nprocs=nprocs,
)
@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: type[torch.nn.Module],
custom_ops: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
dynamic: bool,
):
current_platform.seed_everything(0)
@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
custom_ops_list = custom_ops.split(",") if custom_ops else []
compilation_config = CompilationConfig(
splitting_ops=[], # avoid automatic rms_norm enablement
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
custom_ops=custom_ops_list,
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
)
),
) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda"))
@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
backend = TestBackend(*passes_for_backend)
model = test_model_cls(hidden_size, hidden_size * 2)
model = test_model_cls(hidden_size)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
assert sequence_parallelism_pass.matched_count == 1
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
assert sequence_parallelism_pass.matched_count == 4
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
for op in model.ops_in_model_before():
assert backend.op_count(op, before=True) == 4
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
for op in model.ops_in_model_after():
assert backend.op_count(op, before=False) == 4
# check if the functionalization pass is applied
for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())
find_auto_fn(backend.graph_post_pass.nodes, op)

View File

@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import Literal, NamedTuple
import pytest
import torch
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
@ -254,6 +255,17 @@ def test_cp_generation(
test_options: CPTestOptions,
num_gpus_available,
):
if (
model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
and torch.cuda.get_device_capability() < (9, 0)
):
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
if (
model_id == "bigcode/gpt_bigcode-santacoder"
and torch.cuda.get_device_capability() != (9, 0)
):
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
_compare_cp_with_tp(
model_id,
parallel_setup,

View File

@ -18,6 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS
@ -161,6 +162,7 @@ def _compare_sp(
test_options: SPTestOptions,
num_gpus_available: int,
use_inductor_graph_partition: bool,
enable_async_tp: bool,
*,
method: Literal["generate", "encode"],
is_multimodal: bool,
@ -244,10 +246,10 @@ def _compare_sp(
compilation_config = {
"mode": CompilationMode.VLLM_COMPILE,
"custom_ops": ["+rms_norm"],
"compile_sizes": [4, 8],
"pass_config": {
"enable_sequence_parallelism": True,
"enable_async_tp": enable_async_tp,
"enable_fusion": enable_fusion,
"enable_noop": True,
},
@ -307,6 +309,7 @@ SP_TEST_MODELS = [
],
)
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
@create_new_process_for_each_test()
def test_tp_sp_generation(
model_id: str,
@ -316,10 +319,19 @@ def test_tp_sp_generation(
test_options: SPTestOptions,
num_gpus_available,
use_inductor_graph_partition: bool,
enable_async_tp: bool,
):
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# Skip FP8 SP-only test on sm89 (compute capability 8.9)
if (
"fp8" in model_id.lower()
and current_platform.get_device_capability() < (9, 0)
and (not enable_async_tp)
):
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
_compare_sp(
model_id,
parallel_setup,
@ -328,6 +340,7 @@ def test_tp_sp_generation(
test_options,
num_gpus_available,
use_inductor_graph_partition,
enable_async_tp=enable_async_tp,
method="generate",
is_multimodal=False,
)

View File

@ -3,6 +3,3 @@ accuracy_threshold: 0.45
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220
env:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1"

View File

@ -170,6 +170,7 @@ def test_cascade(
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
max_num_splits=0, # no max
fa_version=fa_version,
)

View File

@ -10,11 +10,13 @@ from collections import defaultdict
from pathlib import PosixPath
import pytest
from packaging.version import Version
from transformers import (
AutoModel,
AutoModelForImageTextToText,
AutoModelForTextToWaveform,
)
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.platforms import current_platform
from vllm.utils.func_utils import identity
@ -137,6 +139,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>",
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>",
enforce_eager=False,
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
@ -166,6 +169,7 @@ VLM_TEST_SETTINGS = {
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO,
),
enforce_eager=False,
needs_video_metadata=True,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
@ -859,6 +863,12 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 4},
)
],
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) == Version("4.57.1"),
reason="This model is broken in Transformers v4.57.1",
)
],
),
# regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention": VLMTestInfo(

View File

@ -61,10 +61,8 @@ def test_qwen2_5_vl_evs_functionality(
model,
runner="generate",
max_model_len=4000,
max_num_seqs=1,
dtype=dtype,
limit_mm_per_prompt={"video": 1},
tensor_parallel_size=1,
video_pruning_rate=video_pruning_rate,
) as vllm_model:
# Generate output - this should not crash

View File

@ -980,8 +980,10 @@ def test_hybrid_block_table_initialization():
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks)
expected_kernel_blocks = block_table.map_to_kernel_blocks(
np.array(kvcache_manager_blocks),
block_table.blocks_per_kv_block,
block_table._kernel_block_arange,
)
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)

View File

@ -938,4 +938,5 @@ class rocm_aiter_ops:
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
rocm_aiter_ops.register_ops_once()
if IS_AITER_FOUND:
rocm_aiter_ops.register_ops_once()

View File

@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
return output
else:
# We can still access forward context to check calculation flag
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales(q, kv_c_normed, k_pe)
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
if attn_metadata is None or not getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)

View File

@ -195,7 +195,6 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
@ -10,98 +12,28 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
class _RMSNormAndQuantOpHelper:
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
def get_first_out_wrapper(fn):
@functools.wraps(fn)
def wrapper(*args):
return fn(*args)[0]
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: torch._ops.OpOverload | None = None,
**kwargs,
):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.quant_op = quant_op
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=result_buffer,
input=input_tensor,
weight=weight_tensor,
epsilon=self.epsilon,
)
def _functional_fused_add_rmsnorm(
self, input_tensor, residual_tensor, weight_tensor
):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=input_tensor,
residual=residual_tensor,
weight=weight_tensor,
epsilon=self.epsilon,
)
def _functional_rmsnorm_then_quant(
self,
rmsnorm_result_buffer,
quant_result_buffer,
input_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
rmsnorm_out_tuple = self._functional_rmsnorm(
rmsnorm_result_buffer, input_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=rmsnorm_out_tuple[1],
scale=scale_tensor,
)
return quant_out_tuple
def _functional_fused_add_rmsnorm_then_quant(
self,
quant_result_buffer,
input_tensor,
residual_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
input_tensor, residual_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale_tensor,
)
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
return wrapper
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: torch._ops.OpOverload | None = None,
**kwargs,
):
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, permute, arg3_1]
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
all_reduce = self._all_reduce(input)
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
return rmsnorm[1], all_reduce
return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter)
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm[1])
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
pm.register_replacement(
@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1], rmsnorm[2]
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
return rmsnorm[0], rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, rmsnorm[2]
residual = residual[0 : reduce_scatter.size(0), ...]
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
all_gather = self._all_gather(rmsnorm[0])
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
normalized = self._all_gather(rmsnorm[1])
return normalized
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
self,
epsilon: float,
dtype: torch.dtype,
device: str,
):
super().__init__(epsilon, dtype, device, quant_op=op)
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = self._all_reduce(input)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, all_reduce, weight, scale
)
return static_fp8[1], all_reduce
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(
reduce_scatter, dtype=rmsnorm_result.dtype
)
quant_result = torch.empty_like(
rmsnorm_result, # Output of RMSNorm
dtype=quant_result.dtype,
)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, reduce_scatter, weight, scale
)
all_gather = self._all_gather(static_fp8[1])
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
return all_gather, reduce_scatter
@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, rmsnorm_residual_out = (
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
result, all_reduce, residual, rms_norm_weights, scale
)
rms, residual_out = self.rmsnorm_matcher(
all_reduce, rms_norm_weights, residual
)
return static_fp8[1], rmsnorm_residual_out
quant, _ = self.quant_matcher(rms, scale)
return quant, residual_out
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# add a temporary slice which will become a noop
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, rmsnorm_residual_out = (
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
residual = residual[0 : reduce_scatter.size(0), ...]
rms, residual_out = self.rmsnorm_matcher(
reduce_scatter, rms_norm_weights, residual
)
all_gather = self._all_gather(static_fp8[1])
return all_gather, rmsnorm_residual_out
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, residual_out
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
result, all_reduce, residual, rms_norm_weights, scale
)
return static_fp8[1]
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
normalized = self._all_gather(static_fp8[1])
return normalized
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
# Used to cleanup redundant views created temporarily
# to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config)
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
epsilon, self.model_dtype, self.device
).register(self.patterns)
# Normal RMSNorm patterns
@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
epsilon, self.model_dtype, self.device
).register(self.patterns)
LastAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
self.noop_cleanup(graph)

View File

@ -152,7 +152,6 @@ class CompilationConfig:
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
@ -162,7 +161,6 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
- Inductor compilation:
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
@ -268,9 +266,10 @@ class CompilationConfig:
If None, defaults to attention ops for piecewise cudagraphs.
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = True
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl`."""
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
# Inductor capture
use_inductor: bool | None = None
@ -342,18 +341,6 @@ class CompilationConfig:
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation:
- False: cudagraph inside compilation is not used.\n
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
_PIECEWISE instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
@ -371,15 +358,6 @@ class CompilationConfig:
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: bool | None = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
"""
cudagraph_specialize_lora: bool = True
"""Whether to create separate cuda graphs for cases with and without active
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
@ -528,13 +506,19 @@ class CompilationConfig:
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
"""
enable parse the `cudagraph_mode` enum type from string
"""
"""Enable parsing of the `cudagraph_mode` enum type from string."""
if isinstance(value, str):
return CUDAGraphMode[value.upper()]
return value
@field_validator("pass_config", mode="before")
@classmethod
def validate_pass_config_before(cls, value: Any) -> Any:
"""Enable parsing of the `pass_config` field from a dictionary."""
if isinstance(value, dict):
return PassConfig(**value)
return value
@field_validator("compile_cache_save_format")
@classmethod
def validate_compile_cache_save_format(cls, value: str) -> str:
@ -591,8 +575,10 @@ class CompilationConfig:
func if isinstance(func, InductorPass) else CallableInductorPass(func)
)
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
if self.pass_config.enable_qk_norm_rope_fusion:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")
@ -604,36 +590,6 @@ class CompilationConfig:
self.inductor_compile_config["combo_kernels"] = True
self.inductor_compile_config["benchmark_combo_kernel"] = True
# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning(
"use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
)
if (
self.cudagraph_mode is not None
and self.cudagraph_mode != CUDAGraphMode.NONE
):
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated."
)
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning(
"full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
)
if (
self.cudagraph_mode is not None
and not self.cudagraph_mode.has_full_cudagraphs()
):
raise ValueError(
"full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated."
)
self.cudagraph_mode = CUDAGraphMode.FULL
if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
@ -811,20 +767,19 @@ class CompilationConfig:
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
# For dynamo-partition (non-inductor) attention fusion,
# set splitting_ops to empty to avoid splitting at attention ops
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "

View File

@ -441,8 +441,6 @@ class VllmConfig:
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
@ -631,6 +629,32 @@ class VllmConfig:
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1()
if self.compilation_config.pass_config.enable_sequence_parallelism:
# With pipeline parallelism or dynamo partitioning,
# native rms norm tracing errors due to incorrect residual shape.
# Use custom rms norm to unblock. In the future,
# the pass will operate on higher-level IR to avoid the issue.
# TODO: https://github.com/vllm-project/vllm/issues/27894
is_fullgraph = (
self.compilation_config.use_inductor_graph_partition
or len(self.compilation_config.splitting_ops) == 0
)
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
if "-rms_norm" not in self.compilation_config.custom_ops:
self.compilation_config.custom_ops.append("+rms_norm")
else:
regime = (
"Dynamo partition"
if not is_fullgraph
else "pipeline parallelism"
)
logger.warning_once(
"Sequence parallelism not supported with"
"native rms_norm when using %s, "
"this will likely lead to an error.",
regime,
)
# final check of cudagraph mode after all possible updates
if current_platform.is_cuda_alike():
if (
@ -652,14 +676,6 @@ class VllmConfig:
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
# final migrate the deprecated flags
self.compilation_config.use_cudagraph = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
self.compilation_config.full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
if self.parallel_config.enable_dbo:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
@ -849,7 +865,9 @@ class VllmConfig:
)
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = dedup_sizes
cudagraph_capture_sizes = [
i for i in dedup_sizes if i <= max_num_tokens
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:

View File

@ -342,8 +342,8 @@ class MsgpackSerde(ObjectSerde):
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
self.encoder = MsgpackEncoder()
self.tensor_decoder = MsgpackDecoder(torch.Tensor)
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False)
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False)
self._mm_kwargs_item_cls = MultiModalKwargsItem
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
@ -368,7 +368,7 @@ class MsgpackSerde(ObjectSerde):
# pickle.loads do not read past the end of a pickled object
# within a large buffer, so we can skip storing the metadata size
type_name, nbytes, len_arr = pickle.loads(data_view)
serialized_data = bytearray(data_view[-nbytes:])
serialized_data = data_view[-nbytes:]
if type_name == torch.Tensor.__name__:
obj = []

View File

@ -48,6 +48,7 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@ -110,6 +111,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
@dataclass
class ReqMeta:
local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
@ -137,6 +140,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
assert load_remote_cache ^ save_to_host
_req = ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
@ -897,6 +901,8 @@ class NixlConnectorWorker:
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
)
self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake(
self,
@ -1092,6 +1098,22 @@ class NixlConnectorWorker:
if base_addr in seen_base_addresses:
continue
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
@ -1438,7 +1460,7 @@ class NixlConnectorWorker:
assert self.use_host_buffer
assert self.copy_blocks is not None
local_block_ids = meta.local_block_ids
local_block_ids = meta.local_physical_block_ids
self.copy_blocks(
self.host_xfer_buffers,
self.device_kv_caches,
@ -1451,7 +1473,7 @@ class NixlConnectorWorker:
"synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
",".join(map(str, local_block_ids)),
)
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
@ -1460,19 +1482,22 @@ class NixlConnectorWorker:
assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
",".join(map(str, meta.local_physical_block_ids)),
)
# blocking
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
meta.local_block_ids,
meta.local_block_ids,
meta.local_physical_block_ids,
meta.local_physical_block_ids,
"d2h",
)
@ -1541,7 +1566,7 @@ class NixlConnectorWorker:
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_block_ids
block_ids_to_permute += meta.local_physical_block_ids
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)
@ -1628,7 +1653,7 @@ class NixlConnectorWorker:
req_id,
xfer_state,
)
# mark all blocks for this request as invalid
# mark all (logical)blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
@ -1645,13 +1670,19 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
)
remote_engine_id = meta.remote_engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id,
remote_engine_id,
len(meta.local_block_ids),
len(meta.local_physical_block_ids),
len(meta.remote_block_ids),
)
# always store metadata for failure recovery
@ -1699,7 +1730,7 @@ class NixlConnectorWorker:
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
)
@ -1826,7 +1857,7 @@ class NixlConnectorWorker:
"Marking blocks as invalid.",
request_id,
)
# mark all blocks for this request as invalid
# mark all (logical) blocks for this request as invalid
if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer()
@ -1865,6 +1896,23 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_ids_np = np.array(block_ids)
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
return BlockTable.map_to_kernel_blocks(
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).

View File

@ -1625,40 +1625,39 @@ class EngineArgs:
)
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_detailed_traces=self.collect_detailed_traces,
)
# Compilation config overrides
compilation_config = copy.deepcopy(self.compilation_config)
if self.cuda_graph_sizes is not None:
logger.warning(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
"instead."
)
if self.compilation_config.cudagraph_capture_sizes is not None:
if compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"cuda_graph_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
if self.cudagraph_capture_sizes is not None:
if self.compilation_config.cudagraph_capture_sizes is not None:
if compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"cudagraph_capture_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
self.compilation_config.cudagraph_capture_sizes = (
self.cudagraph_capture_sizes
)
compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
if self.max_cudagraph_capture_size is not None:
if self.compilation_config.max_cudagraph_capture_size is not None:
if compilation_config.max_cudagraph_capture_size is not None:
raise ValueError(
"max_cudagraph_capture_size and compilation_config."
"max_cudagraph_capture_size are mutually exclusive"
)
self.compilation_config.max_cudagraph_capture_size = (
compilation_config.max_cudagraph_capture_size = (
self.max_cudagraph_capture_size
)
@ -1673,7 +1672,7 @@ class EngineArgs:
load_config=load_config,
structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config,
compilation_config=self.compilation_config,
compilation_config=compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
additional_config=self.additional_config,

View File

@ -2439,28 +2439,6 @@ class FusedMoE(CustomOp):
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
self.shared_experts_stream.wait_stream(current_stream())
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states)
else:
shared_output = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@ -2489,11 +2467,7 @@ class FusedMoE(CustomOp):
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
shared_output = self.shared_experts(staged_hidden_states)
final_hidden_states = (
shared_output,
final_hidden_states,
@ -2602,11 +2576,22 @@ class FusedMoE(CustomOp):
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
self.shared_experts_stream.wait_stream(current_stream())
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
shared_output = self.shared_experts(hidden_states_clone)
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: we dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
else:

View File

@ -43,7 +43,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
@ -95,11 +94,9 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
@ -554,83 +551,19 @@ class Fp8LinearMethod(LinearMethodBase):
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
# Call is_deep_gemm_supported() ahead of time for torch.compile
# dynamo has trouble tracing through
if self.block_quant and should_use_deepgemm_for_fp8_linear(
torch.bfloat16, layer.weight, self.use_deep_gemm
):
# use group quant consistent with block size across K
assert self.act_q_group_shape is not None
q_input, input_scale = QuantFP8(
False,
self.act_q_group_shape,
column_major_scales=True,
)(x)
output_2d = torch.empty(
(q_input.shape[0], layer.weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(layer.weight, layer.weight_scale),
output_2d,
)
if bias is not None:
output_2d = output_2d + bias
return output_2d
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
# Handle different quantization granularities
if self.block_quant:
# Block-wise quantization:
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
assert self.weight_block_size is not None
block_n, block_k = self.weight_block_size # Note: order is [N, K]
N, K = weight_fp8.shape
# determine expected number of blocks along N and K
num_blocks_n = (N + block_n - 1) // block_n
num_blocks_k = (K + block_k - 1) // block_k
# scale layout may be [num_blocks_n, num_blocks_k]
# or [num_blocks_k, num_blocks_n] depending on backend
if weight_scale.dim() != 2:
raise RuntimeError(
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
)
scale_rows, scale_cols = weight_scale.shape
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
if num_blocks_n == num_blocks_k:
# ambiguous square case, warn and skip transpose
logger.warning(
"Batch-invariant FP8: square block-scale %dx%d; "
"skipping transpose to avoid misorientation.",
scale_rows,
scale_cols,
)
else:
# clear KN -> transpose to NK
weight_scale = weight_scale.t()
# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
scale_expanded = weight_scale.repeat_interleave(
block_n, dim=0
).repeat_interleave(block_k, dim=1)
# Trim to exact weight size (in case of padding)
scale_expanded = scale_expanded[:N, :K]
weight_bf16 = weight_fp8 * scale_expanded
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# Per-tensor quantization: weight IS transposed to [K, N]
# scale should be scalar or [1] or per-output-channel [N]
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
@ -649,16 +582,7 @@ class Fp8LinearMethod(LinearMethodBase):
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
# For block quant, weight is [N, K], for per-tensor it's [K, N]
# F.linear expects weight to be [N, K], so:
if self.block_quant:
# Already in correct shape [N, K]
output = torch.nn.functional.linear(x, weight_bf16, bias)
else:
# Need to transpose back: [K, N] -> [N, K]
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
return output
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(

View File

@ -82,7 +82,8 @@ enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
kwargs["disable"] = True
super().__init__(*args, **kwargs)
def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None):

View File

@ -779,6 +779,10 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> DotsOCRImageInputs | None:

View File

@ -35,6 +35,7 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -45,6 +46,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
@ -68,11 +70,15 @@ from .interfaces import (
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import (
AutoWeightsLoader,
maybe_prefix,
)
from .vision import run_dp_sharded_vision_model
@ -724,7 +730,12 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
nn.Module,
SupportsMultiModal,
SupportsPP,
MixtureOfExperts,
SupportsEagle3,
SupportsLoRA,
):
merge_by_field_config = True
@ -1067,6 +1078,17 @@ class Llama4ForConditionalGeneration(
return updated_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.text_config.num_local_experts,
num_redundant_experts=self.num_redundant_experts,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@ -1113,3 +1135,13 @@ class Llama4ForConditionalGeneration(
)
return updated_params
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector.",
tower_model="vision_model.",
)

View File

@ -198,23 +198,18 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
if image_processor is None:
image_processor = self.get_image_processor()
do_resize = True
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width, height=image_height)
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
grid_t = 1
grid_h = preprocessed_size.height // patch_size
@ -227,8 +222,19 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
image_size = hf_config.vision_config.image_size
return ImageSize(height=image_size, width=image_size)
# See `smart_resize` for the calculation of the image size.
merge_size = hf_config.vision_config.spatial_merge_size
patch_size = hf_config.vision_config.patch_size
factor = merge_size * patch_size
max_num_tokens = self.get_image_processor().max_pixels // (factor**2)
# Find factors of max_num_tokens close to its square root
# to create a dummy image with a reasonable aspect ratio.
h_patches = int(math.sqrt(max_num_tokens))
while max_num_tokens % h_patches != 0:
h_patches -= 1
w_patches = max_num_tokens // h_patches
return ImageSize(height=h_patches * factor, width=w_patches * factor)
class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]):

View File

@ -13,7 +13,6 @@ from transformers import (
BatchFeature,
WhisperConfig,
WhisperFeatureExtractor,
WhisperProcessor,
)
from transformers.models.whisper.modeling_whisper import sinusoids
@ -660,16 +659,6 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self, **kwargs: object) -> WhisperProcessor:
# HACK: Transformers 4.53.2 has issue with whisper tokenizer to
# initialize processor. We use a monkeypatch to fix it here.
# See: https://github.com/vllm-project/vllm/issues/20224
processor_class = WhisperProcessor
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
if processor_class.tokenizer_class != tokenizer_class:
processor_class.tokenizer_class = tokenizer_class
return self.ctx.get_hf_processor(processor_class, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}

View File

@ -675,6 +675,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
@ -921,6 +922,7 @@ def cascade_attention(
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
max_num_splits: int,
fa_version: int,
prefix_scheduler_metadata: torch.Tensor | None = None,
suffix_scheduler_metadata: torch.Tensor | None = None,
@ -965,7 +967,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@ -990,7 +992,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.

View File

@ -123,7 +123,7 @@ class Mamba1AttentionMetadataBuilder(
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(

View File

@ -302,7 +302,7 @@ class Mamba2AttentionMetadataBuilder(
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)

View File

@ -349,6 +349,7 @@ class MLACommonPrefillMetadata:
local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor | None = None
cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
block_table: torch.Tensor
query_start_loc: torch.Tensor
@ -914,6 +915,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device, non_blocking=True
),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = chunked_context_metadata_cls(
@ -998,6 +1000,8 @@ def reorg_kvcache(
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@ -1013,6 +1017,9 @@ def reorg_kvcache(
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
@ -1024,20 +1031,31 @@ def reorg_kvcache(
):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
if local_context_len != 0:
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ local_context_len
+ local_chunk_len
]
k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ local_context_len
+ local_chunk_len
]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_context_len
cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
@ -1688,6 +1706,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
assert prefill_metadata.chunked_context.chunk_size is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
@ -1737,6 +1756,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)

View File

@ -81,7 +81,7 @@ class ShortConvAttentionMetadataBuilder(
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(

View File

@ -5,7 +5,7 @@ import inspect
import itertools
from abc import abstractmethod
from collections.abc import Sequence
from functools import partial
from functools import lru_cache, partial
from typing import TYPE_CHECKING
import torch
@ -216,11 +216,17 @@ def build_logitsprocs(
)
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
for logits_procs in _load_custom_logitsprocs(logits_processors):
logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)

View File

@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
MultiModalSharedField,
NestedTensors,
)
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.engine import UtilityResult
from vllm.v1.utils import tensor_data
@ -282,7 +283,9 @@ class MsgpackDecoder:
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Any | None = None):
def __init__(self, t: Any | None = None, share_mem: bool = True):
self.share_mem = share_mem
self.pin_tensors = is_pin_memory_available()
args = () if t is None else (t,)
self.decoder = msgpack.Decoder(
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
@ -347,21 +350,30 @@ class MsgpackDecoder:
# zero-copy decode. We assume the ndarray will not be kept around,
# as it now locks the whole received message buffer in memory.
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
arr = np.frombuffer(buffer, dtype=dtype)
if not self.share_mem:
arr = arr.copy()
return arr.reshape(shape)
def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
# Copy from inline representation, to decouple the memory storage
# of the message from the original buffer. And also make Torch
# not complain about a readonly memoryview.
buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype)
if not buffer: # torch.frombuffer doesn't like empty buffers
if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers
assert 0 in shape
return torch.empty(shape, dtype=torch_dtype)
# Create uint8 array
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Clone ensures tensor is backed by pytorch-owned memory for safe
# future async CPU->GPU transfer.
# Pin larger tensors for more efficient CPU->GPU transfer.
if not is_aux:
arr = arr.clone()
elif not self.share_mem:
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)

View File

@ -98,7 +98,9 @@ class BlockTable:
return
if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
@ -188,7 +190,12 @@ class BlockTable:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
@ -203,12 +210,12 @@ class BlockTable:
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if not self.use_hybrid_blocks:
if blocks_per_kv_block == 1:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
+ self._kernel_block_arange
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_arange
)
return kernel_block_ids.reshape(-1)

View File

@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
@ -2587,27 +2590,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
@ -2625,16 +2629,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# Set cudagraph mode to none if calc_kv_scales is true.
if attn_metadata is not None:
metadata_list = (
attn_metadata.values()
if isinstance(attn_metadata, dict)
else [attn_metadata]
)
if any(
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
):
cudagraph_runtime_mode = CUDAGraphMode.NONE
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.
if self.calculate_kv_scales:
cudagraph_runtime_mode = CUDAGraphMode.NONE
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
# Run the model.
# Use persistent buffers for CUDA graphs.