[torch.compile] Fix tests for torch==2.9 inductor partition (#26116)

Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
Luka Govedič
2025-10-14 19:55:02 -04:00
committed by GitHub
parent 579d2e5458
commit 2dcd12d357
8 changed files with 138 additions and 72 deletions

View File

@ -11,6 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
@contextlib.contextmanager
@ -32,13 +33,13 @@ def temporary_environ(env_vars):
os.environ[k] = v
test_params_full_cudagraph = []
model_backends_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
model_backends_full_cudagraph.append(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
)
# Qwen/Qwen2-1.5B-Instruct with other backends
@ -46,14 +47,18 @@ other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
)
model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
@pytest.fixture(scope="class")
def llm_pair(request):
model, backend_config = request.param
model, backend_config, use_inductor_graph_partition = request.param
backend_config.comp_config["use_inductor_graph_partition"] = (
use_inductor_graph_partition
)
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition only supported in torch>=2.9")
# Dynamically skip test if GPU capability is not met
if (
@ -104,7 +109,15 @@ def llm_pair(request):
)
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
@pytest.mark.parametrize(
"llm_pair",
[
pytest.param((model, backend_config, use_inductor_graph_partition))
for model, backend_config in model_backends_full_cudagraph
for use_inductor_graph_partition in [True, False]
],
indirect=True,
)
class TestFullCUDAGraph:
"""
Use a class such that an llm pair is constructed once for all

View File

@ -5,6 +5,7 @@ Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately.
"""
import pytest
import torch
from torch import nn
@ -190,7 +191,12 @@ def run_model(
return output.cpu()
def test_multi_graph_piecewise_compile_outputs_equal():
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
if use_inductor_graph_partition:
# FIXME(luka/boyuan): this currently fails
pytest.skip("Inductor graph partition not supported with multi-graph")
outputs = []
# piecewise compile
@ -200,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@ -220,16 +227,24 @@ def test_multi_graph_piecewise_compile_outputs_equal():
# static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6,
if use_inductor_graph_partition:
# Splitting happens at Inductor lowering level,
# total piecewise fx graphs is equal to total graphs
num_piecewise_fx = 2
num_piecewise_capturable_fx = 2
else:
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4,
num_piecewise_fx = 6
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_piecewise_capturable_fx = 4
with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=num_piecewise_capturable_fx,
num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
):
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
@ -268,6 +283,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
level=CompilationLevel.PIECEWISE,
use_cudagraph=False,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@ -286,9 +302,9 @@ def test_multi_graph_piecewise_compile_outputs_equal():
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=num_piecewise_capturable_fx,
num_cudagraph_captured=0, # no cudagraph captured
):
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))

View File

@ -9,6 +9,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
@ -26,6 +27,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401
@ -257,27 +259,13 @@ def tractable_computation(
@torch.inference_mode
def run_model(
llama_config, use_compile: bool, backend: str, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly::attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION,
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
# Start with a fresh copy to make sure there's no cache dir sharing
compile_config = deepcopy(compile_config)
cudagraph_runtime_mode = compile_config.cudagraph_mode
vllm_config = VllmConfig(
compilation_config=compilation_config, additional_config=llama_config
compilation_config=compile_config, additional_config=llama_config
)
with set_current_vllm_config(vllm_config):
model = (
@ -338,8 +326,25 @@ def run_model(
return output.cpu()
@pytest.mark.parametrize("backend", ["inductor", "eager"])
def test_toy_llama(backend: str):
@pytest.mark.parametrize(
"backend, use_inductor_graph_partition",
[
("eager", False), # No inductor
("inductor", False), # Inductor, Dynamo partition
("inductor", True), # Inductor, Inductor partition
],
)
def test_toy_llama(
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
):
# We disable the vLLM compile cache into a new tmp dir for 2 reasons:
# 1. To make sure we can properly track the number of Inductor compilations.
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition only supported in torch>=2.9")
# compare output with and without piecewise compilation
llama_config = LlamaConfig(
@ -350,6 +355,32 @@ def test_toy_llama(backend: str):
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
)
compile_config_no_compile = CompilationConfig(
level=CompilationLevel.NO_COMPILATION,
cudagraph_mode=CUDAGraphMode.NONE,
backend="eager",
)
compile_config_no_split = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
# FIXME(luka/boyuan): the graph from the previous test case
# (no inductor partition) gets cached by AotAutograd so then the
# compilation with inductor partitioning incorrectly loads an unpartitioned
# graph and never partitions. I think this is a bug with custom inductor
# partitioning but does not affect vLLM more generally as vLLM uses its own
# cache (which takes inductor partitioning into account).
if use_inductor_graph_partition:
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
compile_config_split = deepcopy(compile_config_no_split)
compile_config_split.splitting_ops = ["silly::attention"]
outputs = []
with compilation_counter.expect(
num_graphs_seen=0,
@ -358,8 +389,9 @@ def test_toy_llama(backend: str):
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
run_model(tractable_config, backend="eager", use_compile=False)
outputs.append(run_model(llama_config, compile_config_no_compile))
run_model(tractable_config, compile_config_no_compile)
if backend == "inductor":
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
@ -367,35 +399,34 @@ def test_toy_llama(backend: str):
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
with compilation_counter.expect(
# One graph for the model
num_graphs_seen=1,
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
# num_piecewise_capturable_graphs_seen
num_backend_compilations=1,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
run_model(tractable_config, backend=backend, use_compile=True)
outputs.append(run_model(llama_config, compile_config_no_split))
run_model(tractable_config, compile_config_no_split)
if use_inductor_graph_partition:
num_piecewise_fx = 1
num_piecewise_capturable_fx = 1
else:
num_piecewise_fx = 2 * llama_config.num_layers + 1
num_piecewise_capturable_fx = 1 + llama_config.num_layers
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=1
+ llama_config.num_layers, # 1 + num_layers
num_backend_compilations=1
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2
* (
1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=num_piecewise_capturable_fx,
# num_cudagraph_sizes * num_partitions
num_cudagraph_captured=2 * (1 + llama_config.num_layers),
):
outputs.append(
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
)
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
outputs.append(run_model(llama_config, compile_config_split))
run_model(tractable_config, compile_config_split)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])

View File

@ -62,5 +62,4 @@ direct_register_custom_op(
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe,),
)

View File

@ -73,6 +73,7 @@ def test_ignore_torch_compile_decorator():
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=False, # TODO test both?
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@ -188,6 +189,7 @@ def test_conditional_compile_enable_if():
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=False, # TODO test both
),
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@ -220,6 +222,7 @@ def test_conditional_compile_enable_if():
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=False, # TODO test both?
),
)

View File

@ -38,10 +38,6 @@ from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,)
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
def check_xformers_availability():
@ -879,7 +875,6 @@ direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
tags=tag_cudagraph_unsafe,
)
@ -931,7 +926,6 @@ direct_register_custom_op(
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
tags=tag_cudagraph_unsafe,
)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import logging
from typing import TYPE_CHECKING
from torch._library.utils import lookup_op
@ -38,8 +39,16 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
logger.warning(
"Failed to resolve operator for Inductor partition: %s", op_name
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from vllm.config import CompilationConfig
logger.log(
logging.DEBUG
if op_name in CompilationConfig._attention_ops
else logging.WARNING,
"Failed to resolve operator for CUDAGraph partition: %s",
op_name,
)
continue

View File

@ -201,7 +201,7 @@ class CompilationConfig:
(it sees a part of the graph). The backend can not be custom for compilation
level 3, i.e. the backend must be either eager or inductor. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_cudagraphs_partition is off. Note that the default options for
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
"""
custom_ops: list[str] = field(default_factory=list)
@ -431,6 +431,7 @@ class CompilationConfig:
factors.append(self.custom_ops)
factors.append(self.splitting_ops)
factors.append(self.use_inductor)
factors.append(self.use_inductor_graph_partition)
factors.append(self.inductor_compile_config)
factors.append(self.inductor_passes)
factors.append(self.pass_config.uuid())