[FrontEnd] UNREVERT CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops (#26502)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
This commit is contained in:
Morrison Turnansky
2025-10-13 18:47:16 -04:00
committed by GitHub
parent 7200a21cd1
commit e3fdb627d9
8 changed files with 153 additions and 86 deletions

View File

@ -258,13 +258,13 @@ def tractable_computation(
@torch.inference_mode
def run_model(
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
llama_config, use_compile: bool, backend: str, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
@ -338,8 +338,8 @@ def run_model(
return output.cpu()
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
@pytest.mark.parametrize("backend", ["inductor", "eager"])
def test_toy_llama(backend: str):
# compare output with and without piecewise compilation
llama_config = LlamaConfig(
@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool):
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
run_model(tractable_config, backend="eager", use_compile=False)
if use_inductor:
if backend == "inductor":
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool):
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
)
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
run_model(tractable_config, backend=backend, use_compile=True)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool):
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(
llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True,
)
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
)
run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])

View File

@ -77,14 +77,15 @@ class TestSetting:
method="encode",
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate_with_image",
),
# See https://github.com/vllm-project/vllm/issues/26716.
# TestSetting(
# model="microsoft/Phi-3.5-vision-instruct",
# model_args=["--trust-remote-code", "--max-model-len", "2048"],
# pp_size=2,
# tp_size=1,
# attn_backend="FLASH_ATTN",
# method="generate_with_image",
# ),
],
)
def test_compile_correctness(
@ -109,41 +110,46 @@ def test_compile_correctness(
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [
"--enforce-eager",
*model_args,
"-pp",
str(pp_size),
"-tp",
str(tp_size),
"-O.cudagraph_mode=none",
]
all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = []
for level in [
CompilationLevel.NO_COMPILATION,
for comp_level in [
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
for level in [CompilationLevel.NO_COMPILATION, comp_level]:
all_args.append(
final_args + [f"-O.level={level}", "-O.backend=inductor"]
)
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
)
all_envs.clear()
all_args.clear()
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
)
all_envs.clear()
all_args.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + [f"-O{level}"])
all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({})
compare_all_settings(model, all_args * 3, all_envs, method=method)

View File

@ -36,55 +36,56 @@ class Relu3(ReLUSquaredActivation):
@pytest.mark.parametrize(
"env, torch_level, use_inductor, ops_enabled, default_on",
"env, torch_level, backend, ops_enabled, default_on",
[
# Default values based on compile level
# - All by default (no Inductor compilation)
(None, 0, False, [True] * 4, True),
(None, 1, True, [True] * 4, True),
(None, 2, False, [True] * 4, True),
(None, 0, "eager", [True] * 4, True),
(None, 1, "eager", [True] * 4, True),
(None, 2, "eager", [True] * 4, True),
(None, 3, "eager", [True] * 4, True),
# - None by default (with Inductor)
(None, 3, True, [False] * 4, False),
(None, 4, True, [False] * 4, False),
# - All by default (without Inductor)
(None, 3, False, [True] * 4, True),
(None, 4, False, [True] * 4, True),
(None, 0, "inductor", [True] * 4, True),
# - None by default (with Inductor)
(None, 1, "inductor", [False] * 4, False),
(None, 2, "inductor", [False] * 4, False),
(None, 3, "inductor", [False] * 4, False),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
# All but RMSNorm
("-rms_norm", 3, False, [0, 1, 1, 1], True),
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
],
)
def test_enabled_ops(
env: str | None,
torch_level: int,
use_inductor: bool,
backend: str,
ops_enabled: list[int],
default_on: bool,
):
custom_ops = env.split(",") if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
backend=backend, level=torch_level, custom_ops=custom_ops
)
)
with set_current_vllm_config(vllm_config):

View File

@ -41,7 +41,7 @@ logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (
@ -55,6 +55,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationLevel.PIECEWISE"
)
logger.debug("Using EagerAdaptor")
return EagerAdaptor()

View File

@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
if TYPE_CHECKING:
@ -187,7 +188,8 @@ class CompilationConfig:
backend: str = ""
"""The backend for compilation. It needs to be a string:
- "" (empty string): use the default backend.
- "" (empty string): use the default backend ("inductor" on CUDA-alike
platforms).
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the
@ -196,7 +198,12 @@ class CompilationConfig:
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph)."""
(it sees a part of the graph). The backend can not be custom for compilation
level 3, i.e. the backend must be either eager or inductor. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_cudagraphs_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
"""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
@ -229,8 +236,12 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
# Inductor capture
use_inductor: bool = True
"""Whether to use inductor compilation:
use_inductor: bool | None = None
"""
Whether to use inductor compilation.
This flag is deprecated and will be removed in the next release 0.12.0.
Please use the 'backend' option instead.
- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
@ -238,7 +249,11 @@ class CompilationConfig:
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE."""
This setting is ignored if level<PIECEWISE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
compile_sizes: list[int | str] | None = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
@ -545,7 +560,43 @@ class CompilationConfig:
"(where 'op' is the registered op name)"
)
# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
"",
"eager",
"inductor",
]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be "
"removed in the next release (v0.12.0). "
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "":
self.backend = current_platform.simple_compile_backend
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
Returns:
The backend for the compilation config.
"""
if self.level is None:
raise ValueError(
"No compilation level is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -553,15 +604,15 @@ class CompilationConfig:
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
from vllm.compilation.backends import VllmBackend
@ -710,7 +761,9 @@ class CompilationConfig:
return self.level == CompilationLevel.PIECEWISE
# Inductor partition case
return self.level > CompilationLevel.NO_COMPILATION and self.use_inductor
return (
self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION
)
def custom_op_log_check(self):
"""

View File

@ -322,6 +322,20 @@ class VllmConfig:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
else:
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
# If user does not set custom ops via none or all set it here based on
# compilation level and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
and self.compilation_config.level > CompilationLevel.NO_COMPILATION
):
self.compilation_config.custom_ops.append("none")
else:
self.compilation_config.custom_ops.append("all")
# async tp is built on top of sequence parallelism
# and requires it to be enabled.

View File

@ -113,7 +113,9 @@ class CustomOp(nn.Module):
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
"Custom op %s was not registered, which means it won't appear "
"in the op registry. It will be enabled/disabled based on the "
"global settings.",
cls.__name__,
)
return CustomOp.default_on()
@ -127,19 +129,17 @@ class CustomOp(nn.Module):
@staticmethod
def default_on() -> bool:
"""
On by default if PyTorch Inductor is not used.
Specifying 'all' or 'none' in custom_op takes precedence.
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
'all', off by default if 'none'.
When PyTorch Inductor is used, 'none' is the default value,
otherwise 'all'.
"""
from vllm.config import CompilationLevel
compilation_config = get_cached_compilation_config()
default_on = (
compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor
)
count_none = compilation_config.custom_ops.count("none")
count_all = compilation_config.custom_ops.count("all")
return default_on and not count_none > 0 or count_all > 0
assert count_none + count_all == 1
return not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.

View File

@ -275,8 +275,6 @@ class CpuPlatform(Platform):
"epilogue_fusion": True,
}
)
if compilation_config.use_inductor:
compilation_config.custom_ops = ["none"]
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION