[Frontend][torch.compile] CompilationConfig Overhaul (#20283): name change compilation level to compilation mode, deprecation compilation level (#26355)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Morrison Turnansky
2025-10-14 22:51:16 -04:00
committed by GitHub
parent e66d787bce
commit 96b9aa5aa0
42 changed files with 270 additions and 248 deletions

View File

@ -58,12 +58,12 @@ You can adjust `compilation_config` to achieve a better balance between inferenc
```python
from vllm import LLM
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationMode
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
# By default, it goes up to max_num_seqs
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
),

View File

@ -167,7 +167,7 @@ class AttentionCGSupport(enum.Enum):
"""NO CUDA Graphs support"""
```
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation level. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].
The following table lists backends that support full CUDA Graphs at the time of writing.
@ -202,7 +202,7 @@ os.environ.setdefault("VLLM_LOGGING_LEVEL", "DEBUG")
import vllm
from vllm.config import CUDAGraphMode
compilation_config = {"level": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"}
compilation_config = {"mode": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"}
model = vllm.LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
dtype="auto",

View File

@ -95,7 +95,7 @@ def parse_args():
parser.add_argument(
"--compilation-config",
type=int,
help=("Compilation optimization (O) level 0-3."),
help=("Compilation optimization (O) mode 0-3."),
)
parser.add_argument(
"--quantization",

View File

@ -14,7 +14,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
@ -199,10 +199,10 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
outputs = []
# piecewise compile
# vllmcompile compile
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
@ -251,7 +251,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
# no compile or cudagraph
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION,
mode=CompilationMode.NONE,
)
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
@ -280,7 +280,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
# piecewise compile without CUDA graph
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=False,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,

View File

@ -13,7 +13,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
@ -61,7 +61,7 @@ def _run_simple_model(
):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,

View File

@ -21,7 +21,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
@ -356,13 +356,13 @@ def test_toy_llama(
)
compile_config_no_compile = CompilationConfig(
level=CompilationLevel.NO_COMPILATION,
level=CompilationMode.NONE,
cudagraph_mode=CUDAGraphMode.NONE,
backend="eager",
)
compile_config_no_split = CompilationConfig(
level=CompilationLevel.PIECEWISE,
level=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
backend=backend,
@ -458,14 +458,14 @@ def benchmark():
for piecewise in [False, True]:
if piecewise:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
cudagraph_capture_sizes=cudagraph_sizes,
)

View File

@ -10,7 +10,7 @@ import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
CompilationMode,
VllmConfig,
set_current_vllm_config,
)
@ -38,7 +38,7 @@ class CompiledMod(torch.nn.Module):
def make_vllm_config() -> VllmConfig:
return VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
level=CompilationMode.VLLM_COMPILE,
)
)

View File

@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass
from vllm.config import (
CompilationConfig,
CompilationMode,
DeviceConfig,
ModelConfig,
PassConfig,
@ -400,7 +401,7 @@ def test_async_tp_pass_correctness(
common_args.append("--enforce-eager")
compilation_config = {
"level": 3,
"mode": CompilationMode.VLLM_COMPILE,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {"enable_async_tp": async_tp_enabled},

View File

@ -4,7 +4,7 @@ import dataclasses
import pytest
from vllm.config import CompilationLevel
from vllm.config import CompilationMode
from vllm.utils import cuda_device_count_stateless
from ..utils import compare_all_settings
@ -21,7 +21,7 @@ class TestSetting:
# we cannot afford testing the full Cartesian product
# of all models and all levels
# of all models and all modes
@pytest.mark.parametrize(
"test_setting",
[
@ -121,15 +121,13 @@ def test_compile_correctness(
all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = []
for comp_level in [
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.PIECEWISE,
for comp_mode in [
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
for level in [CompilationLevel.NO_COMPILATION, comp_level]:
all_args.append(
final_args + [f"-O.level={level}", "-O.backend=inductor"]
)
for mode in [CompilationMode.NONE, comp_mode]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"])
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
@ -142,13 +140,13 @@ def test_compile_correctness(
all_envs.clear()
all_args.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.PIECEWISE,
for mode in [
CompilationMode.NONE,
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"])
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({})

View File

@ -4,7 +4,7 @@ import pytest
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationLevel
from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
@ -90,16 +90,16 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch):
def test_stock_torch_compile(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(dynamo_as_is_count=1),
compilation_counter.expect(stock_torch_compile_count=1),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config={"level": 1},
compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
gpu_memory_utilization=0.4,
) as _,
):
@ -112,11 +112,11 @@ def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config={"level": 0},
compilation_config={"mode": CompilationMode.NONE},
gpu_memory_utilization=0.4,
) as _,
):
@ -130,7 +130,7 @@ def test_enforce_eager(vllm_runner, monkeypatch):
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
@ -151,7 +151,7 @@ def test_splitting_ops_dynamic():
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
level=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
@ -163,7 +163,7 @@ def test_splitting_ops_dynamic():
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
level=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
@ -178,7 +178,7 @@ def test_splitting_ops_dynamic():
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
level=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],

View File

@ -8,7 +8,7 @@ from vllm.compilation.decorators import ignore_torch_compile, support_torch_comp
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationLevel,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
@ -66,10 +66,10 @@ def run_model(
def test_ignore_torch_compile_decorator():
# piecewise
# vllmcompile
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
@ -185,7 +185,7 @@ def test_conditional_compile_enable_if():
kv_sharing_fast_prefill=True,
),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
kv_sharing_fast_prefill=False,
),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],

View File

@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
@ -80,22 +80,22 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
"compilation_mode",
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
)
@pytest.mark.parametrize("model_info", models_list(all=True))
@create_new_process_for_each_test()
def test_full_graph(
monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]],
optimization_level: int,
compilation_mode: int,
):
model, model_kwargs = model_info
with monkeypatch.context():
print(f"MODEL={model}")
run_model(optimization_level, model, model_kwargs)
run_model(compilation_mode, model, model_kwargs)
# TODO(luka) add other supported compilation config scenarios here
@ -104,7 +104,7 @@ def test_full_graph(
[
# additional compile sizes, only some of the models
(
CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
model,
)
for model in models_list(all=False)
@ -113,7 +113,7 @@ def test_full_graph(
# RMSNorm + quant fusion, only 8-bit quant models
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
@ -125,7 +125,8 @@ def test_full_graph(
# Test depyf integration works
(
CompilationConfig(
level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
mode=CompilationMode.VLLM_COMPILE,
debug_dump_path=tempfile.gettempdir(),
),
("facebook/opt-125m", {}),
),
@ -134,7 +135,7 @@ def test_full_graph(
# graph inductor partition
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
@ -164,10 +165,10 @@ def test_custom_compile_config(
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
"compilation_mode",
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
)
def test_fp8_kv_scale_compile(optimization_level: int):
def test_fp8_kv_scale_compile(compilation_mode: int):
model = "Qwen/Qwen2-0.5B"
model_kwargs = {
"quantization": "fp8",
@ -175,7 +176,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
"calculate_kv_scales": True,
"max_model_len": 512,
}
run_model(optimization_level, model, model_kwargs)
run_model(compilation_mode, model, model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
@ -184,7 +185,7 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],

View File

@ -13,7 +13,7 @@ from vllm.compilation.fusion import (
)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@ -114,7 +114,7 @@ def test_fusion_rmsnorm_quant(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
)

View File

@ -12,7 +12,7 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
CompilationConfig,
CompilationLevel,
CompilationMode,
DeviceConfig,
ModelConfig,
PassConfig,
@ -219,7 +219,7 @@ def all_reduce_fusion_pass_on_test_model(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"]
)
)
vllm_config.compilation_config.pass_config = PassConfig(

View File

@ -19,7 +19,7 @@ from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationLevel,
CompilationMode,
ModelConfig,
PassConfig,
SchedulerConfig,
@ -321,7 +321,7 @@ def test_attention_quant_pattern(
),
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
),

View File

@ -6,7 +6,7 @@ import torch
import vllm
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
from .backend import TestBackend
@ -50,7 +50,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_noop=True),
)
)
@ -98,7 +98,7 @@ def test_non_noop_slice_preserved():
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_noop=True),
)
)

View File

@ -5,7 +5,7 @@
import torch
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel
from vllm.config import CompilationMode
class MyMod(torch.nn.Module):
@ -20,7 +20,7 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(
compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
)
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):

View File

@ -15,6 +15,7 @@ from typing import Literal, NamedTuple
import pytest
from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
@ -234,7 +235,7 @@ def _compare_sp(
common_args.append("--skip-tokenizer-init")
compilation_config = {
"level": 3,
"mode": CompilationMode.VLLM_COMPILE,
"custom_ops": ["+rms_norm"],
"compile_sizes": [4, 8],
"pass_config": {

View File

@ -226,30 +226,30 @@ def test_compilation_config():
# set to O3
args = parser.parse_args(["-O0"])
assert args.compilation_config.level == 0
assert args.compilation_config.mode == 0
# set to O 3 (space)
args = parser.parse_args(["-O", "1"])
assert args.compilation_config.level == 1
assert args.compilation_config.mode == 1
# set to O 3 (equals)
args = parser.parse_args(["-O=2"])
assert args.compilation_config.level == 2
assert args.compilation_config.mode == 2
# set to O.level 3
args = parser.parse_args(["-O.level", "3"])
assert args.compilation_config.level == 3
# set to O.mode 3
args = parser.parse_args(["-O.mode", "3"])
assert args.compilation_config.mode == 3
# set to string form of a dict
args = parser.parse_args(
[
"-O",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": false}',
]
)
assert (
args.compilation_config.level == 3
args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and not args.compilation_config.use_inductor
)
@ -258,12 +258,12 @@ def test_compilation_config():
args = parser.parse_args(
[
"--compilation-config="
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": true}',
]
)
assert (
args.compilation_config.level == 3
args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and args.compilation_config.use_inductor
)

View File

@ -3,7 +3,7 @@
import pytest
from vllm.config import CompilationLevel
from vllm.config import CompilationMode
from ..utils import compare_two_settings
@ -21,13 +21,13 @@ def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
"--max-model-len=256",
"--max-num-seqs=32",
"--enforce-eager",
f"-O{CompilationLevel.DYNAMO_ONCE}",
f"-O{CompilationMode.DYNAMO_TRACE_ONCE}",
],
arg2=[
"--max-model-len=256",
"--max-num-seqs=32",
"--enforce-eager",
f"-O{CompilationLevel.DYNAMO_AS_IS}",
f"-O{CompilationMode.STOCK_TORCH_COMPILE}",
],
env1={},
env2={},

View File

@ -299,7 +299,7 @@ def test_dict_args(parser):
"val2",
"--hf-overrides.key2.key4",
"val3",
# Test compile config and compilation level
# Test compile config and compilation mode
"-O.use_inductor=true",
"-O.backend",
"custom",
@ -352,7 +352,7 @@ def test_dict_args(parser):
},
}
assert parsed_args.compilation_config == {
"level": 1,
"mode": 1,
"use_inductor": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
@ -367,7 +367,7 @@ def test_duplicate_dict_args(caplog_vllm, parser):
"--hf-overrides.key1",
"val2",
"-O1",
"-O.level",
"-O.mode",
"2",
"-O3",
]
@ -375,12 +375,12 @@ def test_duplicate_dict_args(caplog_vllm, parser):
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
assert parsed_args.compilation_config == {"level": 3}
assert parsed_args.compilation_config == {"mode": 3}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
assert "-O.level" in caplog_vllm.text
assert "-O.mode" in caplog_vllm.text
@pytest.mark.parametrize(

View File

@ -11,7 +11,7 @@ from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationConfig,
CompilationLevel,
CompilationMode,
CUDAGraphMode,
ParallelConfig,
SchedulerConfig,
@ -42,7 +42,7 @@ def _create_vllm_config(
mock_config.parallel_config = ParallelConfig()
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.level == CompilationLevel.PIECEWISE:
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
return mock_config
@ -50,23 +50,23 @@ def _create_vllm_config(
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"case_id,cudagraph_mode_str,compilation_level",
"case_id,cudagraph_mode_str,compilation_mode",
[
# Test case 0: Full CG for mixed batches, no separate routine
(0, "FULL", CompilationLevel.NO_COMPILATION),
(0, "FULL", CompilationMode.NONE),
# Test case 1: Full CG for uniform batches, piecewise for mixed
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
(1, "FULL_AND_PIECEWISE", CompilationMode.NONE),
# Test case 2: Full CG for uniform batches, no CG for mixed
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
# Test case 3: Piecewise for all
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
(2, "FULL_DECODE_ONLY", CompilationMode.NONE),
# Test case 3: PIECEWISE for all
(3, "PIECEWISE", CompilationMode.VLLM_COMPILE),
],
)
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
def test_dispatcher(self, cudagraph_mode_str, compilation_mode):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str,
level=compilation_level,
mode=compilation_mode,
cudagraph_capture_sizes=[1, 8],
)
@ -242,7 +242,7 @@ class TestCudagraphIntegration:
def setup_method(self):
# only FULL mode for non-uniform batches
self.comp_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode="FULL",
cudagraph_capture_sizes=[10, 20],
)

View File

@ -10,7 +10,7 @@ import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform
@ -73,7 +73,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=3, cudagraph_mode=cudagraph_mode
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
),
)
llm.generate(["Hello, my name is"] * 10)
@ -90,32 +90,27 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
)
# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
combo_cases_2 = [
("FA2", "FULL", 0, True), # no compilation + full cudagraph
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
("FA2", "PIECEWISE", 3, True), # piecewise compilation + piecewise cudagraph
(
"FA2",
"FULL_AND_PIECEWISE",
0,
False,
), # piecewise cudagraph not supported without piecewise compilation
("FA2", "FULL_AND_PIECEWISE", 3, True),
("FA2", "FULL_DECODE_ONLY", 0, True),
("FA2", "FULL_DECODE_ONLY", 3, True),
("FA2", "NONE", 0, True), # no compilation + no cudagraph
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
("FA2", "FULL", CompilationMode.NONE, True),
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
("FA2", "PIECEWISE", CompilationMode.NONE, False),
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("FA2", "NONE", CompilationMode.NONE, True),
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
]
@pytest.mark.parametrize(
"backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2
"backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
)
def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_level, supported = combo_case
backend_name, cudagraph_mode, compilation_mode, supported = combo_case
env_vars = backend_configs[backend_name].env_vars
@ -130,7 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=compilation_level, cudagraph_mode=cudagraph_mode
mode=compilation_mode, cudagraph_mode=cudagraph_mode
),
)
llm.generate(["Hello, my name is"] * 10)

View File

@ -7,7 +7,7 @@ import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...utils import fork_new_process_for_each_test
@ -75,9 +75,9 @@ def test_kv_sharing_fast_prefill(
# This allows vLLM compilation backend to handle allocating and
# managing buffers for cudagraph
cudagraph_copy_inputs=True,
level=CompilationLevel.PIECEWISE
mode=CompilationMode.VLLM_COMPILE
if not enforce_eager
else CompilationLevel.NO_COMPILATION,
else CompilationMode.NONE,
)
with monkeypatch.context() as m:

View File

@ -56,7 +56,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationLevel.PIECEWISE"
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
)
logger.debug("Using EagerAdaptor")
@ -481,7 +481,7 @@ def set_model_tag(tag: str):
class VllmBackend:
"""The compilation backend for `torch.compile` with vLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
where we customize the compilation.
The major work of this backend is to split the graph into

View File

@ -575,7 +575,7 @@ class InductorAdaptor(CompilerInterface):
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
should be set at a different mode of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public

View File

@ -27,8 +27,8 @@ class CompilationCounter:
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
dynamo_as_is_count: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)

View File

@ -18,7 +18,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
@ -233,11 +233,11 @@ def _support_torch_compile(
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
vllm_config.compilation_config.level
in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS]
vllm_config.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
or not enable_compile
@ -247,7 +247,7 @@ def _support_torch_compile(
compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level
self, compilation_mode=vllm_config.compilation_config.mode
)
cls.__init__ = __init__

View File

@ -3,7 +3,7 @@
import time
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -18,7 +18,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
compilation_config: CompilationConfig = vllm_config.compilation_config
path = vllm_config.compile_debug_dump_path()
if compilation_config.level == CompilationLevel.PIECEWISE and path:
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
import depyf
path.mkdir(parents=True, exist_ok=True)
@ -29,7 +29,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
def end_monitoring_torch_compile(vllm_config: VllmConfig):
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE:
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info(
"torch.compile takes %.2f s in total", compilation_config.compilation_time
)

View File

@ -11,7 +11,7 @@ from types import CodeType
import torch
import vllm.envs as envs
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -31,7 +31,7 @@ class TorchCompileWrapperWithCustomDispatcher:
"""
def __init__(
self, compiled_callable: Callable | None = None, compilation_level: int = 0
self, compiled_callable: Callable | None = None, compilation_mode: int = 0
):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
@ -72,7 +72,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = (
compilation_level >= CompilationLevel.DYNAMO_ONCE
compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE
)
def aot_compile(self, *args, **kwargs):
@ -85,7 +85,7 @@ class TorchCompileWrapperWithCustomDispatcher:
return self.compiled_callable.aot_compile((args, kwargs))
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
"""Implement the dispatch logic here, beyond the torch.compile mode.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""

View File

@ -4,7 +4,7 @@
from vllm.config.cache import CacheConfig
from vllm.config.compilation import (
CompilationConfig,
CompilationLevel,
CompilationMode,
CUDAGraphMode,
PassConfig,
)
@ -49,7 +49,7 @@ __all__ = [
"CacheConfig",
# From vllm.config.compilation
"CompilationConfig",
"CompilationLevel",
"CompilationMode",
"CUDAGraphMode",
"PassConfig",
# From vllm.config.device

View File

@ -26,12 +26,20 @@ else:
logger = init_logger(__name__)
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3
class CompilationMode:
"""The compilation approach used for torch.compile-based compilation of the
model."""
NONE = 0
"""No torch.compile compilation is applied, model runs in fully eager pytorch mode.
The model runs as-is."""
STOCK_TORCH_COMPILE = 1
"""The standard `torch.compile` compilation pipeline."""
DYNAMO_TRACE_ONCE = 2
"""Single Dynamo trace through the model, avoiding recompilation."""
VLLM_COMPILE = 3
"""Custom vLLM Inductor-based backend with caching, piecewise compilation,
shape specialization, and custom passes."""
class CUDAGraphMode(enum.Enum):
@ -134,7 +142,7 @@ class CompilationConfig:
"""Configuration for compilation. It has three parts:
- Top-level Compilation control:
- [`level`][vllm.config.CompilationConfig.level]
- [`mode`][vllm.config.CompilationConfig.mode]
- [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
- [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
- [`backend`][vllm.config.CompilationConfig.backend]
@ -171,14 +179,26 @@ class CompilationConfig:
# Top-level Compilation control
level: int | None = None
"""The level of compilation:
"""
Level is deprecated and will be removed in the next release,
either 0.12.0 or 0.11.2 whichever is soonest.
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
mode: int | None = None
"""The compilation approach used for torch.compile-based compilation of the
model.
- None: If None, we will select the default compilation level.
For V1 engine this is 3, for V0 engine this is 0.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation."""
- None: If None, we will select the default compilation mode.
For V1 engine this is 3.
- 0: NONE: No torch.compile compilation is applied, model runs in fully
eager pytorch mode. The model runs as-is.
- 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
- 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding
recompilation by removing guards.
Requires no dynamic-shape-dependent control-flow.
- 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching,
piecewise compilation, shape specialization, and custom passes."""
debug_dump_path: Path | None = None
"""The path to dump the debug information."""
cache_dir: str = ""
@ -195,11 +215,11 @@ class CompilationConfig:
backend function.
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation level is 1 or 2, the backend is
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
compilation mode is 3, the backend is used for the piecewise compilation
(it sees a part of the graph). The backend can not be custom for compilation
level 3, i.e. the backend must be either eager or inductor. Furthermore,
mode 3, i.e. the backend must be either eager or inductor. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
@ -214,7 +234,7 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
@ -249,7 +269,7 @@ class CompilationConfig:
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE.
This setting is ignored if mode<VLLM_COMPILE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
@ -299,7 +319,7 @@ class CompilationConfig:
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
@ -312,7 +332,7 @@ class CompilationConfig:
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
In the vLLM V1 Engine, this flag only applies for
CompilationLevel.PIECEWISE (aka -O3).
CompilationMode.VLLM_COMPILE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
Warning: This flag is deprecated and will be removed in the next major or
@ -426,7 +446,7 @@ class CompilationConfig:
the final hidden states.
"""
factors: list[Any] = []
factors.append(self.level)
factors.append(self.mode)
factors.append(self.backend)
factors.append(self.custom_ops)
factors.append(self.splitting_ops)
@ -477,6 +497,17 @@ class CompilationConfig:
return value
def __post_init__(self) -> None:
if self.level is not None:
logger.warning(
"Level is deprecated and will be removed in the next release,"
"either 0.12.0 or 0.11.2 whichever is soonest."
"Use mode instead."
"If both level and mode are given,"
"only mode will be used."
)
if self.mode is None:
self.mode = self.level
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
@ -574,7 +605,7 @@ class CompilationConfig:
# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [
"",
"eager",
"inductor",
@ -602,24 +633,27 @@ class CompilationConfig:
Returns:
The backend for the compilation config.
"""
if self.level is None:
if self.mode is None:
raise ValueError(
"No compilation level is set. This method should only be \
"No compilation mode is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
if self.mode == CompilationMode.NONE:
raise ValueError("No compilation mode is set.")
from torch._dynamo.backends.registry import list_backends
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.mode in [
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
]:
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
assert self.level == CompilationLevel.PIECEWISE
assert self.mode == CompilationMode.VLLM_COMPILE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
@ -684,11 +718,11 @@ class CompilationConfig:
self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when level is
# CompilationLevel.PIECEWISE
assert self.level == CompilationLevel.PIECEWISE, (
# NOTE: this function needs to be called only when mode is
# CompilationMode.VLLM_COMPILE
assert self.mode == CompilationMode.VLLM_COMPILE, (
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE"
"mode is CompilationMode.VLLM_COMPILE"
)
if self.use_inductor_graph_partition:
@ -769,12 +803,10 @@ class CompilationConfig:
if not self.use_inductor_graph_partition:
# Dynamo-level FX split case
return self.level == CompilationLevel.PIECEWISE
return self.mode == CompilationMode.VLLM_COMPILE
# Inductor partition case
return (
self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION
)
return self.backend == "inductor" and self.mode > CompilationMode.NONE
def custom_op_log_check(self):
"""

View File

@ -22,7 +22,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
@ -84,17 +84,11 @@ class VllmConfig:
compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
"""`torch.compile` and cudagraph capture configuration for the model.
As a shorthand, `-O<n>` can be used to directly specify the compilation
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
Currently, -O <n> and -O=<n> are supported as well but this will likely be
removed in favor of clearer -O<n> syntax in the future.
NOTE: level 0 is the default level without any optimization. level 1 and 2
are for internal testing only. level 3 is the recommended level for
production, also default in V1.
As a shorthand, one can append compilation arguments via
-0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`).
You can specify the full compilation config like so:
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer."""
@ -305,33 +299,33 @@ class VllmConfig:
"precision for chunked prefill triton kernels."
)
# If the user does not explicitly set a compilation level, then
# we use the default level. The default level depends on other
# If the user does not explicitly set a compilation mode, then
# we use the default mode. The default mode depends on other
# settings (see the below code).
if self.compilation_config.level is None:
if self.compilation_config.mode is None:
if envs.VLLM_USE_V1:
if (
self.model_config is not None
and not self.model_config.enforce_eager
):
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else:
self.compilation_config.level = CompilationLevel.NO_COMPILATION
self.compilation_config.mode = CompilationMode.NONE
else:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
# NB: Passing both --enforce-eager and a compilation mode
# in V0 means the compilation mode wins out.
self.compilation_config.mode = CompilationMode.NONE
else:
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
assert self.compilation_config.mode >= CompilationMode.NONE
assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
# If user does not set custom ops via none or all set it here based on
# compilation level and backend.
# compilation mode and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
and self.compilation_config.level > CompilationLevel.NO_COMPILATION
and self.compilation_config.mode > CompilationMode.NONE
):
self.compilation_config.custom_ops.append("none")
else:
@ -350,7 +344,7 @@ class VllmConfig:
if self.compilation_config.cudagraph_mode is None:
if (
envs.VLLM_USE_V1
and self.compilation_config.level == CompilationLevel.PIECEWISE
and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
):
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = (
@ -486,10 +480,10 @@ class VllmConfig:
)
current_platform.check_and_update_config(self)
# Do this after all the updates to compilation_config.level
# Do this after all the updates to compilation_config.mode
if (
envs.VLLM_USE_V1
and self.compilation_config.level == CompilationLevel.PIECEWISE
and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
):
self.compilation_config.set_splitting_ops_for_v1()
@ -508,8 +502,8 @@ class VllmConfig:
)
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
assert self.compilation_config.level == CompilationLevel.PIECEWISE, (
"Compilation level should be CompilationLevel.PIECEWISE "
assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
"Compilation mode should be CompilationMode.VLLM_COMPILE "
"when cudagraph_mode piecewise cudagraphs is used, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
@ -837,7 +831,7 @@ def set_current_vllm_config(
if (
check_compile
and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and compilation_counter.num_models_seen == num_models_seen
):
# If the model supports compilation,

View File

@ -176,7 +176,7 @@ class LLM:
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
whichever is sooner.
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the level of compilation optimization. If it
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
@ -257,9 +257,7 @@ class LLM:
if compilation_config is not None:
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
level=compilation_config
)
compilation_config_instance = CompilationConfig(mode=compilation_config)
elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig(
**{

View File

@ -8,7 +8,7 @@ from packaging import version
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
@ -419,7 +419,7 @@ class Fp8LinearOp:
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = (
config.level < CompilationLevel.PIECEWISE
config.mode < CompilationMode.VLLM_COMPILE
and self.preferred_backend == "torch"
)

View File

@ -247,12 +247,12 @@ class CpuPlatform(Platform):
parallel_config.enable_dbo = False
# Note: workaround for v1 gpu_model_runner
from vllm.config import CompilationLevel
from vllm.config import CompilationMode
vllm_config.compilation_config.cudagraph_capture_sizes = []
compilation_config = vllm_config.compilation_config
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
# Note: vLLM V1 is using PIECEWISE level compilation, which will
# take time to compile kernels just-in-time with the inductor
# backend. For CPU CI tests, most of them are executed fast and
@ -265,7 +265,7 @@ class CpuPlatform(Platform):
else:
backend = "inductor"
compilation_config.level = CompilationLevel.DYNAMO_ONCE
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
compilation_config.backend = backend
compilation_config.inductor_compile_config.update(
{
@ -277,7 +277,7 @@ class CpuPlatform(Platform):
)
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.mode = CompilationMode.NONE
assert vllm_config.device_config.device_type == "cpu"

View File

@ -114,7 +114,7 @@ class TpuPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel, CUDAGraphMode
from vllm.config import CompilationMode, CUDAGraphMode
cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
@ -122,12 +122,13 @@ class TpuPlatform(Platform):
cache_config.block_size = cast(BlockSize, 16)
compilation_config = vllm_config.compilation_config
# TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
logger.info(
"[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph."
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
disabling cudagraph."
)
compilation_config.level = CompilationLevel.DYNAMO_ONCE
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
if (
compilation_config.cudagraph_mode is None

View File

@ -144,7 +144,7 @@ class XPUPlatform(Platform):
cache_config.block_size = 64
# lazy import to avoid circular import
from vllm.config import CompilationLevel, CUDAGraphMode
from vllm.config import CompilationMode, CUDAGraphMode
compilation_config = vllm_config.compilation_config
if compilation_config.compile_sizes is None:
@ -155,7 +155,7 @@ class XPUPlatform(Platform):
)
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.mode = CompilationMode.NONE
# check and update parallel config
parallel_config = vllm_config.parallel_config

View File

@ -1686,16 +1686,16 @@ class FlexibleArgumentParser(ArgumentParser):
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<level> here
level = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.level={level}")
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.level <n>
processed_args.append("-O.level")
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
else:
processed_args.append(arg)

View File

@ -43,12 +43,12 @@ class CudagraphDispatcher:
not_use_piecewise_compilation
or self.compilation_config.is_attention_compiled_piecewise()
), (
"Compilation level should be CompilationLevel.PIECEWISE when "
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or "
"inductor splitting should be used. "
f"cudagraph_mode={self.cudagraph_mode}, "
f"compilation_level={self.compilation_config.level}, "
f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}"
)

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from vllm.config import (
CompilationLevel,
CompilationMode,
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
@ -86,7 +86,7 @@ class EagleProposer:
self.use_cuda_graph = False
compilation_config = self.vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE:
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
cudagraph_mode = compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
CUDAGraphMode.PIECEWISE

View File

@ -25,7 +25,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationLevel,
CompilationMode,
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
@ -2927,14 +2927,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
if (
self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS
self.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
and supports_dynamo()
):
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
compilation_counter.dynamo_as_is_count += 1
compilation_counter.stock_torch_compile_count += 1
self.model.compile(fullgraph=True, backend=backend)
return
# for other compilation levels, cudagraph behavior is controlled by
# for other compilation modes, cudagraph behavior is controlled by
# CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
@ -3985,7 +3986,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# if not supported any full cudagraphs, just raise it.
msg += (
"; please try cudagraph_mode=PIECEWISE, and "
"make sure compilation level is piecewise"
"make sure compilation mode is VLLM_COMPILE"
)
raise ValueError(msg)
@ -4012,7 +4013,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})"
)
if self.compilation_config.level == CompilationLevel.PIECEWISE and (
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition
):
@ -4068,7 +4069,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
f"supported with {min_cg_builder_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise"
"and make sure compilation mode is VLLM_COMPILE"
)
# Trigger cudagraph dispatching keys initialization here (after