[Bugfix] VLLM_V1 supports passing other compilation levels (#19340)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2025-07-29 09:35:58 -04:00
committed by GitHub
parent ab714131e4
commit 04e38500ee
5 changed files with 88 additions and 5 deletions

View File

@ -26,6 +26,8 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@ -33,8 +35,8 @@ def test_use_cudagraphs_dynamic(monkeypatch):
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
assert vllm.envs.VLLM_USE_V1
# spawn means that the counters are in the same process.
monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn")
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
compilation_config = {
@ -50,6 +52,8 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
assert vllm.envs.VLLM_USE_V1
@ -72,3 +76,50 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _):
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with (
compilation_counter.expect(dynamo_as_is_count=1),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config={"level": 1},
gpu_memory_utilization=0.4) as _):
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config={"level": 0},
gpu_memory_utilization=0.4) as _):
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
enforce_eager=True,
gpu_memory_utilization=0.4) as _):
pass

View File

@ -27,6 +27,8 @@ class CompilationCounter:
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
dynamo_as_is_count: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)

View File

@ -4106,9 +4106,11 @@ class CompilationConfig:
certain small batchsizes, where inductor is good at optimizing.
"""
# Top-level Compilation control
level: int = 0
level: Optional[int] = None
"""The level of compilation:
- None: If None, we will select the default compilation level.
For V1 engine this is 3, for V0 engine this is 0.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
@ -4664,6 +4666,22 @@ class VllmConfig:
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.")
# If the user does not explicitly set a compilation level, then
# we use the default level. The default level depends on other
# settings (see the below code).
if self.compilation_config.level is None:
if envs.VLLM_USE_V1:
if (self.model_config is not None
and not self.model_config.enforce_eager):
self.compilation_config.level = CompilationLevel.PIECEWISE
else:
self.compilation_config.level = \
CompilationLevel.NO_COMPILATION
else:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
@ -4676,7 +4694,6 @@ class VllmConfig:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()

View File

@ -43,7 +43,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
@ -1930,6 +1930,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
rank_mapping,
)
if (
self.vllm_config.compilation_config.level == \
CompilationLevel.DYNAMO_AS_IS and supports_dynamo()
):
backend = self.vllm_config.compilation_config.init_backend(
self.vllm_config)
compilation_counter.dynamo_as_is_count += 1
self.model.compile(
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded."

View File

@ -22,6 +22,7 @@ import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import broadcast_tensor_dict, get_pp_group
@ -1121,6 +1122,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend(
self.vllm_config)
compilation_counter.dynamo_as_is_count += 1
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,