mirror of
https://github.com/vllm-project/vllm.git
synced 2025-11-18 00:40:58 +08:00
Compare commits
21 Commits
main
...
v0.11.1rc7
| Author | SHA1 | Date | |
|---|---|---|---|
| f67299f66d | |||
| 5f6666fb5a | |||
| 66a62d73da | |||
| c505dd6b61 | |||
| f7adf64aac | |||
| 240d6b1758 | |||
| b315ba9052 | |||
| 9b24cf6f47 | |||
| facbc2c21e | |||
| e2fd9a2edf | |||
| 1326f17492 | |||
| caf412e593 | |||
| a035b5cffb | |||
| 5b4dcecdd7 | |||
| 609bb244bd | |||
| 3a9ea77c35 | |||
| 28a82bb5e6 | |||
| 2a21f3e7c2 | |||
| ab625ba2fc | |||
| 324c8cbd79 | |||
| 75ecaf48fe |
@ -441,6 +441,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_config.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_fusion_attn.py
|
||||
@ -471,10 +472,11 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
# Limit to no custom ops to reduce running time
|
||||
# fp8 kv scales not supported on sm89, tested on Blackwell instead
|
||||
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
|
||||
# Limit to no custom ops to reduce running time
|
||||
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
|
||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
|
||||
|
||||
- label: Cudagraph test
|
||||
timeout_in_minutes: 20
|
||||
@ -867,12 +869,12 @@ steps:
|
||||
optional: true
|
||||
commands:
|
||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||
- pytest -v -s tests/models/test_initialization.py
|
||||
- pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
|
||||
- pytest -v -s tests/models/test_transformers.py
|
||||
- pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||
# - pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
# - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
# Whisper needs spawn method to avoid deadlock
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
|
||||
@ -912,7 +914,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
- label: Blackwell Fusion Tests # 30 min
|
||||
- label: Blackwell Fusion & Compile Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
@ -932,8 +934,10 @@ steps:
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# Wrap with quotes to escape yaml
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
|
||||
# Wrap with quotes to escape yaml
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
- label: Blackwell Fusion E2E Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
@ -951,6 +955,7 @@ steps:
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/test_fusions_e2e.py
|
||||
- tests/compile/test_full_graph.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
# Run all e2e fusion tests
|
||||
@ -1250,7 +1255,8 @@ steps:
|
||||
- pytest -v -s tests/compile/test_async_tp.py
|
||||
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
- pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
@ -218,16 +218,6 @@ outputs = model.generate(
|
||||
)
|
||||
```
|
||||
|
||||
### Migration from legacy flags
|
||||
|
||||
Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:
|
||||
|
||||
* `use_cudagraph=False` → `NONE`.
|
||||
* `use_cudagraph=True` and `full_cuda_graph=False` → `PIECEWISE`.
|
||||
* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy.
|
||||
|
||||
As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead.
|
||||
|
||||
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
|
||||
|
||||
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
|
||||
|
||||
@ -9,7 +9,6 @@ torch==2.9.0
|
||||
torchaudio==2.9.0
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.5.2
|
||||
|
||||
@ -203,7 +203,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -281,7 +281,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=False,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
|
||||
@ -62,7 +62,6 @@ def _run_simple_model(
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
||||
@ -449,7 +449,6 @@ def benchmark():
|
||||
if piecewise:
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
@ -11,7 +13,7 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
def test_version():
|
||||
@ -23,14 +25,6 @@ def test_version():
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_use_cudagraphs_dynamic():
|
||||
vllm_config = VllmConfig()
|
||||
# Default V1 configuration now starts without cudagraphs enabled; the
|
||||
# engine decides when to capture based on runtime settings instead of a
|
||||
# blanket default.
|
||||
assert vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||
|
||||
compilation_config = {
|
||||
"use_cudagraph": False, # speed things up a bit
|
||||
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("enabled", [True, False])
|
||||
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
@pytest.mark.parametrize(
|
||||
"cudagraph_mode,num_cudagraph_captured",
|
||||
[
|
||||
(CUDAGraphMode.NONE, 0),
|
||||
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
|
||||
(CUDAGraphMode.PIECEWISE, 13),
|
||||
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
|
||||
],
|
||||
)
|
||||
def test_use_cudagraphs(
|
||||
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
|
||||
):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [100],
|
||||
"use_cudagraph": enabled,
|
||||
"cudagraph_mode": cudagraph_mode,
|
||||
}
|
||||
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||
num_cudagraph_captured=13 if enabled else 0,
|
||||
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
|
||||
num_cudagraph_captured=num_cudagraph_captured,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
@ -168,19 +173,18 @@ def test_splitting_ops_dynamic():
|
||||
assert not config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
|
||||
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
@ -189,29 +193,41 @@ def test_splitting_ops_dynamic():
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With the new simplified logic, attention fusion works with splitting_ops
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# cudagraph mode remains PIECEWISE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(ValidationError):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_should_split():
|
||||
@ -293,25 +309,36 @@ def test_should_split():
|
||||
"tp_size",
|
||||
"enable_sequence_parallelism",
|
||||
"max_num_batched_tokens",
|
||||
"use_cudagraph",
|
||||
"cudagraph_mode",
|
||||
"expected_max_size",
|
||||
),
|
||||
[
|
||||
(None, None, 1, False, 2048, True, 512),
|
||||
([1, 2, 4], 4, 1, False, 2048, True, 4),
|
||||
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
|
||||
([1, 256], None, 1, False, 2048, 256),
|
||||
([], None, 1, False, 2048, False, 0),
|
||||
(None, 0, 1, False, 2048, False, 0),
|
||||
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
(
|
||||
[1, 2, 4],
|
||||
8,
|
||||
1,
|
||||
False,
|
||||
2048,
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
ValidationError,
|
||||
),
|
||||
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
# truncated to nearest multiple of 8 or 16
|
||||
(None, 257, 1, False, 2048, True, 256),
|
||||
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
|
||||
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
|
||||
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
# max from list
|
||||
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
|
||||
# filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# limited by the max_tokens
|
||||
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# the list should contain at least 1 element when use cudagraph
|
||||
([], None, 1, False, 2048, True, RuntimeError),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
# the max capturing size should be >= 1 when use cudagraph
|
||||
(None, 0, 1, False, 2048, True, RuntimeError),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
],
|
||||
)
|
||||
def test_cudagraph_sizes_post_init(
|
||||
@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init(
|
||||
tp_size,
|
||||
enable_sequence_parallelism,
|
||||
max_num_batched_tokens,
|
||||
use_cudagraph,
|
||||
cudagraph_mode,
|
||||
expected_max_size,
|
||||
):
|
||||
ctx = nullcontext()
|
||||
if isinstance(expected_max_size, Exception):
|
||||
if expected_max_size == ValidationError:
|
||||
ctx = pytest.raises(expected_max_size)
|
||||
|
||||
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
|
||||
with ctx:
|
||||
with (
|
||||
ctx,
|
||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
||||
):
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init(
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=min(max_num_batched_tokens, 128),
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
|
||||
)
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
||||
@ -183,8 +183,14 @@ def test_custom_compile_config(
|
||||
"compilation_mode",
|
||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(compilation_mode: int):
|
||||
model = "Qwen/Qwen2-0.5B"
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"Qwen/Qwen2-0.5B", # Standard attention model
|
||||
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
|
||||
],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
|
||||
@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
class Matches(NamedTuple):
|
||||
attention_fusion: int = 0
|
||||
allreduce_fusion: int = 0
|
||||
sequence_parallel: int = 0
|
||||
async_tp: int = 0
|
||||
|
||||
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: _Backend
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
backend: AttentionBackendEnum
|
||||
matches: Matches
|
||||
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
@ -38,17 +47,33 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
# TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
|
||||
# so FI attention+fp8_quant is at least tested once
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.FLASHINFER
|
||||
if is_blackwell()
|
||||
else AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=32,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
|
||||
# https://github.com/vllm-project/vllm/issues/28568
|
||||
# TODO FlashInfer attn broken on Blackwell for llama4:
|
||||
# https://github.com/vllm-project/vllm/issues/28604
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=48,
|
||||
allreduce_fusion=96,
|
||||
sequence_parallel=96,
|
||||
async_tp=95, # mlp is moe, no fusion there
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -56,9 +81,13 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
matches=Matches(
|
||||
attention_fusion=32,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -67,9 +96,24 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=0,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=0,
|
||||
allreduce_fusion=97,
|
||||
sequence_parallel=97,
|
||||
async_tp=96, # MLP is MoE, half the fusions of dense
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -78,20 +122,20 @@ elif current_platform.is_rocm():
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
]
|
||||
|
||||
@ -99,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
||||
# quant_fp4 only has the custom impl
|
||||
@ -110,16 +153,15 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: _Backend,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not is_blackwell() or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
@ -162,12 +204,12 @@ def test_attn_quant(
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
matches = re.findall(
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 1, log_holder.text
|
||||
assert int(matches[0]) == attention_fusions
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
|
||||
|
||||
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
|
||||
@ -180,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
@ -201,9 +242,8 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: _Backend,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
@ -212,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
if "fp4" in model_name.lower() and not is_blackwell():
|
||||
pytest.skip("NVFP4 quant requires Blackwell")
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
|
||||
# FlashInfer attn fusion requires Blackwell
|
||||
matches = matches._replace(attention_fusion=0)
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
@ -251,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
matches = re.findall(
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == attention_fusions
|
||||
assert int(matches[1]) == attention_fusions
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
assert int(log_matches[1]) == matches.attention_fusion
|
||||
|
||||
matches = re.findall(
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == allreduce_fusions
|
||||
assert int(matches[1]) == allreduce_fusions
|
||||
assert int(log_matches[0]) == matches.allreduce_fusion
|
||||
assert int(log_matches[1]) == matches.allreduce_fusion
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||
)
|
||||
)
|
||||
# Toggle RMSNorm for FP4 models and unquant models
|
||||
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="sequence parallel only tested on CUDA",
|
||||
)
|
||||
def test_tp2_attn_quant_async_tp(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if is_blackwell():
|
||||
# TODO: https://github.com/vllm-project/vllm/issues/27893
|
||||
pytest.skip("Blackwell is not supported for AsyncTP pass")
|
||||
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
if "fp4" in model_name.lower() and not is_blackwell():
|
||||
pytest.skip("NVFP4 quant requires Blackwell")
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER:
|
||||
if not has_flashinfer():
|
||||
pytest.skip("FlashInfer backend requires flashinfer installed")
|
||||
if not is_blackwell():
|
||||
# FlashInfer attn fusion requires Blackwell
|
||||
matches = matches._replace(attention_fusion=0)
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
custom_ops=custom_ops_list,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_sequence_parallelism=True,
|
||||
enable_async_tp=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
assert int(log_matches[1]) == matches.attention_fusion
|
||||
|
||||
log_matches = re.findall(
|
||||
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.sequence_parallel
|
||||
assert int(log_matches[1]) == matches.sequence_parallel
|
||||
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.async_tp
|
||||
assert int(log_matches[1]) == matches.async_tp
|
||||
|
||||
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
|
||||
@ -10,8 +10,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
def test_compile():
|
||||
vllm_config = VllmConfig()
|
||||
# Default configuration compiles mm encoder
|
||||
assert vllm_config.compilation_config.compile_mm_encoder
|
||||
# Default configuration does not compile mm encoder
|
||||
assert not vllm_config.compilation_config.compile_mm_encoder
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@ -39,7 +39,10 @@ def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=2048,
|
||||
gpu_memory_utilization=0.8,
|
||||
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
|
||||
compilation_config={
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_mm_encoder": True,
|
||||
},
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -5,15 +5,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.fx_utils import find_auto_fn
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CUDAGraphMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
@ -43,172 +44,157 @@ prompts = [
|
||||
]
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
"""
|
||||
Forward pass implementing the operations in the FX graph
|
||||
def forward(self, x):
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
residual: Residual tensor from previous layer
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
|
||||
Returns:
|
||||
Tuple containing the output tensor
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
|
||||
# Tensor parallel all-reduce
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
|
||||
return norm_output, residual_output
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default,
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
return [torch.ops._C.fused_add_rms_norm.default]
|
||||
if RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
# which expects a column-major layout.
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
"""
|
||||
Forward pass implementing the operations in the FX graph
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
residual: Residual tensor from previous layer
|
||||
|
||||
Returns:
|
||||
Tuple containing the output tensor
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
# Tensor parallel all-reduce
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
|
||||
# The following are only removed if fusion happens
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
ops_to_remove.extend(
|
||||
[
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
)
|
||||
return ops_to_remove
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
ops_to_add = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
return [
|
||||
torch.ops.vllm.all_gather.default,
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
]
|
||||
# The following is only added if fusion happens
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||
return ops_to_add
|
||||
|
||||
def ops_in_model(self):
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
# If fusion happens, the fused op is the one
|
||||
# we check for (de)functionalization
|
||||
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
else:
|
||||
# If no fusion, the original ops are checked
|
||||
elif RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
# TODO functionalization pass does not handle this yet
|
||||
# torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
elif self.fp8_linear.quant_fp8.enabled():
|
||||
return [
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
|
||||
@pytest.mark.parametrize(
|
||||
"test_model_cls, custom_ops",
|
||||
[
|
||||
(TestAllReduceRMSNormModel, "+rms_norm"),
|
||||
(TestAllReduceRMSNormModel, "-rms_norm"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.parametrize("dynamic", [False, True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
custom_ops: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
|
||||
@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
|
||||
args=(
|
||||
num_processes,
|
||||
test_model_cls,
|
||||
custom_ops,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_fusion,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
custom_ops: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
compilation_config = CompilationConfig(
|
||||
splitting_ops=[], # avoid automatic rms_norm enablement
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||
custom_ops=custom_ops_list,
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
)
|
||||
),
|
||||
) # NoOp needed for fusion
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
|
||||
with set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
backend = TestBackend(*passes_for_backend)
|
||||
|
||||
model = test_model_cls(hidden_size, hidden_size * 2)
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 4
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
for op in model.ops_in_model_before():
|
||||
assert backend.op_count(op, before=True) == 4
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
for op in model.ops_in_model_after():
|
||||
assert backend.op_count(op, before=False) == 4
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
find_auto_fn(backend.graph_post_pass.nodes, op)
|
||||
|
||||
@ -14,6 +14,7 @@ from dataclasses import dataclass
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
@ -254,6 +255,17 @@ def test_cp_generation(
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
if (
|
||||
model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
and torch.cuda.get_device_capability() < (9, 0)
|
||||
):
|
||||
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
|
||||
if (
|
||||
model_id == "bigcode/gpt_bigcode-santacoder"
|
||||
and torch.cuda.get_device_capability() != (9, 0)
|
||||
):
|
||||
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
|
||||
|
||||
_compare_cp_with_tp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
|
||||
@ -18,6 +18,7 @@ import pytest
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
@ -161,6 +162,7 @@ def _compare_sp(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available: int,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
is_multimodal: bool,
|
||||
@ -244,10 +246,10 @@ def _compare_sp(
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"custom_ops": ["+rms_norm"],
|
||||
"compile_sizes": [4, 8],
|
||||
"pass_config": {
|
||||
"enable_sequence_parallelism": True,
|
||||
"enable_async_tp": enable_async_tp,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_noop": True,
|
||||
},
|
||||
@ -307,6 +309,7 @@ SP_TEST_MODELS = [
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
|
||||
@create_new_process_for_each_test()
|
||||
def test_tp_sp_generation(
|
||||
model_id: str,
|
||||
@ -316,10 +319,19 @@ def test_tp_sp_generation(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
):
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
# Skip FP8 SP-only test on sm89 (compute capability 8.9)
|
||||
if (
|
||||
"fp8" in model_id.lower()
|
||||
and current_platform.get_device_capability() < (9, 0)
|
||||
and (not enable_async_tp)
|
||||
):
|
||||
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
||||
|
||||
_compare_sp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
@ -328,6 +340,7 @@ def test_tp_sp_generation(
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition,
|
||||
enable_async_tp=enable_async_tp,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
@ -3,6 +3,3 @@ accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220
|
||||
env:
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1"
|
||||
|
||||
@ -170,6 +170,7 @@ def test_cascade(
|
||||
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
||||
block_table=block_tables,
|
||||
common_prefix_len=common_prefix_len,
|
||||
max_num_splits=0, # no max
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
|
||||
@ -10,11 +10,13 @@ from collections import defaultdict
|
||||
from pathlib import PosixPath
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForTextToWaveform,
|
||||
)
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.func_utils import identity
|
||||
@ -137,6 +139,7 @@ VLM_TEST_SETTINGS = {
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>",
|
||||
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>",
|
||||
enforce_eager=False,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
@ -166,6 +169,7 @@ VLM_TEST_SETTINGS = {
|
||||
VLMTestType.MULTI_IMAGE,
|
||||
VLMTestType.VIDEO,
|
||||
),
|
||||
enforce_eager=False,
|
||||
needs_video_metadata=True,
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
|
||||
@ -859,6 +863,12 @@ VLM_TEST_SETTINGS = {
|
||||
limit_mm_per_prompt={"image": 4},
|
||||
)
|
||||
],
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_VERSION) == Version("4.57.1"),
|
||||
reason="This model is broken in Transformers v4.57.1",
|
||||
)
|
||||
],
|
||||
),
|
||||
# regression test for https://github.com/vllm-project/vllm/issues/15122
|
||||
"qwen2_5_vl-windows-attention": VLMTestInfo(
|
||||
|
||||
@ -61,10 +61,8 @@ def test_qwen2_5_vl_evs_functionality(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"video": 1},
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
) as vllm_model:
|
||||
# Generate output - this should not crash
|
||||
|
||||
@ -980,8 +980,10 @@ def test_hybrid_block_table_initialization():
|
||||
req_index = 0
|
||||
block_table.append_row(kvcache_manager_blocks, req_index)
|
||||
# Get expected kernel blocks from the implementation for verification.
|
||||
expected_kernel_blocks = block_table._map_to_kernel_blocks(
|
||||
np.array(kvcache_manager_blocks)
|
||||
expected_kernel_blocks = block_table.map_to_kernel_blocks(
|
||||
np.array(kvcache_manager_blocks),
|
||||
block_table.blocks_per_kv_block,
|
||||
block_table._kernel_block_arange,
|
||||
)
|
||||
# Verify block table state
|
||||
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
|
||||
|
||||
@ -938,4 +938,5 @@ class rocm_aiter_ops:
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
if IS_AITER_FOUND:
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
|
||||
@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
k_pe: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
||||
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# Mirror Attention.forward scale calculation path
|
||||
if self.calculate_kv_scales and getattr(
|
||||
attn_metadata, "enable_kv_scales_calculation", False
|
||||
):
|
||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.impl.forward(
|
||||
@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
return output
|
||||
else:
|
||||
# We can still access forward context to check calculation flag
|
||||
if self.calculate_kv_scales:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
|
||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||
return torch.ops.vllm.unified_mla_attention(
|
||||
q,
|
||||
kv_c_normed,
|
||||
@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
if attn_metadata is None or not getattr(
|
||||
attn_metadata, "enable_kv_scales_calculation", False
|
||||
):
|
||||
# Only calculate if the layer's calculate_kv_scales flag is True
|
||||
# This flag gets set to False after the first forward pass
|
||||
if not self.calculate_kv_scales:
|
||||
return
|
||||
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
|
||||
|
||||
@ -195,7 +195,6 @@ def cp_lse_ag_out_rs(
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
assert out.is_contiguous()
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
|
||||
if return_lse:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
@ -10,98 +12,28 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _RMSNormAndQuantOpHelper:
|
||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
||||
def get_first_out_wrapper(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args):
|
||||
return fn(*args)[0]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: torch._ops.OpOverload | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.quant_op = quant_op
|
||||
|
||||
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=result_buffer,
|
||||
input=input_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def _functional_fused_add_rmsnorm(
|
||||
self, input_tensor, residual_tensor, weight_tensor
|
||||
):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input_tensor,
|
||||
residual=residual_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def _functional_rmsnorm_then_quant(
|
||||
self,
|
||||
rmsnorm_result_buffer,
|
||||
quant_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor,
|
||||
scale_tensor,
|
||||
):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(
|
||||
rmsnorm_result_buffer, input_tensor, weight_tensor
|
||||
)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor,
|
||||
)
|
||||
return quant_out_tuple
|
||||
|
||||
def _functional_fused_add_rmsnorm_then_quant(
|
||||
self,
|
||||
quant_result_buffer,
|
||||
input_tensor,
|
||||
residual_tensor,
|
||||
weight_tensor,
|
||||
scale_tensor,
|
||||
):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
||||
input_tensor, residual_tensor, weight_tensor
|
||||
)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor,
|
||||
)
|
||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
||||
return wrapper
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: torch._ops.OpOverload | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, permute, arg3_1]
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm[1], all_reduce
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
|
||||
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights
|
||||
)
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||
return rmsnorm[0], rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights
|
||||
)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
return all_gather, rmsnorm[2]
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||
all_gather = self._all_gather(rmsnorm[0])
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, rmsnorm[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights
|
||||
)
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights
|
||||
)
|
||||
normalized = self._all_gather(rmsnorm[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale
|
||||
)
|
||||
return static_fp8[1], all_reduce
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(
|
||||
reduce_scatter, dtype=rmsnorm_result.dtype
|
||||
)
|
||||
quant_result = torch.empty_like(
|
||||
rmsnorm_result, # Output of RMSNorm
|
||||
dtype=quant_result.dtype,
|
||||
)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale
|
||||
)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, rmsnorm_residual_out = (
|
||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale
|
||||
)
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
all_reduce, rms_norm_weights, residual
|
||||
)
|
||||
return static_fp8[1], rmsnorm_residual_out
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, residual_out
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# add a temporary slice which will become a noop
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = (
|
||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
||||
)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
reduce_scatter, rms_norm_weights, residual
|
||||
)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
return all_gather, rmsnorm_residual_out
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, residual_out
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
result, all_reduce, residual, rms_norm_weights, scale
|
||||
)
|
||||
return static_fp8[1]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
||||
)
|
||||
normalized = self._all_gather(static_fp8[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# Used to cleanup redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
).register(self.patterns)
|
||||
LastAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
self.noop_cleanup(graph)
|
||||
|
||||
@ -152,7 +152,6 @@ class CompilationConfig:
|
||||
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
|
||||
- [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
|
||||
- CudaGraph capture:
|
||||
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
|
||||
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
|
||||
- [`cudagraph_capture_sizes`]
|
||||
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
|
||||
@ -162,7 +161,6 @@ class CompilationConfig:
|
||||
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
|
||||
- [`cudagraph_copy_inputs`]
|
||||
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
||||
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
|
||||
- Inductor compilation:
|
||||
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
|
||||
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
||||
@ -268,9 +266,10 @@ class CompilationConfig:
|
||||
|
||||
If None, defaults to attention ops for piecewise cudagraphs.
|
||||
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
|
||||
compile_mm_encoder: bool = True
|
||||
compile_mm_encoder: bool = False
|
||||
"""Whether or not to compile the multimodal encoder.
|
||||
Currently, this only works for `Qwen2_5_vl`."""
|
||||
Currently, this only works for `Qwen2_5_vl` on selected platforms.
|
||||
Disabled by default until more models are supported/tested to work."""
|
||||
|
||||
# Inductor capture
|
||||
use_inductor: bool | None = None
|
||||
@ -342,18 +341,6 @@ class CompilationConfig:
|
||||
Warning: This flag is new and subject to change in addition
|
||||
more modes may be added.
|
||||
"""
|
||||
use_cudagraph: bool = True
|
||||
"""Whether to use cudagraph inside compilation:
|
||||
|
||||
- False: cudagraph inside compilation is not used.\n
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
|
||||
_PIECEWISE instead.
|
||||
"""
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
"""Number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
@ -371,15 +358,6 @@ class CompilationConfig:
|
||||
internally managed buffer. Default is False.
|
||||
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
|
||||
"""
|
||||
full_cuda_graph: bool | None = False
|
||||
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs. Thus this
|
||||
flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
||||
FULL_AND_PIECEWISE instead.
|
||||
"""
|
||||
cudagraph_specialize_lora: bool = True
|
||||
"""Whether to create separate cuda graphs for cases with and without active
|
||||
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
|
||||
@ -528,13 +506,19 @@ class CompilationConfig:
|
||||
@field_validator("cudagraph_mode", mode="before")
|
||||
@classmethod
|
||||
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
|
||||
"""
|
||||
enable parse the `cudagraph_mode` enum type from string
|
||||
"""
|
||||
"""Enable parsing of the `cudagraph_mode` enum type from string."""
|
||||
if isinstance(value, str):
|
||||
return CUDAGraphMode[value.upper()]
|
||||
return value
|
||||
|
||||
@field_validator("pass_config", mode="before")
|
||||
@classmethod
|
||||
def validate_pass_config_before(cls, value: Any) -> Any:
|
||||
"""Enable parsing of the `pass_config` field from a dictionary."""
|
||||
if isinstance(value, dict):
|
||||
return PassConfig(**value)
|
||||
return value
|
||||
|
||||
@field_validator("compile_cache_save_format")
|
||||
@classmethod
|
||||
def validate_compile_cache_save_format(cls, value: str) -> str:
|
||||
@ -591,8 +575,10 @@ class CompilationConfig:
|
||||
func if isinstance(func, InductorPass) else CallableInductorPass(func)
|
||||
)
|
||||
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
@ -604,36 +590,6 @@ class CompilationConfig:
|
||||
self.inductor_compile_config["combo_kernels"] = True
|
||||
self.inductor_compile_config["benchmark_combo_kernel"] = True
|
||||
|
||||
# migrate the deprecated flags
|
||||
if not self.use_cudagraph:
|
||||
logger.warning(
|
||||
"use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
|
||||
)
|
||||
if (
|
||||
self.cudagraph_mode is not None
|
||||
and self.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
raise ValueError(
|
||||
"use_cudagraph and cudagraph_mode are mutually"
|
||||
" exclusive, prefer cudagraph_mode since "
|
||||
"use_cudagraph is deprecated."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
if self.full_cuda_graph:
|
||||
logger.warning(
|
||||
"full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
|
||||
)
|
||||
if (
|
||||
self.cudagraph_mode is not None
|
||||
and not self.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
raise ValueError(
|
||||
"full_cuda_graph and cudagraph_mode are "
|
||||
"mutually exclusive, prefer cudagraph_mode "
|
||||
"since full_cuda_graph is deprecated."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
@ -811,20 +767,19 @@ class CompilationConfig:
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.enable_attn_fusion
|
||||
# For dynamo-partition (non-inductor) attention fusion,
|
||||
# set splitting_ops to empty to avoid splitting at attention ops
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
|
||||
@ -441,8 +441,6 @@ class VllmConfig:
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if current_platform.support_static_graph_mode():
|
||||
# if cudagraph_mode is not explicitly set by users, set default
|
||||
@ -631,6 +629,32 @@ class VllmConfig:
|
||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
# With pipeline parallelism or dynamo partitioning,
|
||||
# native rms norm tracing errors due to incorrect residual shape.
|
||||
# Use custom rms norm to unblock. In the future,
|
||||
# the pass will operate on higher-level IR to avoid the issue.
|
||||
# TODO: https://github.com/vllm-project/vllm/issues/27894
|
||||
is_fullgraph = (
|
||||
self.compilation_config.use_inductor_graph_partition
|
||||
or len(self.compilation_config.splitting_ops) == 0
|
||||
)
|
||||
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
|
||||
if "-rms_norm" not in self.compilation_config.custom_ops:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
else:
|
||||
regime = (
|
||||
"Dynamo partition"
|
||||
if not is_fullgraph
|
||||
else "pipeline parallelism"
|
||||
)
|
||||
logger.warning_once(
|
||||
"Sequence parallelism not supported with"
|
||||
"native rms_norm when using %s, "
|
||||
"this will likely lead to an error.",
|
||||
regime,
|
||||
)
|
||||
|
||||
# final check of cudagraph mode after all possible updates
|
||||
if current_platform.is_cuda_alike():
|
||||
if (
|
||||
@ -652,14 +676,6 @@ class VllmConfig:
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
)
|
||||
|
||||
# final migrate the deprecated flags
|
||||
self.compilation_config.use_cudagraph = (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
)
|
||||
self.compilation_config.full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = self.parallel_config.all2all_backend
|
||||
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
|
||||
@ -849,7 +865,9 @@ class VllmConfig:
|
||||
)
|
||||
# de-duplicate the sizes provided by the config
|
||||
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
|
||||
cudagraph_capture_sizes = dedup_sizes
|
||||
cudagraph_capture_sizes = [
|
||||
i for i in dedup_sizes if i <= max_num_tokens
|
||||
]
|
||||
# sort to make sure the sizes are in ascending order
|
||||
cudagraph_capture_sizes.sort()
|
||||
else:
|
||||
|
||||
@ -342,8 +342,8 @@ class MsgpackSerde(ObjectSerde):
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.tensor_decoder = MsgpackDecoder(torch.Tensor)
|
||||
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
|
||||
self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False)
|
||||
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False)
|
||||
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
||||
|
||||
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
|
||||
@ -368,7 +368,7 @@ class MsgpackSerde(ObjectSerde):
|
||||
# pickle.loads do not read past the end of a pickled object
|
||||
# within a large buffer, so we can skip storing the metadata size
|
||||
type_name, nbytes, len_arr = pickle.loads(data_view)
|
||||
serialized_data = bytearray(data_view[-nbytes:])
|
||||
serialized_data = data_view[-nbytes:]
|
||||
|
||||
if type_name == torch.Tensor.__name__:
|
||||
obj = []
|
||||
|
||||
@ -48,6 +48,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -110,6 +111,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
# To be used when logical block size does not match the kernel block size
|
||||
local_physical_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
@ -137,6 +140,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
assert load_remote_cache ^ save_to_host
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
local_physical_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
@ -897,6 +901,8 @@ class NixlConnectorWorker:
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
self._physical_blocks_per_logical_kv_block = 1
|
||||
|
||||
def _nixl_handshake(
|
||||
self,
|
||||
@ -1092,6 +1098,22 @@ class NixlConnectorWorker:
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
# TODO (NickLucche): Get kernel_block_size in a cleaner way
|
||||
# NHD default "view" for non-MLA cache
|
||||
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
|
||||
|
||||
if self.block_size != kernel_block_size:
|
||||
logger.info_once(
|
||||
"User-specified logical block size (%s) does not match"
|
||||
" physical kernel block size (%s). Using the latter. ",
|
||||
self.block_size,
|
||||
kernel_block_size,
|
||||
)
|
||||
self._physical_blocks_per_logical_kv_block = (
|
||||
self.block_size // kernel_block_size
|
||||
)
|
||||
self.block_size = kernel_block_size
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
||||
|
||||
@ -1438,7 +1460,7 @@ class NixlConnectorWorker:
|
||||
assert self.use_host_buffer
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
local_block_ids = meta.local_block_ids
|
||||
local_block_ids = meta.local_physical_block_ids
|
||||
self.copy_blocks(
|
||||
self.host_xfer_buffers,
|
||||
self.device_kv_caches,
|
||||
@ -1451,7 +1473,7 @@ class NixlConnectorWorker:
|
||||
"synced recved kv of request[%s] to device kv buffer,"
|
||||
"local_block_ids: %s. ",
|
||||
req_id,
|
||||
",".join(map(str, meta.local_block_ids)),
|
||||
",".join(map(str, local_block_ids)),
|
||||
)
|
||||
|
||||
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
|
||||
@ -1460,19 +1482,22 @@ class NixlConnectorWorker:
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
for req_id, meta in metadata.reqs_to_save.items():
|
||||
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.local_block_ids
|
||||
)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"save_load_kv for request[%s] to host xfer buffer."
|
||||
"local_block_ids: %s. ",
|
||||
req_id,
|
||||
",".join(map(str, meta.local_block_ids)),
|
||||
",".join(map(str, meta.local_physical_block_ids)),
|
||||
)
|
||||
# blocking
|
||||
self.copy_blocks(
|
||||
self.device_kv_caches,
|
||||
self.host_xfer_buffers,
|
||||
meta.local_block_ids,
|
||||
meta.local_block_ids,
|
||||
meta.local_physical_block_ids,
|
||||
meta.local_physical_block_ids,
|
||||
"d2h",
|
||||
)
|
||||
|
||||
@ -1541,7 +1566,7 @@ class NixlConnectorWorker:
|
||||
if self.use_host_buffer:
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
if self.enable_permute_local_kv:
|
||||
block_ids_to_permute += meta.local_block_ids
|
||||
block_ids_to_permute += meta.local_physical_block_ids
|
||||
if len(block_ids_to_permute) > 0:
|
||||
self.permute_device_kv(block_ids_to_permute)
|
||||
|
||||
@ -1628,7 +1653,7 @@ class NixlConnectorWorker:
|
||||
req_id,
|
||||
xfer_state,
|
||||
)
|
||||
# mark all blocks for this request as invalid
|
||||
# mark all (logical)blocks for this request as invalid
|
||||
if meta := self._recving_metadata.pop(req_id, None):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self._recving_metadata.pop(req_id, None)
|
||||
@ -1645,13 +1670,19 @@ class NixlConnectorWorker:
|
||||
We check for these trnxs to complete in each step().
|
||||
"""
|
||||
for req_id, meta in metadata.reqs_to_recv.items():
|
||||
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.local_block_ids
|
||||
)
|
||||
meta.remote_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.remote_block_ids
|
||||
)
|
||||
remote_engine_id = meta.remote_engine_id
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
||||
req_id,
|
||||
remote_engine_id,
|
||||
len(meta.local_block_ids),
|
||||
len(meta.local_physical_block_ids),
|
||||
len(meta.remote_block_ids),
|
||||
)
|
||||
# always store metadata for failure recovery
|
||||
@ -1699,7 +1730,7 @@ class NixlConnectorWorker:
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
local_block_ids=meta.local_physical_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
)
|
||||
|
||||
@ -1826,7 +1857,7 @@ class NixlConnectorWorker:
|
||||
"Marking blocks as invalid.",
|
||||
request_id,
|
||||
)
|
||||
# mark all blocks for this request as invalid
|
||||
# mark all (logical) blocks for this request as invalid
|
||||
if meta := self._recving_metadata.get(request_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
@ -1865,6 +1896,23 @@ class NixlConnectorWorker:
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Convert logical block ids to kernel physical block ids.
|
||||
This is required when the logical block size (the one set by the user)
|
||||
does not match the one required by the attn backend.
|
||||
"""
|
||||
if self._physical_blocks_per_logical_kv_block == 1:
|
||||
# Noop when physical and logical block sizes are the same
|
||||
return block_ids
|
||||
block_ids_np = np.array(block_ids)
|
||||
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
return BlockTable.map_to_kernel_blocks(
|
||||
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
|
||||
).tolist()
|
||||
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@ -1625,40 +1625,39 @@ class EngineArgs:
|
||||
)
|
||||
|
||||
observability_config = ObservabilityConfig(
|
||||
show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
|
||||
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||
collect_detailed_traces=self.collect_detailed_traces,
|
||||
)
|
||||
|
||||
# Compilation config overrides
|
||||
compilation_config = copy.deepcopy(self.compilation_config)
|
||||
if self.cuda_graph_sizes is not None:
|
||||
logger.warning(
|
||||
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
|
||||
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
|
||||
"instead."
|
||||
)
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||
if compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
"cuda_graph_sizes and compilation_config."
|
||||
"cudagraph_capture_sizes are mutually exclusive"
|
||||
)
|
||||
self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
|
||||
compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
|
||||
if self.cudagraph_capture_sizes is not None:
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||
if compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
"cudagraph_capture_sizes and compilation_config."
|
||||
"cudagraph_capture_sizes are mutually exclusive"
|
||||
)
|
||||
self.compilation_config.cudagraph_capture_sizes = (
|
||||
self.cudagraph_capture_sizes
|
||||
)
|
||||
compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
|
||||
if self.max_cudagraph_capture_size is not None:
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
if compilation_config.max_cudagraph_capture_size is not None:
|
||||
raise ValueError(
|
||||
"max_cudagraph_capture_size and compilation_config."
|
||||
"max_cudagraph_capture_size are mutually exclusive"
|
||||
)
|
||||
self.compilation_config.max_cudagraph_capture_size = (
|
||||
compilation_config.max_cudagraph_capture_size = (
|
||||
self.max_cudagraph_capture_size
|
||||
)
|
||||
|
||||
@ -1673,7 +1672,7 @@ class EngineArgs:
|
||||
load_config=load_config,
|
||||
structured_outputs_config=self.structured_outputs_config,
|
||||
observability_config=observability_config,
|
||||
compilation_config=self.compilation_config,
|
||||
compilation_config=compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
additional_config=self.additional_config,
|
||||
|
||||
@ -2439,28 +2439,6 @@ class FusedMoE(CustomOp):
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel,
|
||||
# the shared experts must be called here
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if self.shared_experts_stream is not None:
|
||||
# For chunked, we start the shared experts stream here
|
||||
# (Note that no concurrency with the router/gate)
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that staged_hidden_states clone() is necessary
|
||||
# here to avoid conflict with the main stream
|
||||
shared_output = self.shared_experts(
|
||||
staged_hidden_states.clone()
|
||||
)
|
||||
else:
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@ -2489,11 +2467,7 @@ class FusedMoE(CustomOp):
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
# Here we finish the shared experts stream
|
||||
if self.shared_experts_stream is not None:
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
@ -2602,11 +2576,22 @@ class FusedMoE(CustomOp):
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if self.shared_experts_stream is not None:
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
# Run shared experts in parallel on a separate stream
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states.clone())
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: we dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
else:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
|
||||
@ -43,7 +43,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
@ -95,11 +94,9 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
@ -554,83 +551,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
|
||||
# we will use BF16 dequant when DeepGEMM is not supported.
|
||||
if vllm_is_batch_invariant():
|
||||
# Call is_deep_gemm_supported() ahead of time for torch.compile
|
||||
# dynamo has trouble tracing through
|
||||
if self.block_quant and should_use_deepgemm_for_fp8_linear(
|
||||
torch.bfloat16, layer.weight, self.use_deep_gemm
|
||||
):
|
||||
# use group quant consistent with block size across K
|
||||
assert self.act_q_group_shape is not None
|
||||
q_input, input_scale = QuantFP8(
|
||||
False,
|
||||
self.act_q_group_shape,
|
||||
column_major_scales=True,
|
||||
)(x)
|
||||
|
||||
output_2d = torch.empty(
|
||||
(q_input.shape[0], layer.weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(layer.weight, layer.weight_scale),
|
||||
output_2d,
|
||||
)
|
||||
if bias is not None:
|
||||
output_2d = output_2d + bias
|
||||
return output_2d
|
||||
|
||||
# Dequantize FP8 weights to BF16
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
|
||||
# Handle different quantization granularities
|
||||
if self.block_quant:
|
||||
# Block-wise quantization:
|
||||
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
|
||||
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
|
||||
assert self.weight_block_size is not None
|
||||
block_n, block_k = self.weight_block_size # Note: order is [N, K]
|
||||
|
||||
N, K = weight_fp8.shape
|
||||
|
||||
# determine expected number of blocks along N and K
|
||||
num_blocks_n = (N + block_n - 1) // block_n
|
||||
num_blocks_k = (K + block_k - 1) // block_k
|
||||
|
||||
# scale layout may be [num_blocks_n, num_blocks_k]
|
||||
# or [num_blocks_k, num_blocks_n] depending on backend
|
||||
if weight_scale.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
|
||||
)
|
||||
|
||||
scale_rows, scale_cols = weight_scale.shape
|
||||
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
|
||||
if num_blocks_n == num_blocks_k:
|
||||
# ambiguous square case, warn and skip transpose
|
||||
logger.warning(
|
||||
"Batch-invariant FP8: square block-scale %dx%d; "
|
||||
"skipping transpose to avoid misorientation.",
|
||||
scale_rows,
|
||||
scale_cols,
|
||||
)
|
||||
else:
|
||||
# clear KN -> transpose to NK
|
||||
weight_scale = weight_scale.t()
|
||||
|
||||
# Expand scale to match weight dimensions
|
||||
# scale_expanded should have shape [N, K]
|
||||
scale_expanded = weight_scale.repeat_interleave(
|
||||
block_n, dim=0
|
||||
).repeat_interleave(block_k, dim=1)
|
||||
# Trim to exact weight size (in case of padding)
|
||||
scale_expanded = scale_expanded[:N, :K]
|
||||
weight_bf16 = weight_fp8 * scale_expanded
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
# Per-tensor quantization: weight IS transposed to [K, N]
|
||||
# scale should be scalar or [1] or per-output-channel [N]
|
||||
# per-tensor/channel: dequant to BF16 and run GEMM
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
@ -649,16 +582,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
|
||||
# For block quant, weight is [N, K], for per-tensor it's [K, N]
|
||||
# F.linear expects weight to be [N, K], so:
|
||||
if self.block_quant:
|
||||
# Already in correct shape [N, K]
|
||||
output = torch.nn.functional.linear(x, weight_bf16, bias)
|
||||
else:
|
||||
# Need to transpose back: [K, N] -> [N, K]
|
||||
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
return output
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
|
||||
@ -82,7 +82,8 @@ enable_hf_transfer()
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
kwargs["disable"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None):
|
||||
|
||||
@ -779,6 +779,10 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> DotsOCRImageInputs | None:
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -45,6 +46,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.utils import initialize_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@ -68,11 +70,15 @@ from .interfaces import (
|
||||
MixtureOfExperts,
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@ -724,7 +730,12 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
dummy_inputs=Mllama4DummyInputsBuilder,
|
||||
)
|
||||
class Llama4ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
MixtureOfExperts,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
@ -1067,6 +1078,17 @@ class Llama4ForConditionalGeneration(
|
||||
|
||||
return updated_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.text_config.num_local_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@ -1113,3 +1135,13 @@ class Llama4ForConditionalGeneration(
|
||||
)
|
||||
|
||||
return updated_params
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector.",
|
||||
tower_model="vision_model.",
|
||||
)
|
||||
|
||||
@ -198,23 +198,18 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||
if image_processor is None:
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
do_resize = True
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
|
||||
else:
|
||||
preprocessed_size = ImageSize(width=image_width, height=image_height)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
|
||||
|
||||
grid_t = 1
|
||||
grid_h = preprocessed_size.height // patch_size
|
||||
@ -227,8 +222,19 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self.get_hf_config()
|
||||
image_size = hf_config.vision_config.image_size
|
||||
return ImageSize(height=image_size, width=image_size)
|
||||
|
||||
# See `smart_resize` for the calculation of the image size.
|
||||
merge_size = hf_config.vision_config.spatial_merge_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
factor = merge_size * patch_size
|
||||
max_num_tokens = self.get_image_processor().max_pixels // (factor**2)
|
||||
# Find factors of max_num_tokens close to its square root
|
||||
# to create a dummy image with a reasonable aspect ratio.
|
||||
h_patches = int(math.sqrt(max_num_tokens))
|
||||
while max_num_tokens % h_patches != 0:
|
||||
h_patches -= 1
|
||||
w_patches = max_num_tokens // h_patches
|
||||
return ImageSize(height=h_patches * factor, width=w_patches * factor)
|
||||
|
||||
|
||||
class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]):
|
||||
|
||||
@ -13,7 +13,6 @@ from transformers import (
|
||||
BatchFeature,
|
||||
WhisperConfig,
|
||||
WhisperFeatureExtractor,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
@ -660,16 +659,6 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self) -> WhisperConfig:
|
||||
return self.ctx.get_hf_config(WhisperConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> WhisperProcessor:
|
||||
# HACK: Transformers 4.53.2 has issue with whisper tokenizer to
|
||||
# initialize processor. We use a monkeypatch to fix it here.
|
||||
# See: https://github.com/vllm-project/vllm/issues/20224
|
||||
processor_class = WhisperProcessor
|
||||
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
|
||||
if processor_class.tokenizer_class != tokenizer_class:
|
||||
processor_class.tokenizer_class = tokenizer_class
|
||||
return self.ctx.get_hf_processor(processor_class, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": 1}
|
||||
|
||||
|
||||
@ -675,6 +675,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
max_num_splits=attn_metadata.max_num_splits,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
@ -921,6 +922,7 @@ def cascade_attention(
|
||||
logits_soft_cap: float,
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
max_num_splits: int,
|
||||
fa_version: int,
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None,
|
||||
suffix_scheduler_metadata: torch.Tensor | None = None,
|
||||
@ -965,7 +967,7 @@ def cascade_attention(
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
s_aux=s_aux,
|
||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
@ -990,7 +992,7 @@ def cascade_attention(
|
||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
||||
@ -123,7 +123,7 @@ class Mamba1AttentionMetadataBuilder(
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
|
||||
@ -302,7 +302,7 @@ class Mamba2AttentionMetadataBuilder(
|
||||
|
||||
elif (
|
||||
num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
# Pad state tensor for CUDA graph
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
|
||||
@ -349,6 +349,7 @@ class MLACommonPrefillMetadata:
|
||||
local_context_lens_allranks: list[list[int]] | None = None
|
||||
padded_local_cu_seq_lens: torch.Tensor | None = None
|
||||
cu_seq_lens_lst: list[list[int]] | None = None
|
||||
chunk_size: int | None = None
|
||||
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
@ -914,6 +915,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
device, non_blocking=True
|
||||
),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = chunked_context_metadata_cls(
|
||||
@ -998,6 +1000,8 @@ def reorg_kvcache(
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -1013,6 +1017,9 @@ def reorg_kvcache(
|
||||
local_context_lens_allranks: local context lengths on each CP rank.
|
||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||
chunk_size: the local padded max context chunk from
|
||||
chunked_context_metadata building.
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
kv_c_segments = []
|
||||
@ -1024,20 +1031,31 @@ def reorg_kvcache(
|
||||
):
|
||||
cur_seq_len = 0
|
||||
for rank, local_context_len in enumerate(local_context_lens):
|
||||
if local_context_len != 0:
|
||||
# Note(qcs): We split the context into multiple chunks,
|
||||
# depending on the size of the workspace.
|
||||
# local_context in dcp0: |-----------------|
|
||||
# local_context in dcp1: |--------------|
|
||||
# n*padded_local_chunk: |-----|-----|-----|
|
||||
# local_chunk_len in dcp1: |-----|-----|--|
|
||||
# so we need update the last chunk length in dcp1.
|
||||
local_chunk_len = min(
|
||||
max(0, local_context_len - chunk_idx * chunk_size),
|
||||
padded_local_chunk_seq_len,
|
||||
)
|
||||
if local_chunk_len != 0:
|
||||
kv_c_segment = allgatered_kv_c_normed[
|
||||
rank * toks + src_token_idx : rank * toks
|
||||
+ src_token_idx
|
||||
+ local_context_len
|
||||
+ local_chunk_len
|
||||
]
|
||||
k_pe_segment = allgatered_k_pe[
|
||||
rank * toks + src_token_idx : rank * toks
|
||||
+ src_token_idx
|
||||
+ local_context_len
|
||||
+ local_chunk_len
|
||||
]
|
||||
kv_c_segments.append(kv_c_segment)
|
||||
k_pe_segments.append(k_pe_segment)
|
||||
cur_seq_len += local_context_len
|
||||
cur_seq_len += local_chunk_len
|
||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||
src_token_idx += padded_local_chunk_seq_len
|
||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||
@ -1688,6 +1706,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
|
||||
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||
assert prefill_metadata.chunked_context.chunk_size is not None
|
||||
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
@ -1737,6 +1756,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
|
||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks,
|
||||
)
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ class ShortConvAttentionMetadataBuilder(
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
|
||||
@ -5,7 +5,7 @@ import inspect
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -216,11 +216,17 @@ def build_logitsprocs(
|
||||
)
|
||||
|
||||
|
||||
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
|
||||
|
||||
|
||||
def validate_logits_processors_parameters(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
for logits_procs in _load_custom_logitsprocs(logits_processors):
|
||||
logits_processors = (
|
||||
tuple(logits_processors) if logits_processors is not None else None
|
||||
)
|
||||
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
|
||||
logits_procs.validate_params(sampling_params)
|
||||
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalSharedField,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.engine import UtilityResult
|
||||
from vllm.v1.utils import tensor_data
|
||||
|
||||
@ -282,7 +283,9 @@ class MsgpackDecoder:
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, t: Any | None = None):
|
||||
def __init__(self, t: Any | None = None, share_mem: bool = True):
|
||||
self.share_mem = share_mem
|
||||
self.pin_tensors = is_pin_memory_available()
|
||||
args = () if t is None else (t,)
|
||||
self.decoder = msgpack.Decoder(
|
||||
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
|
||||
@ -347,21 +350,30 @@ class MsgpackDecoder:
|
||||
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||
# as it now locks the whole received message buffer in memory.
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
|
||||
arr = np.frombuffer(buffer, dtype=dtype)
|
||||
if not self.share_mem:
|
||||
arr = arr.copy()
|
||||
return arr.reshape(shape)
|
||||
|
||||
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
||||
dtype, shape, data = arr
|
||||
# Copy from inline representation, to decouple the memory storage
|
||||
# of the message from the original buffer. And also make Torch
|
||||
# not complain about a readonly memoryview.
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
|
||||
is_aux = isinstance(data, int)
|
||||
buffer = self.aux_buffers[data] if is_aux else data
|
||||
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
assert isinstance(torch_dtype, torch.dtype)
|
||||
if not buffer: # torch.frombuffer doesn't like empty buffers
|
||||
if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers
|
||||
assert 0 in shape
|
||||
return torch.empty(shape, dtype=torch_dtype)
|
||||
# Create uint8 array
|
||||
arr = torch.frombuffer(buffer, dtype=torch.uint8)
|
||||
# Clone ensures tensor is backed by pytorch-owned memory for safe
|
||||
# future async CPU->GPU transfer.
|
||||
# Pin larger tensors for more efficient CPU->GPU transfer.
|
||||
if not is_aux:
|
||||
arr = arr.clone()
|
||||
elif not self.share_mem:
|
||||
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
|
||||
# Convert back to proper shape & type
|
||||
return arr.view(torch_dtype).view(shape)
|
||||
|
||||
|
||||
@ -98,7 +98,9 @@ class BlockTable:
|
||||
return
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
|
||||
block_ids = self.map_to_kernel_blocks(
|
||||
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
|
||||
)
|
||||
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
@ -188,7 +190,12 @@ class BlockTable:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
|
||||
@staticmethod
|
||||
def map_to_kernel_blocks(
|
||||
kv_manager_block_ids: np.ndarray,
|
||||
blocks_per_kv_block: int,
|
||||
kernel_block_arange: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||
|
||||
Example:
|
||||
@ -203,12 +210,12 @@ class BlockTable:
|
||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||
"""
|
||||
if not self.use_hybrid_blocks:
|
||||
if blocks_per_kv_block == 1:
|
||||
return kv_manager_block_ids
|
||||
|
||||
kernel_block_ids = (
|
||||
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
|
||||
+ self._kernel_block_arange
|
||||
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
|
||||
+ kernel_block_arange
|
||||
)
|
||||
|
||||
return kernel_block_ids.reshape(-1)
|
||||
|
||||
@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# This will be overridden in load_model()
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
# Always set to false after the first forward pass
|
||||
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
|
||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
@ -2587,27 +2590,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
)
|
||||
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
if ubatch_slices:
|
||||
assert num_tokens_across_dp is not None
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif num_tokens_across_dp is not None:
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
else:
|
||||
num_input_tokens = self._get_num_input_tokens(
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
)
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
if ubatch_slices:
|
||||
assert num_tokens_across_dp is not None
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif num_tokens_across_dp is not None:
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
else:
|
||||
num_input_tokens = self._get_num_input_tokens(
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
) = self._preprocess(
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
) = self._preprocess(
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
@ -2625,16 +2629,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||
if attn_metadata is not None:
|
||||
metadata_list = (
|
||||
attn_metadata.values()
|
||||
if isinstance(attn_metadata, dict)
|
||||
else [attn_metadata]
|
||||
)
|
||||
if any(
|
||||
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
|
||||
):
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# KV scales calculation involves dynamic operations that are incompatible
|
||||
# with CUDA graph capture.
|
||||
if self.calculate_kv_scales:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# Mark KV scales as calculated after the first forward pass
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
|
||||
Reference in New Issue
Block a user