mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend][torch.compile] CompilationConfig Overhaul (#20283): name change compilation level to compilation mode, deprecation compilation level (#26355)
Signed-off-by: morrison-turnansky <mturnans@redhat.com> Signed-off-by: Morrison Turnansky <mturnans@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e66d787bce
commit
96b9aa5aa0
@ -58,12 +58,12 @@ You can adjust `compilation_config` to achieve a better balance between inferenc
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
# By default, it goes up to max_num_seqs
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
),
|
||||
|
@ -167,7 +167,7 @@ class AttentionCGSupport(enum.Enum):
|
||||
"""NO CUDA Graphs support"""
|
||||
```
|
||||
|
||||
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation level. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].
|
||||
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].
|
||||
|
||||
The following table lists backends that support full CUDA Graphs at the time of writing.
|
||||
|
||||
@ -202,7 +202,7 @@ os.environ.setdefault("VLLM_LOGGING_LEVEL", "DEBUG")
|
||||
import vllm
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = {"level": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"}
|
||||
compilation_config = {"mode": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"}
|
||||
model = vllm.LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
dtype="auto",
|
||||
|
@ -95,7 +95,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--compilation-config",
|
||||
type=int,
|
||||
help=("Compilation optimization (O) level 0-3."),
|
||||
help=("Compilation optimization (O) mode 0-3."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
|
@ -14,7 +14,7 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
@ -199,10 +199,10 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
|
||||
outputs = []
|
||||
|
||||
# piecewise compile
|
||||
# vllmcompile compile
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
@ -251,7 +251,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
mode=CompilationMode.NONE,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
@ -280,7 +280,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
@ -13,7 +13,7 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
@ -61,7 +61,7 @@ def _run_simple_model(
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=splitting_ops,
|
||||
|
@ -21,7 +21,7 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
@ -356,13 +356,13 @@ def test_toy_llama(
|
||||
)
|
||||
|
||||
compile_config_no_compile = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
level=CompilationMode.NONE,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
backend="eager",
|
||||
)
|
||||
|
||||
compile_config_no_split = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
backend=backend,
|
||||
@ -458,14 +458,14 @@ def benchmark():
|
||||
for piecewise in [False, True]:
|
||||
if piecewise:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
@ -38,7 +38,7 @@ class CompiledMod(torch.nn.Module):
|
||||
def make_vllm_config() -> VllmConfig:
|
||||
return VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
@ -400,7 +401,7 @@ def test_async_tp_pass_correctness(
|
||||
common_args.append("--enforce-eager")
|
||||
|
||||
compilation_config = {
|
||||
"level": 3,
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"enable_async_tp": async_tp_enabled},
|
||||
|
@ -4,7 +4,7 @@ import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import compare_all_settings
|
||||
@ -21,7 +21,7 @@ class TestSetting:
|
||||
|
||||
|
||||
# we cannot afford testing the full Cartesian product
|
||||
# of all models and all levels
|
||||
# of all models and all modes
|
||||
@pytest.mark.parametrize(
|
||||
"test_setting",
|
||||
[
|
||||
@ -121,15 +121,13 @@ def test_compile_correctness(
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
|
||||
for comp_level in [
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
CompilationLevel.PIECEWISE,
|
||||
for comp_mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for level in [CompilationLevel.NO_COMPILATION, comp_level]:
|
||||
all_args.append(
|
||||
final_args + [f"-O.level={level}", "-O.backend=inductor"]
|
||||
)
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"])
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
@ -142,13 +140,13 @@ def test_compile_correctness(
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
CompilationLevel.PIECEWISE,
|
||||
for mode in [
|
||||
CompilationMode.NONE,
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"])
|
||||
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationLevel
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
|
||||
|
||||
@ -90,16 +90,16 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||
def test_stock_torch_compile(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(dynamo_as_is_count=1),
|
||||
compilation_counter.expect(stock_torch_compile_count=1),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config={"level": 1},
|
||||
compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
@ -112,11 +112,11 @@ def test_no_compilation(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config={"level": 0},
|
||||
compilation_config={"mode": CompilationMode.NONE},
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
@ -130,7 +130,7 @@ def test_enforce_eager(vllm_runner, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
|
||||
@ -151,7 +151,7 @@ def test_splitting_ops_dynamic():
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
@ -163,7 +163,7 @@ def test_splitting_ops_dynamic():
|
||||
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
@ -178,7 +178,7 @@ def test_splitting_ops_dynamic():
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
|
@ -8,7 +8,7 @@ from vllm.compilation.decorators import ignore_torch_compile, support_torch_comp
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
@ -66,10 +66,10 @@ def run_model(
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
# piecewise
|
||||
# vllmcompile
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
@ -185,7 +185,7 @@ def test_conditional_compile_enable_if():
|
||||
kv_sharing_fast_prefill=True,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
|
||||
kv_sharing_fast_prefill=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
|
@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
@ -80,22 +80,22 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"optimization_level",
|
||||
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
|
||||
"compilation_mode",
|
||||
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
@pytest.mark.parametrize("model_info", models_list(all=True))
|
||||
@create_new_process_for_each_test()
|
||||
def test_full_graph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
optimization_level: int,
|
||||
compilation_mode: int,
|
||||
):
|
||||
model, model_kwargs = model_info
|
||||
|
||||
with monkeypatch.context():
|
||||
print(f"MODEL={model}")
|
||||
|
||||
run_model(optimization_level, model, model_kwargs)
|
||||
run_model(compilation_mode, model, model_kwargs)
|
||||
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@ -104,7 +104,7 @@ def test_full_graph(
|
||||
[
|
||||
# additional compile sizes, only some of the models
|
||||
(
|
||||
CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
|
||||
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
|
||||
model,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
@ -113,7 +113,7 @@ def test_full_graph(
|
||||
# RMSNorm + quant fusion, only 8-bit quant models
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
),
|
||||
@ -125,7 +125,8 @@ def test_full_graph(
|
||||
# Test depyf integration works
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
debug_dump_path=tempfile.gettempdir(),
|
||||
),
|
||||
("facebook/opt-125m", {}),
|
||||
),
|
||||
@ -134,7 +135,7 @@ def test_full_graph(
|
||||
# graph inductor partition
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
# inductor graph partition uses
|
||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||
use_inductor_graph_partition=True,
|
||||
@ -164,10 +165,10 @@ def test_custom_compile_config(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"optimization_level",
|
||||
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
|
||||
"compilation_mode",
|
||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(optimization_level: int):
|
||||
def test_fp8_kv_scale_compile(compilation_mode: int):
|
||||
model = "Qwen/Qwen2-0.5B"
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
@ -175,7 +176,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
run_model(optimization_level, model, model_kwargs)
|
||||
run_model(compilation_mode, model, model_kwargs)
|
||||
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
@ -184,7 +185,7 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
|
||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
|
@ -13,7 +13,7 @@ from vllm.compilation.fusion import (
|
||||
)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@ -114,7 +114,7 @@ def test_fusion_rmsnorm_quant(
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
@ -219,7 +219,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
|
@ -19,7 +19,7 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
SchedulerConfig,
|
||||
@ -321,7 +321,7 @@ def test_attention_quant_pattern(
|
||||
),
|
||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
@ -50,7 +50,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
)
|
||||
)
|
||||
@ -98,7 +98,7 @@ def test_non_noop_slice_preserved():
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
)
|
||||
)
|
||||
|
@ -5,7 +5,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
@ -20,7 +20,7 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||
self.model = model
|
||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||
super().__init__(
|
||||
compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
|
||||
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
||||
|
@ -15,6 +15,7 @@ from typing import Literal, NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -234,7 +235,7 @@ def _compare_sp(
|
||||
common_args.append("--skip-tokenizer-init")
|
||||
|
||||
compilation_config = {
|
||||
"level": 3,
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"custom_ops": ["+rms_norm"],
|
||||
"compile_sizes": [4, 8],
|
||||
"pass_config": {
|
||||
|
@ -226,30 +226,30 @@ def test_compilation_config():
|
||||
|
||||
# set to O3
|
||||
args = parser.parse_args(["-O0"])
|
||||
assert args.compilation_config.level == 0
|
||||
assert args.compilation_config.mode == 0
|
||||
|
||||
# set to O 3 (space)
|
||||
args = parser.parse_args(["-O", "1"])
|
||||
assert args.compilation_config.level == 1
|
||||
assert args.compilation_config.mode == 1
|
||||
|
||||
# set to O 3 (equals)
|
||||
args = parser.parse_args(["-O=2"])
|
||||
assert args.compilation_config.level == 2
|
||||
assert args.compilation_config.mode == 2
|
||||
|
||||
# set to O.level 3
|
||||
args = parser.parse_args(["-O.level", "3"])
|
||||
assert args.compilation_config.level == 3
|
||||
# set to O.mode 3
|
||||
args = parser.parse_args(["-O.mode", "3"])
|
||||
assert args.compilation_config.mode == 3
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"-O",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": false}',
|
||||
]
|
||||
)
|
||||
assert (
|
||||
args.compilation_config.level == 3
|
||||
args.compilation_config.mode == 3
|
||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and not args.compilation_config.use_inductor
|
||||
)
|
||||
@ -258,12 +258,12 @@ def test_compilation_config():
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--compilation-config="
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": true}',
|
||||
]
|
||||
)
|
||||
assert (
|
||||
args.compilation_config.level == 3
|
||||
args.compilation_config.mode == 3
|
||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and args.compilation_config.use_inductor
|
||||
)
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
@ -21,13 +21,13 @@ def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
|
||||
"--max-model-len=256",
|
||||
"--max-num-seqs=32",
|
||||
"--enforce-eager",
|
||||
f"-O{CompilationLevel.DYNAMO_ONCE}",
|
||||
f"-O{CompilationMode.DYNAMO_TRACE_ONCE}",
|
||||
],
|
||||
arg2=[
|
||||
"--max-model-len=256",
|
||||
"--max-num-seqs=32",
|
||||
"--enforce-eager",
|
||||
f"-O{CompilationLevel.DYNAMO_AS_IS}",
|
||||
f"-O{CompilationMode.STOCK_TORCH_COMPILE}",
|
||||
],
|
||||
env1={},
|
||||
env2={},
|
||||
|
@ -299,7 +299,7 @@ def test_dict_args(parser):
|
||||
"val2",
|
||||
"--hf-overrides.key2.key4",
|
||||
"val3",
|
||||
# Test compile config and compilation level
|
||||
# Test compile config and compilation mode
|
||||
"-O.use_inductor=true",
|
||||
"-O.backend",
|
||||
"custom",
|
||||
@ -352,7 +352,7 @@ def test_dict_args(parser):
|
||||
},
|
||||
}
|
||||
assert parsed_args.compilation_config == {
|
||||
"level": 1,
|
||||
"mode": 1,
|
||||
"use_inductor": True,
|
||||
"backend": "custom",
|
||||
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
||||
@ -367,7 +367,7 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
"--hf-overrides.key1",
|
||||
"val2",
|
||||
"-O1",
|
||||
"-O.level",
|
||||
"-O.mode",
|
||||
"2",
|
||||
"-O3",
|
||||
]
|
||||
@ -375,12 +375,12 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
parsed_args = parser.parse_args(args)
|
||||
# Should be the last value
|
||||
assert parsed_args.hf_overrides == {"key1": "val2"}
|
||||
assert parsed_args.compilation_config == {"level": 3}
|
||||
assert parsed_args.compilation_config == {"mode": 3}
|
||||
|
||||
assert len(caplog_vllm.records) == 1
|
||||
assert "duplicate" in caplog_vllm.text
|
||||
assert "--hf-overrides.key1" in caplog_vllm.text
|
||||
assert "-O.level" in caplog_vllm.text
|
||||
assert "-O.mode" in caplog_vllm.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -11,7 +11,7 @@ from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
@ -42,7 +42,7 @@ def _create_vllm_config(
|
||||
mock_config.parallel_config = ParallelConfig()
|
||||
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
return mock_config
|
||||
@ -50,23 +50,23 @@ def _create_vllm_config(
|
||||
|
||||
class TestCudagraphDispatcher:
|
||||
@pytest.mark.parametrize(
|
||||
"case_id,cudagraph_mode_str,compilation_level",
|
||||
"case_id,cudagraph_mode_str,compilation_mode",
|
||||
[
|
||||
# Test case 0: Full CG for mixed batches, no separate routine
|
||||
(0, "FULL", CompilationLevel.NO_COMPILATION),
|
||||
(0, "FULL", CompilationMode.NONE),
|
||||
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
|
||||
(1, "FULL_AND_PIECEWISE", CompilationMode.NONE),
|
||||
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 3: Piecewise for all
|
||||
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
|
||||
(2, "FULL_DECODE_ONLY", CompilationMode.NONE),
|
||||
# Test case 3: PIECEWISE for all
|
||||
(3, "PIECEWISE", CompilationMode.VLLM_COMPILE),
|
||||
],
|
||||
)
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_mode):
|
||||
# Setup dispatcher
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=cudagraph_mode_str,
|
||||
level=compilation_level,
|
||||
mode=compilation_mode,
|
||||
cudagraph_capture_sizes=[1, 8],
|
||||
)
|
||||
|
||||
@ -242,7 +242,7 @@ class TestCudagraphIntegration:
|
||||
def setup_method(self):
|
||||
# only FULL mode for non-uniform batches
|
||||
self.comp_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_mode="FULL",
|
||||
cudagraph_capture_sizes=[10, 20],
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ import pytest
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
level=3, cudagraph_mode=cudagraph_mode
|
||||
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
|
||||
),
|
||||
)
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
@ -90,32 +90,27 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
)
|
||||
|
||||
|
||||
# test cudagraph_mode with different compilation level.
|
||||
# (backend_name, cudagraph_mode, compilation_level, supported)
|
||||
# test cudagraph_mode with different compilation mode.
|
||||
# (backend_name, cudagraph_mode, compilation_mode, supported)
|
||||
combo_cases_2 = [
|
||||
("FA2", "FULL", 0, True), # no compilation + full cudagraph
|
||||
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
|
||||
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
|
||||
("FA2", "PIECEWISE", 3, True), # piecewise compilation + piecewise cudagraph
|
||||
(
|
||||
"FA2",
|
||||
"FULL_AND_PIECEWISE",
|
||||
0,
|
||||
False,
|
||||
), # piecewise cudagraph not supported without piecewise compilation
|
||||
("FA2", "FULL_AND_PIECEWISE", 3, True),
|
||||
("FA2", "FULL_DECODE_ONLY", 0, True),
|
||||
("FA2", "FULL_DECODE_ONLY", 3, True),
|
||||
("FA2", "NONE", 0, True), # no compilation + no cudagraph
|
||||
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
|
||||
("FA2", "FULL", CompilationMode.NONE, True),
|
||||
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "PIECEWISE", CompilationMode.NONE, False),
|
||||
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "NONE", CompilationMode.NONE, True),
|
||||
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2
|
||||
"backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
|
||||
)
|
||||
def test_cudagraph_compilation_combo(combo_case):
|
||||
backend_name, cudagraph_mode, compilation_level, supported = combo_case
|
||||
backend_name, cudagraph_mode, compilation_mode, supported = combo_case
|
||||
|
||||
env_vars = backend_configs[backend_name].env_vars
|
||||
|
||||
@ -130,7 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
level=compilation_level, cudagraph_mode=cudagraph_mode
|
||||
mode=compilation_mode, cudagraph_mode=cudagraph_mode
|
||||
),
|
||||
)
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
|
@ -7,7 +7,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
@ -75,9 +75,9 @@ def test_kv_sharing_fast_prefill(
|
||||
# This allows vLLM compilation backend to handle allocating and
|
||||
# managing buffers for cudagraph
|
||||
cudagraph_copy_inputs=True,
|
||||
level=CompilationLevel.PIECEWISE
|
||||
mode=CompilationMode.VLLM_COMPILE
|
||||
if not enforce_eager
|
||||
else CompilationLevel.NO_COMPILATION,
|
||||
else CompilationMode.NONE,
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
@ -56,7 +56,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
assert compilation_config.backend == "eager", (
|
||||
"Custom backends not supported with CompilationLevel.PIECEWISE"
|
||||
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
logger.debug("Using EagerAdaptor")
|
||||
@ -481,7 +481,7 @@ def set_model_tag(tag: str):
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with vLLM.
|
||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||
It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
|
||||
where we customize the compilation.
|
||||
|
||||
The major work of this backend is to split the graph into
|
||||
|
@ -575,7 +575,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different level of compilation.
|
||||
should be set at a different mode of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
|
@ -27,8 +27,8 @@ class CompilationCounter:
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
||||
dynamo_as_is_count: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
@ -18,7 +18,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
||||
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
|
||||
@ -233,11 +233,11 @@ def _support_torch_compile(
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.vllm_config = vllm_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
vllm_config.compilation_config.level
|
||||
in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS]
|
||||
vllm_config.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or not supports_dynamo()
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
or not enable_compile
|
||||
@ -247,7 +247,7 @@ def _support_torch_compile(
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_level=vllm_config.compilation_config.level
|
||||
self, compilation_mode=vllm_config.compilation_config.mode
|
||||
)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -18,7 +18,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and path:
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
@ -29,7 +29,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info(
|
||||
"torch.compile takes %.2f s in total", compilation_config.compilation_time
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from types import CodeType
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -31,7 +31,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, compiled_callable: Callable | None = None, compilation_level: int = 0
|
||||
self, compiled_callable: Callable | None = None, compilation_mode: int = 0
|
||||
):
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
@ -72,7 +72,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = (
|
||||
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||
compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE
|
||||
)
|
||||
|
||||
def aot_compile(self, *args, **kwargs):
|
||||
@ -85,7 +85,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
return self.compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
"""Implement the dispatch logic here, beyond the torch.compile mode.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
|
@ -4,7 +4,7 @@
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
PassConfig,
|
||||
)
|
||||
@ -49,7 +49,7 @@ __all__ = [
|
||||
"CacheConfig",
|
||||
# From vllm.config.compilation
|
||||
"CompilationConfig",
|
||||
"CompilationLevel",
|
||||
"CompilationMode",
|
||||
"CUDAGraphMode",
|
||||
"PassConfig",
|
||||
# From vllm.config.device
|
||||
|
@ -26,12 +26,20 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
NO_COMPILATION = 0
|
||||
DYNAMO_AS_IS = 1
|
||||
DYNAMO_ONCE = 2
|
||||
PIECEWISE = 3
|
||||
class CompilationMode:
|
||||
"""The compilation approach used for torch.compile-based compilation of the
|
||||
model."""
|
||||
|
||||
NONE = 0
|
||||
"""No torch.compile compilation is applied, model runs in fully eager pytorch mode.
|
||||
The model runs as-is."""
|
||||
STOCK_TORCH_COMPILE = 1
|
||||
"""The standard `torch.compile` compilation pipeline."""
|
||||
DYNAMO_TRACE_ONCE = 2
|
||||
"""Single Dynamo trace through the model, avoiding recompilation."""
|
||||
VLLM_COMPILE = 3
|
||||
"""Custom vLLM Inductor-based backend with caching, piecewise compilation,
|
||||
shape specialization, and custom passes."""
|
||||
|
||||
|
||||
class CUDAGraphMode(enum.Enum):
|
||||
@ -134,7 +142,7 @@ class CompilationConfig:
|
||||
"""Configuration for compilation. It has three parts:
|
||||
|
||||
- Top-level Compilation control:
|
||||
- [`level`][vllm.config.CompilationConfig.level]
|
||||
- [`mode`][vllm.config.CompilationConfig.mode]
|
||||
- [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
|
||||
- [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
|
||||
- [`backend`][vllm.config.CompilationConfig.backend]
|
||||
@ -171,14 +179,26 @@ class CompilationConfig:
|
||||
|
||||
# Top-level Compilation control
|
||||
level: int | None = None
|
||||
"""The level of compilation:
|
||||
"""
|
||||
Level is deprecated and will be removed in the next release,
|
||||
either 0.12.0 or 0.11.2 whichever is soonest.
|
||||
Please use mode. Currently all levels are mapped to mode.
|
||||
"""
|
||||
# Top-level Compilation control
|
||||
mode: int | None = None
|
||||
"""The compilation approach used for torch.compile-based compilation of the
|
||||
model.
|
||||
|
||||
- None: If None, we will select the default compilation level.
|
||||
For V1 engine this is 3, for V0 engine this is 0.
|
||||
- 0: no compilation.
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation."""
|
||||
- None: If None, we will select the default compilation mode.
|
||||
For V1 engine this is 3.
|
||||
- 0: NONE: No torch.compile compilation is applied, model runs in fully
|
||||
eager pytorch mode. The model runs as-is.
|
||||
- 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
|
||||
- 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding
|
||||
recompilation by removing guards.
|
||||
Requires no dynamic-shape-dependent control-flow.
|
||||
- 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching,
|
||||
piecewise compilation, shape specialization, and custom passes."""
|
||||
debug_dump_path: Path | None = None
|
||||
"""The path to dump the debug information."""
|
||||
cache_dir: str = ""
|
||||
@ -195,11 +215,11 @@ class CompilationConfig:
|
||||
|
||||
backend function.
|
||||
We use string to avoid serialization issues when using compilation in a
|
||||
distributed setting. When the compilation level is 1 or 2, the backend is
|
||||
distributed setting. When the compilation mode is 1 or 2, the backend is
|
||||
used for the compilation directly (it sees the whole graph). When the
|
||||
compilation level is 3, the backend is used for the piecewise compilation
|
||||
compilation mode is 3, the backend is used for the piecewise compilation
|
||||
(it sees a part of the graph). The backend can not be custom for compilation
|
||||
level 3, i.e. the backend must be either eager or inductor. Furthermore,
|
||||
mode 3, i.e. the backend must be either eager or inductor. Furthermore,
|
||||
compilation is only piecewise if splitting ops is set accordingly and
|
||||
use_inductor_graph_partition is off. Note that the default options for
|
||||
splitting ops are sufficient for piecewise compilation.
|
||||
@ -214,7 +234,7 @@ class CompilationConfig:
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
|
||||
By default, all custom ops are enabled when running without Inductor and
|
||||
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
||||
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
|
||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||
splitting_ops: list[str] | None = None
|
||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||
@ -249,7 +269,7 @@ class CompilationConfig:
|
||||
One graph for symbolic shape and one graph per size in compile_sizes
|
||||
are compiled using configurations in inductor_compile_config.
|
||||
|
||||
This setting is ignored if level<PIECEWISE.
|
||||
This setting is ignored if mode<VLLM_COMPILE.
|
||||
|
||||
For future compatibility:
|
||||
If use_inductor is True, backend="inductor" otherwise backend="eager".
|
||||
@ -299,7 +319,7 @@ class CompilationConfig:
|
||||
Currently, the cudagraph mode is only used for the v1 engine.
|
||||
Note that the cudagraph logic is generally orthogonal to the
|
||||
compilation logic. While piecewise cudagraphs require piecewise
|
||||
compilation (level=PIECEWISE and non-empty splitting_ops), full
|
||||
compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full
|
||||
cudagraphs are supported with and without compilation.
|
||||
|
||||
Warning: This flag is new and subject to change in addition
|
||||
@ -312,7 +332,7 @@ class CompilationConfig:
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
In the vLLM V1 Engine, this flag only applies for
|
||||
CompilationLevel.PIECEWISE (aka -O3).
|
||||
CompilationMode.VLLM_COMPILE (aka -O3).
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
@ -426,7 +446,7 @@ class CompilationConfig:
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.level)
|
||||
factors.append(self.mode)
|
||||
factors.append(self.backend)
|
||||
factors.append(self.custom_ops)
|
||||
factors.append(self.splitting_ops)
|
||||
@ -477,6 +497,17 @@ class CompilationConfig:
|
||||
return value
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.level is not None:
|
||||
logger.warning(
|
||||
"Level is deprecated and will be removed in the next release,"
|
||||
"either 0.12.0 or 0.11.2 whichever is soonest."
|
||||
"Use mode instead."
|
||||
"If both level and mode are given,"
|
||||
"only mode will be used."
|
||||
)
|
||||
if self.mode is None:
|
||||
self.mode = self.level
|
||||
|
||||
count_none = self.custom_ops.count("none")
|
||||
count_all = self.custom_ops.count("all")
|
||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||
@ -574,7 +605,7 @@ class CompilationConfig:
|
||||
# Currently only eager and inductor backend are supported.
|
||||
# for piecewise compilation. Custom backends are not suppported for
|
||||
# piecewise compilation. Update when more backends are supported.
|
||||
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
|
||||
if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [
|
||||
"",
|
||||
"eager",
|
||||
"inductor",
|
||||
@ -602,24 +633,27 @@ class CompilationConfig:
|
||||
Returns:
|
||||
The backend for the compilation config.
|
||||
"""
|
||||
if self.level is None:
|
||||
if self.mode is None:
|
||||
raise ValueError(
|
||||
"No compilation level is set. This method should only be \
|
||||
"No compilation mode is set. This method should only be \
|
||||
called via vllm config where the level is set if none is \
|
||||
provided."
|
||||
)
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
if self.mode == CompilationMode.NONE:
|
||||
raise ValueError("No compilation mode is set.")
|
||||
|
||||
from torch._dynamo.backends.registry import list_backends
|
||||
|
||||
torch_backends = list_backends(exclude_tags=tuple())
|
||||
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
if self.mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
]:
|
||||
if self.backend in torch_backends:
|
||||
return self.backend
|
||||
return resolve_obj_by_qualname(self.backend)
|
||||
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
assert self.mode == CompilationMode.VLLM_COMPILE
|
||||
if self.backend not in ["eager", "inductor"]:
|
||||
raise ValueError(
|
||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
||||
@ -684,11 +718,11 @@ class CompilationConfig:
|
||||
self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def set_splitting_ops_for_v1(self):
|
||||
# NOTE: this function needs to be called only when level is
|
||||
# CompilationLevel.PIECEWISE
|
||||
assert self.level == CompilationLevel.PIECEWISE, (
|
||||
# NOTE: this function needs to be called only when mode is
|
||||
# CompilationMode.VLLM_COMPILE
|
||||
assert self.mode == CompilationMode.VLLM_COMPILE, (
|
||||
"set_splitting_ops_for_v1 should only be called when "
|
||||
"level is CompilationLevel.PIECEWISE"
|
||||
"mode is CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
if self.use_inductor_graph_partition:
|
||||
@ -769,12 +803,10 @@ class CompilationConfig:
|
||||
|
||||
if not self.use_inductor_graph_partition:
|
||||
# Dynamo-level FX split case
|
||||
return self.level == CompilationLevel.PIECEWISE
|
||||
return self.mode == CompilationMode.VLLM_COMPILE
|
||||
|
||||
# Inductor partition case
|
||||
return (
|
||||
self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION
|
||||
)
|
||||
return self.backend == "inductor" and self.mode > CompilationMode.NONE
|
||||
|
||||
def custom_op_log_check(self):
|
||||
"""
|
||||
|
@ -22,7 +22,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
from .kv_events import KVEventsConfig
|
||||
from .kv_transfer import KVTransferConfig
|
||||
@ -84,17 +84,11 @@ class VllmConfig:
|
||||
compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
|
||||
"""`torch.compile` and cudagraph capture configuration for the model.
|
||||
|
||||
As a shorthand, `-O<n>` can be used to directly specify the compilation
|
||||
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
|
||||
Currently, -O <n> and -O=<n> are supported as well but this will likely be
|
||||
removed in favor of clearer -O<n> syntax in the future.
|
||||
|
||||
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||
are for internal testing only. level 3 is the recommended level for
|
||||
production, also default in V1.
|
||||
As a shorthand, one can append compilation arguments via
|
||||
-0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`).
|
||||
|
||||
You can specify the full compilation config like so:
|
||||
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
"""
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
@ -305,33 +299,33 @@ class VllmConfig:
|
||||
"precision for chunked prefill triton kernels."
|
||||
)
|
||||
|
||||
# If the user does not explicitly set a compilation level, then
|
||||
# we use the default level. The default level depends on other
|
||||
# If the user does not explicitly set a compilation mode, then
|
||||
# we use the default mode. The default mode depends on other
|
||||
# settings (see the below code).
|
||||
if self.compilation_config.level is None:
|
||||
if self.compilation_config.mode is None:
|
||||
if envs.VLLM_USE_V1:
|
||||
if (
|
||||
self.model_config is not None
|
||||
and not self.model_config.enforce_eager
|
||||
):
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||
else:
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
self.compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
else:
|
||||
# NB: Passing both --enforce-eager and a compilation level
|
||||
# in V0 means the compilation level wins out.
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
# NB: Passing both --enforce-eager and a compilation mode
|
||||
# in V0 means the compilation mode wins out.
|
||||
self.compilation_config.mode = CompilationMode.NONE
|
||||
else:
|
||||
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
|
||||
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
|
||||
assert self.compilation_config.mode >= CompilationMode.NONE
|
||||
assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
|
||||
|
||||
# If user does not set custom ops via none or all set it here based on
|
||||
# compilation level and backend.
|
||||
# compilation mode and backend.
|
||||
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
||||
if (
|
||||
self.compilation_config.backend == "inductor"
|
||||
and self.compilation_config.level > CompilationLevel.NO_COMPILATION
|
||||
and self.compilation_config.mode > CompilationMode.NONE
|
||||
):
|
||||
self.compilation_config.custom_ops.append("none")
|
||||
else:
|
||||
@ -350,7 +344,7 @@ class VllmConfig:
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if (
|
||||
envs.VLLM_USE_V1
|
||||
and self.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
):
|
||||
# default to full and piecewise for most models
|
||||
self.compilation_config.cudagraph_mode = (
|
||||
@ -486,10 +480,10 @@ class VllmConfig:
|
||||
)
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
if (
|
||||
envs.VLLM_USE_V1
|
||||
and self.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
):
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
@ -508,8 +502,8 @@ class VllmConfig:
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
|
||||
assert self.compilation_config.level == CompilationLevel.PIECEWISE, (
|
||||
"Compilation level should be CompilationLevel.PIECEWISE "
|
||||
assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
|
||||
"Compilation mode should be CompilationMode.VLLM_COMPILE "
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
)
|
||||
@ -837,7 +831,7 @@ def set_current_vllm_config(
|
||||
|
||||
if (
|
||||
check_compile
|
||||
and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
and compilation_counter.num_models_seen == num_models_seen
|
||||
):
|
||||
# If the model supports compilation,
|
||||
|
@ -176,7 +176,7 @@ class LLM:
|
||||
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
|
||||
whichever is sooner.
|
||||
compilation_config: Either an integer or a dictionary. If it is an
|
||||
integer, it is used as the level of compilation optimization. If it
|
||||
integer, it is used as the mode of compilation optimization. If it
|
||||
is a dictionary, it can specify the full compilation configuration.
|
||||
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
|
||||
|
||||
@ -257,9 +257,7 @@ class LLM:
|
||||
|
||||
if compilation_config is not None:
|
||||
if isinstance(compilation_config, int):
|
||||
compilation_config_instance = CompilationConfig(
|
||||
level=compilation_config
|
||||
)
|
||||
compilation_config_instance = CompilationConfig(mode=compilation_config)
|
||||
elif isinstance(compilation_config, dict):
|
||||
compilation_config_instance = CompilationConfig(
|
||||
**{
|
||||
|
@ -8,7 +8,7 @@ from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
@ -419,7 +419,7 @@ class Fp8LinearOp:
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = (
|
||||
config.level < CompilationLevel.PIECEWISE
|
||||
config.mode < CompilationMode.VLLM_COMPILE
|
||||
and self.preferred_backend == "torch"
|
||||
)
|
||||
|
||||
|
@ -247,12 +247,12 @@ class CpuPlatform(Platform):
|
||||
parallel_config.enable_dbo = False
|
||||
|
||||
# Note: workaround for v1 gpu_model_runner
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
# Note: vLLM V1 is using PIECEWISE level compilation, which will
|
||||
# take time to compile kernels just-in-time with the inductor
|
||||
# backend. For CPU CI tests, most of them are executed fast and
|
||||
@ -265,7 +265,7 @@ class CpuPlatform(Platform):
|
||||
else:
|
||||
backend = "inductor"
|
||||
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
compilation_config.backend = backend
|
||||
compilation_config.inductor_compile_config.update(
|
||||
{
|
||||
@ -277,7 +277,7 @@ class CpuPlatform(Platform):
|
||||
)
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
assert vllm_config.device_config.device_type == "cpu"
|
||||
|
||||
|
@ -114,7 +114,7 @@ class TpuPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||
from vllm.config import CompilationMode, CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
# For v0, the default block size is 16.
|
||||
@ -122,12 +122,13 @@ class TpuPlatform(Platform):
|
||||
cache_config.block_size = cast(BlockSize, 16)
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TPU only supports DYNAMO_ONCE compilation level
|
||||
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
|
||||
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
|
||||
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
|
||||
logger.info(
|
||||
"[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph."
|
||||
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
|
||||
disabling cudagraph."
|
||||
)
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode is None
|
||||
|
@ -144,7 +144,7 @@ class XPUPlatform(Platform):
|
||||
cache_config.block_size = 64
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||
from vllm.config import CompilationMode, CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.compile_sizes is None:
|
||||
@ -155,7 +155,7 @@ class XPUPlatform(Platform):
|
||||
)
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
@ -1686,16 +1686,16 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
|
||||
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
||||
# -O.<...> handled later
|
||||
# also handle -O=<level> here
|
||||
level = arg[3:] if arg[2] == "=" else arg[2:]
|
||||
processed_args.append(f"-O.level={level}")
|
||||
# also handle -O=<mode> here
|
||||
mode = arg[3:] if arg[2] == "=" else arg[2:]
|
||||
processed_args.append(f"-O.mode={mode}")
|
||||
elif (
|
||||
arg == "-O"
|
||||
and i + 1 < len(args)
|
||||
and args[i + 1] in {"0", "1", "2", "3"}
|
||||
):
|
||||
# Convert -O <n> to -O.level <n>
|
||||
processed_args.append("-O.level")
|
||||
# Convert -O <n> to -O.mode <n>
|
||||
processed_args.append("-O.mode")
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
|
@ -43,12 +43,12 @@ class CudagraphDispatcher:
|
||||
not_use_piecewise_compilation
|
||||
or self.compilation_config.is_attention_compiled_piecewise()
|
||||
), (
|
||||
"Compilation level should be CompilationLevel.PIECEWISE when "
|
||||
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
|
||||
"cudagraph_mode piecewise cudagraphs is used, "
|
||||
"and attention should be in splitting_ops or "
|
||||
"inductor splitting should be used. "
|
||||
f"cudagraph_mode={self.cudagraph_mode}, "
|
||||
f"compilation_level={self.compilation_config.level}, "
|
||||
f"compilation_mode={self.compilation_config.mode}, "
|
||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||
)
|
||||
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import (
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
get_layers_from_vllm_config,
|
||||
@ -86,7 +86,7 @@ class EagleProposer:
|
||||
self.use_cuda_graph = False
|
||||
|
||||
compilation_config = self.vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
cudagraph_mode = compilation_config.cudagraph_mode
|
||||
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
|
||||
CUDAGraphMode.PIECEWISE
|
||||
|
@ -25,7 +25,7 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (
|
||||
CompilationLevel,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
get_layers_from_vllm_config,
|
||||
@ -2927,14 +2927,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
and supports_dynamo()
|
||||
):
|
||||
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
|
||||
compilation_counter.dynamo_as_is_count += 1
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.model.compile(fullgraph=True, backend=backend)
|
||||
return
|
||||
# for other compilation levels, cudagraph behavior is controlled by
|
||||
# for other compilation modes, cudagraph behavior is controlled by
|
||||
# CudagraphWraper and CudagraphDispatcher of vllm.
|
||||
|
||||
# wrap the model with full cudagraph wrapper if needed.
|
||||
@ -3985,7 +3986,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# if not supported any full cudagraphs, just raise it.
|
||||
msg += (
|
||||
"; please try cudagraph_mode=PIECEWISE, and "
|
||||
"make sure compilation level is piecewise"
|
||||
"make sure compilation mode is VLLM_COMPILE"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -4012,7 +4013,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
f"with {min_cg_builder_name} backend (support: "
|
||||
f"{min_cg_support})"
|
||||
)
|
||||
if self.compilation_config.level == CompilationLevel.PIECEWISE and (
|
||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
|
||||
self.compilation_config.splitting_ops_contain_attention()
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
@ -4068,7 +4069,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
f"supported with {min_cg_builder_name} backend ("
|
||||
f"support:{min_cg_support}) "
|
||||
"; please try cudagraph_mode=PIECEWISE, "
|
||||
"and make sure compilation level is piecewise"
|
||||
"and make sure compilation mode is VLLM_COMPILE"
|
||||
)
|
||||
|
||||
# Trigger cudagraph dispatching keys initialization here (after
|
||||
|
Reference in New Issue
Block a user