mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[FrontEnd] UNREVERT CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops (#26502)
Signed-off-by: morrison-turnansky <mturnans@redhat.com> Signed-off-by: Morrison Turnansky <mturnans@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
This commit is contained in:
committed by
GitHub
parent
7200a21cd1
commit
e3fdb627d9
@ -258,13 +258,13 @@ def tractable_computation(
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(
|
||||
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
|
||||
llama_config, use_compile: bool, backend: str, split_attn: bool = False
|
||||
) -> torch.Tensor:
|
||||
if use_compile:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
backend=backend,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
@ -338,8 +338,8 @@ def run_model(
|
||||
return output.cpu()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
def test_toy_llama(use_inductor: bool):
|
||||
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
||||
def test_toy_llama(backend: str):
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(
|
||||
@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool):
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
|
||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
|
||||
run_model(tractable_config, backend="eager", use_compile=False)
|
||||
|
||||
if use_inductor:
|
||||
if backend == "inductor":
|
||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||
else:
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool):
|
||||
num_cudagraph_captured=2,
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
|
||||
)
|
||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
|
||||
run_model(tractable_config, backend=backend, use_compile=True)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool):
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(
|
||||
llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True,
|
||||
)
|
||||
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
|
||||
)
|
||||
run_model(
|
||||
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
|
||||
)
|
||||
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
@ -77,14 +77,15 @@ class TestSetting:
|
||||
method="encode",
|
||||
),
|
||||
# vision language model
|
||||
TestSetting(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
model_args=["--trust-remote-code", "--max-model-len", "2048"],
|
||||
pp_size=2,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate_with_image",
|
||||
),
|
||||
# See https://github.com/vllm-project/vllm/issues/26716.
|
||||
# TestSetting(
|
||||
# model="microsoft/Phi-3.5-vision-instruct",
|
||||
# model_args=["--trust-remote-code", "--max-model-len", "2048"],
|
||||
# pp_size=2,
|
||||
# tp_size=1,
|
||||
# attn_backend="FLASH_ATTN",
|
||||
# method="generate_with_image",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
def test_compile_correctness(
|
||||
@ -109,41 +110,46 @@ def test_compile_correctness(
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
final_args = [
|
||||
"--enforce-eager",
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
"-O.cudagraph_mode=none",
|
||||
]
|
||||
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
for comp_level in [
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
CompilationLevel.PIECEWISE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-O{level}"])
|
||||
all_envs.append({})
|
||||
for level in [CompilationLevel.NO_COMPILATION, comp_level]:
|
||||
all_args.append(
|
||||
final_args + [f"-O.level={level}", "-O.backend=inductor"]
|
||||
)
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
CompilationLevel.PIECEWISE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-O{level}"])
|
||||
all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
|
@ -36,55 +36,56 @@ class Relu3(ReLUSquaredActivation):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env, torch_level, use_inductor, ops_enabled, default_on",
|
||||
"env, torch_level, backend, ops_enabled, default_on",
|
||||
[
|
||||
# Default values based on compile level
|
||||
# - All by default (no Inductor compilation)
|
||||
(None, 0, False, [True] * 4, True),
|
||||
(None, 1, True, [True] * 4, True),
|
||||
(None, 2, False, [True] * 4, True),
|
||||
(None, 0, "eager", [True] * 4, True),
|
||||
(None, 1, "eager", [True] * 4, True),
|
||||
(None, 2, "eager", [True] * 4, True),
|
||||
(None, 3, "eager", [True] * 4, True),
|
||||
# - None by default (with Inductor)
|
||||
(None, 3, True, [False] * 4, False),
|
||||
(None, 4, True, [False] * 4, False),
|
||||
# - All by default (without Inductor)
|
||||
(None, 3, False, [True] * 4, True),
|
||||
(None, 4, False, [True] * 4, True),
|
||||
(None, 0, "inductor", [True] * 4, True),
|
||||
# - None by default (with Inductor)
|
||||
(None, 1, "inductor", [False] * 4, False),
|
||||
(None, 2, "inductor", [False] * 4, False),
|
||||
(None, 3, "inductor", [False] * 4, False),
|
||||
# Explicitly enabling/disabling
|
||||
#
|
||||
# Default: all
|
||||
#
|
||||
# All but SiluAndMul
|
||||
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
|
||||
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
|
||||
# Only ReLU3
|
||||
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
|
||||
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
|
||||
# All but SiluAndMul
|
||||
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
|
||||
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
|
||||
# All but ReLU3 (even if ReLU2 is on)
|
||||
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
|
||||
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
|
||||
# RMSNorm and SiluAndMul
|
||||
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
|
||||
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
|
||||
# All but RMSNorm
|
||||
("-rms_norm", 3, False, [0, 1, 1, 1], True),
|
||||
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
|
||||
#
|
||||
# Default: none
|
||||
#
|
||||
# Only ReLU3
|
||||
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
|
||||
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
|
||||
# All but RMSNorm
|
||||
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
||||
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
|
||||
],
|
||||
)
|
||||
def test_enabled_ops(
|
||||
env: str | None,
|
||||
torch_level: int,
|
||||
use_inductor: bool,
|
||||
backend: str,
|
||||
ops_enabled: list[int],
|
||||
default_on: bool,
|
||||
):
|
||||
custom_ops = env.split(",") if env else []
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
|
||||
backend=backend, level=torch_level, custom_ops=custom_ops
|
||||
)
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
@ -41,7 +41,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
if compilation_config.backend == "inductor":
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (
|
||||
@ -55,6 +55,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
logger.debug("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
assert compilation_config.backend == "eager", (
|
||||
"Custom backends not supported with CompilationLevel.PIECEWISE"
|
||||
)
|
||||
|
||||
logger.debug("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
|
||||
|
@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -187,7 +188,8 @@ class CompilationConfig:
|
||||
backend: str = ""
|
||||
"""The backend for compilation. It needs to be a string:
|
||||
|
||||
- "" (empty string): use the default backend.
|
||||
- "" (empty string): use the default backend ("inductor" on CUDA-alike
|
||||
platforms).
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
- "full.module.name": a qualified name which can be used to import the
|
||||
|
||||
@ -196,7 +198,12 @@ class CompilationConfig:
|
||||
distributed setting. When the compilation level is 1 or 2, the backend is
|
||||
used for the compilation directly (it sees the whole graph). When the
|
||||
compilation level is 3, the backend is used for the piecewise compilation
|
||||
(it sees a part of the graph)."""
|
||||
(it sees a part of the graph). The backend can not be custom for compilation
|
||||
level 3, i.e. the backend must be either eager or inductor. Furthermore,
|
||||
compilation is only piecewise if splitting ops is set accordingly and
|
||||
use_inductor_cudagraphs_partition is off. Note that the default options for
|
||||
splitting ops are sufficient for piecewise compilation.
|
||||
"""
|
||||
custom_ops: list[str] = field(default_factory=list)
|
||||
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
|
||||
to enable all, 'none' to disable all. Also specify a list of custom op
|
||||
@ -229,8 +236,12 @@ class CompilationConfig:
|
||||
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
|
||||
|
||||
# Inductor capture
|
||||
use_inductor: bool = True
|
||||
"""Whether to use inductor compilation:
|
||||
use_inductor: bool | None = None
|
||||
"""
|
||||
Whether to use inductor compilation.
|
||||
|
||||
This flag is deprecated and will be removed in the next release 0.12.0.
|
||||
Please use the 'backend' option instead.
|
||||
|
||||
- False: inductor compilation is not used. graph runs in eager
|
||||
(custom_ops enabled by default).
|
||||
@ -238,7 +249,11 @@ class CompilationConfig:
|
||||
One graph for symbolic shape and one graph per size in compile_sizes
|
||||
are compiled using configurations in inductor_compile_config.
|
||||
|
||||
This setting is ignored if level<PIECEWISE."""
|
||||
This setting is ignored if level<PIECEWISE.
|
||||
|
||||
For future compatibility:
|
||||
If use_inductor is True, backend="inductor" otherwise backend="eager".
|
||||
"""
|
||||
compile_sizes: list[int | str] | None = None
|
||||
"""Sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
@ -545,7 +560,43 @@ class CompilationConfig:
|
||||
"(where 'op' is the registered op name)"
|
||||
)
|
||||
|
||||
# Currently only eager and inductor backend are supported.
|
||||
# for piecewise compilation. Custom backends are not suppported for
|
||||
# piecewise compilation. Update when more backends are supported.
|
||||
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
|
||||
"",
|
||||
"eager",
|
||||
"inductor",
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
||||
)
|
||||
|
||||
if self.use_inductor is not None:
|
||||
logger.warning_once(
|
||||
"The 'use_inductor' flag is deprecated and will be "
|
||||
"removed in the next release (v0.12.0). "
|
||||
"Please use the 'backend' option instead.",
|
||||
)
|
||||
self.backend = "inductor" if self.use_inductor else "eager"
|
||||
|
||||
if self.backend == "":
|
||||
self.backend = current_platform.simple_compile_backend
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
|
||||
"""
|
||||
Initialize the backend for the compilation config from a vllm config.
|
||||
Arguments:
|
||||
vllm_config: The vllm config to initialize the backend from.
|
||||
Returns:
|
||||
The backend for the compilation config.
|
||||
"""
|
||||
if self.level is None:
|
||||
raise ValueError(
|
||||
"No compilation level is set. This method should only be \
|
||||
called via vllm config where the level is set if none is \
|
||||
provided."
|
||||
)
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
|
||||
@ -553,15 +604,15 @@ class CompilationConfig:
|
||||
|
||||
torch_backends = list_backends(exclude_tags=tuple())
|
||||
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
if self.backend == "":
|
||||
return "eager"
|
||||
if self.backend in torch_backends:
|
||||
return self.backend
|
||||
return resolve_obj_by_qualname(self.backend)
|
||||
|
||||
# TODO: pass user-specified backend to piecewise compilation
|
||||
# merge with the config use_inductor
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
if self.backend not in ["eager", "inductor"]:
|
||||
raise ValueError(
|
||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
||||
)
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
@ -710,7 +761,9 @@ class CompilationConfig:
|
||||
return self.level == CompilationLevel.PIECEWISE
|
||||
|
||||
# Inductor partition case
|
||||
return self.level > CompilationLevel.NO_COMPILATION and self.use_inductor
|
||||
return (
|
||||
self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION
|
||||
)
|
||||
|
||||
def custom_op_log_check(self):
|
||||
"""
|
||||
|
@ -322,6 +322,20 @@ class VllmConfig:
|
||||
# NB: Passing both --enforce-eager and a compilation level
|
||||
# in V0 means the compilation level wins out.
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
else:
|
||||
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
|
||||
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
|
||||
|
||||
# If user does not set custom ops via none or all set it here based on
|
||||
# compilation level and backend.
|
||||
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
||||
if (
|
||||
self.compilation_config.backend == "inductor"
|
||||
and self.compilation_config.level > CompilationLevel.NO_COMPILATION
|
||||
):
|
||||
self.compilation_config.custom_ops.append("none")
|
||||
else:
|
||||
self.compilation_config.custom_ops.append("all")
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
|
@ -113,7 +113,9 @@ class CustomOp(nn.Module):
|
||||
custom_ops = compilation_config.custom_ops
|
||||
if not hasattr(cls, "name"):
|
||||
logger.warning_once(
|
||||
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
|
||||
"Custom op %s was not registered, which means it won't appear "
|
||||
"in the op registry. It will be enabled/disabled based on the "
|
||||
"global settings.",
|
||||
cls.__name__,
|
||||
)
|
||||
return CustomOp.default_on()
|
||||
@ -127,19 +129,17 @@ class CustomOp(nn.Module):
|
||||
@staticmethod
|
||||
def default_on() -> bool:
|
||||
"""
|
||||
On by default if PyTorch Inductor is not used.
|
||||
Specifying 'all' or 'none' in custom_op takes precedence.
|
||||
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
|
||||
'all', off by default if 'none'.
|
||||
When PyTorch Inductor is used, 'none' is the default value,
|
||||
otherwise 'all'.
|
||||
"""
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
compilation_config = get_cached_compilation_config()
|
||||
default_on = (
|
||||
compilation_config.level < CompilationLevel.PIECEWISE
|
||||
or not compilation_config.use_inductor
|
||||
)
|
||||
count_none = compilation_config.custom_ops.count("none")
|
||||
count_all = compilation_config.custom_ops.count("all")
|
||||
return default_on and not count_none > 0 or count_all > 0
|
||||
assert count_none + count_all == 1
|
||||
|
||||
return not count_none > 0 or count_all > 0
|
||||
|
||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||
|
@ -275,8 +275,6 @@ class CpuPlatform(Platform):
|
||||
"epilogue_fusion": True,
|
||||
}
|
||||
)
|
||||
if compilation_config.use_inductor:
|
||||
compilation_config.custom_ops = ["none"]
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
Reference in New Issue
Block a user