mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[torch.compile] Fix tests for torch==2.9 inductor partition (#26116)
Signed-off-by: ProExpertProg <lgovedic@redhat.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@ -11,6 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -32,13 +33,13 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
test_params_full_cudagraph = []
|
||||
model_backends_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
|
||||
model_backends_full_cudagraph.append(
|
||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
|
||||
)
|
||||
|
||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||
@ -46,14 +47,18 @@ other_backend_configs = [
|
||||
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||
]
|
||||
for backend_config in other_backend_configs:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||
)
|
||||
model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def llm_pair(request):
|
||||
model, backend_config = request.param
|
||||
model, backend_config, use_inductor_graph_partition = request.param
|
||||
backend_config.comp_config["use_inductor_graph_partition"] = (
|
||||
use_inductor_graph_partition
|
||||
)
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition only supported in torch>=2.9")
|
||||
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if (
|
||||
@ -104,7 +109,15 @@ def llm_pair(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
pytest.param((model, backend_config, use_inductor_graph_partition))
|
||||
for model, backend_config in model_backends_full_cudagraph
|
||||
for use_inductor_graph_partition in [True, False]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestFullCUDAGraph:
|
||||
"""
|
||||
Use a class such that an llm pair is constructed once for all
|
||||
|
@ -5,6 +5,7 @@ Test (piecewise) compilation with a simple model where multiple submodules
|
||||
are compiled and graph captured separately.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -190,7 +191,12 @@ def run_model(
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
||||
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
if use_inductor_graph_partition:
|
||||
# FIXME(luka/boyuan): this currently fails
|
||||
pytest.skip("Inductor graph partition not supported with multi-graph")
|
||||
|
||||
outputs = []
|
||||
|
||||
# piecewise compile
|
||||
@ -200,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
@ -220,16 +227,24 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
# static tensor addresses
|
||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=6,
|
||||
if use_inductor_graph_partition:
|
||||
# Splitting happens at Inductor lowering level,
|
||||
# total piecewise fx graphs is equal to total graphs
|
||||
num_piecewise_fx = 2
|
||||
num_piecewise_capturable_fx = 2
|
||||
else:
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_piecewise_fx = 6
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_piecewise_capturable_fx = 4
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||
num_backend_compilations=num_piecewise_capturable_fx,
|
||||
num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
@ -268,6 +283,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
@ -286,9 +302,9 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||
num_backend_compilations=num_piecewise_capturable_fx,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
@ -9,6 +9,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||
initialized randomly with a fixed seed.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@ -26,6 +27,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
@ -257,27 +259,13 @@ def tractable_computation(
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(
|
||||
llama_config, use_compile: bool, backend: str, split_attn: bool = False
|
||||
) -> torch.Tensor:
|
||||
if use_compile:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
backend=backend,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
compilation_config.splitting_ops = ["silly::attention"]
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
|
||||
# Start with a fresh copy to make sure there's no cache dir sharing
|
||||
compile_config = deepcopy(compile_config)
|
||||
cudagraph_runtime_mode = compile_config.cudagraph_mode
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compilation_config, additional_config=llama_config
|
||||
compilation_config=compile_config, additional_config=llama_config
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = (
|
||||
@ -338,8 +326,25 @@ def run_model(
|
||||
return output.cpu()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
||||
def test_toy_llama(backend: str):
|
||||
@pytest.mark.parametrize(
|
||||
"backend, use_inductor_graph_partition",
|
||||
[
|
||||
("eager", False), # No inductor
|
||||
("inductor", False), # Inductor, Dynamo partition
|
||||
("inductor", True), # Inductor, Inductor partition
|
||||
],
|
||||
)
|
||||
def test_toy_llama(
|
||||
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
||||
):
|
||||
# We disable the vLLM compile cache into a new tmp dir for 2 reasons:
|
||||
# 1. To make sure we can properly track the number of Inductor compilations.
|
||||
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition only supported in torch>=2.9")
|
||||
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(
|
||||
@ -350,6 +355,32 @@ def test_toy_llama(backend: str):
|
||||
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
||||
)
|
||||
|
||||
compile_config_no_compile = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
backend="eager",
|
||||
)
|
||||
|
||||
compile_config_no_split = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
backend=backend,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
|
||||
# FIXME(luka/boyuan): the graph from the previous test case
|
||||
# (no inductor partition) gets cached by AotAutograd so then the
|
||||
# compilation with inductor partitioning incorrectly loads an unpartitioned
|
||||
# graph and never partitions. I think this is a bug with custom inductor
|
||||
# partitioning but does not affect vLLM more generally as vLLM uses its own
|
||||
# cache (which takes inductor partitioning into account).
|
||||
if use_inductor_graph_partition:
|
||||
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
|
||||
|
||||
compile_config_split = deepcopy(compile_config_no_split)
|
||||
compile_config_split.splitting_ops = ["silly::attention"]
|
||||
|
||||
outputs = []
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
@ -358,8 +389,9 @@ def test_toy_llama(backend: str):
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
|
||||
run_model(tractable_config, backend="eager", use_compile=False)
|
||||
outputs.append(run_model(llama_config, compile_config_no_compile))
|
||||
|
||||
run_model(tractable_config, compile_config_no_compile)
|
||||
|
||||
if backend == "inductor":
|
||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||
@ -367,35 +399,34 @@ def test_toy_llama(backend: str):
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
|
||||
with compilation_counter.expect(
|
||||
# One graph for the model
|
||||
num_graphs_seen=1,
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_piecewise_capturable_graphs_seen=1,
|
||||
# num_piecewise_capturable_graphs_seen
|
||||
num_backend_compilations=1,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=2,
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
|
||||
run_model(tractable_config, backend=backend, use_compile=True)
|
||||
outputs.append(run_model(llama_config, compile_config_no_split))
|
||||
|
||||
run_model(tractable_config, compile_config_no_split)
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
num_piecewise_fx = 1
|
||||
num_piecewise_capturable_fx = 1
|
||||
else:
|
||||
num_piecewise_fx = 2 * llama_config.num_layers + 1
|
||||
num_piecewise_capturable_fx = 1 + llama_config.num_layers
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=1
|
||||
+ llama_config.num_layers, # 1 + num_layers
|
||||
num_backend_compilations=1
|
||||
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=2
|
||||
* (
|
||||
1 + llama_config.num_layers
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||
num_backend_compilations=num_piecewise_capturable_fx,
|
||||
# num_cudagraph_sizes * num_partitions
|
||||
num_cudagraph_captured=2 * (1 + llama_config.num_layers),
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
|
||||
)
|
||||
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
|
||||
outputs.append(run_model(llama_config, compile_config_split))
|
||||
run_model(tractable_config, compile_config_split)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
@ -62,5 +62,4 @@ direct_register_custom_op(
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
tags=(torch._C.Tag.cudagraph_unsafe,),
|
||||
)
|
||||
|
@ -73,6 +73,7 @@ def test_ignore_torch_compile_decorator():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=False, # TODO test both?
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
@ -188,6 +189,7 @@ def test_conditional_compile_enable_if():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=False, # TODO test both
|
||||
),
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
@ -220,6 +222,7 @@ def test_conditional_compile_enable_if():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=False, # TODO test both?
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -38,10 +38,6 @@ from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
try:
|
||||
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,)
|
||||
except AttributeError:
|
||||
tag_cudagraph_unsafe = () # type: ignore[assignment]
|
||||
|
||||
|
||||
def check_xformers_availability():
|
||||
@ -879,7 +875,6 @@ direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
|
||||
@ -931,7 +926,6 @@ direct_register_custom_op(
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch._library.utils import lookup_op
|
||||
@ -38,8 +39,16 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
|
||||
resolved.append(lookup_op(op_name))
|
||||
except Exception:
|
||||
# Skip operators that don't exist (e.g., model-specific ops)
|
||||
logger.warning(
|
||||
"Failed to resolve operator for Inductor partition: %s", op_name
|
||||
# Do not warn for attention ops, warn for others
|
||||
# (most likely manually specified)
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
logger.log(
|
||||
logging.DEBUG
|
||||
if op_name in CompilationConfig._attention_ops
|
||||
else logging.WARNING,
|
||||
"Failed to resolve operator for CUDAGraph partition: %s",
|
||||
op_name,
|
||||
)
|
||||
continue
|
||||
|
||||
|
@ -201,7 +201,7 @@ class CompilationConfig:
|
||||
(it sees a part of the graph). The backend can not be custom for compilation
|
||||
level 3, i.e. the backend must be either eager or inductor. Furthermore,
|
||||
compilation is only piecewise if splitting ops is set accordingly and
|
||||
use_inductor_cudagraphs_partition is off. Note that the default options for
|
||||
use_inductor_graph_partition is off. Note that the default options for
|
||||
splitting ops are sufficient for piecewise compilation.
|
||||
"""
|
||||
custom_ops: list[str] = field(default_factory=list)
|
||||
@ -431,6 +431,7 @@ class CompilationConfig:
|
||||
factors.append(self.custom_ops)
|
||||
factors.append(self.splitting_ops)
|
||||
factors.append(self.use_inductor)
|
||||
factors.append(self.use_inductor_graph_partition)
|
||||
factors.append(self.inductor_compile_config)
|
||||
factors.append(self.inductor_passes)
|
||||
factors.append(self.pass_config.uuid())
|
||||
|
Reference in New Issue
Block a user