[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
fhl2000
2025-08-15 22:01:39 +08:00
committed by GitHub
parent a0632a3e03
commit 74f441f4b5
34 changed files with 1839 additions and 597 deletions

View File

@ -3,7 +3,8 @@
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
@ -32,27 +33,130 @@ def temporary_environ(env_vars):
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
@pytest.fixture(scope="class")
def llm_pair(request):
model = request.param
model, backend_config = request.param
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0):
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
env_vars = {
"VLLM_USE_V1": "1",
# Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
}
with temporary_environ(env_vars):
full = LLM(
model=model,
gpu_memory_utilization=0.45,
gpu_memory_utilization=0.43,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(full_cuda_graph=True),
max_num_seqs=128,
compilation_config=\
CompilationConfig(**backend_config.comp_config),
generation_config="vllm",
seed=42,
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
gpu_memory_utilization=0.43,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
max_num_seqs=128,
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
generation_config="vllm",
seed=42,
)
# PyTest caches the fixture values so we use weakref.proxy to enable GC
@ -66,90 +170,7 @@ def llm_pair(request):
)
@pytest.fixture(scope="class")
def cutlass_mla_llm_pair(request):
model = request.param
# force V1 engine and Cutlass MLA backend
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
}):
full = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
),
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
)
yield weakref.proxy(full), weakref.proxy(piecewise)
del full
del piecewise
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
@pytest.mark.parametrize(
"cutlass_mla_llm_pair",
[
# use an MLA model
"deepseek-ai/DeepSeek-V2-Lite",
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
reason="Only Blackwell GPUs support Cutlass MLA")
class TestFullCUDAGraphCutlassMLA:
"""
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
"""
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(8, 8),
])
def test_full_cudagraph_sm100_cutlass_mla(
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
LLM]):
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
class TestFullCUDAGraph:
"""
Use a class such that an llm pair is constructed once for all
@ -178,12 +199,14 @@ class TestFullCUDAGraph:
full cudagraph compilation works for padded cases too.
"""
piecewise_llm, full_cudagraph_llm = llm_pair
full_cudagraph_llm, piecewise_llm = llm_pair
prompts = ["Hello, my name is"] * batch_size
prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
top_p=1.0)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
@ -191,42 +214,16 @@ class TestFullCUDAGraph:
# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))
llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)
assert piecewise_res.outputs[0].text.lower() == \
full_res.outputs[0].text.lower()
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
# Flex_Attention is not supported with full cuda graph
}), pytest.raises(RuntimeError):
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True))
compilation_config=CompilationConfig(cudagraph_mode="FULL"))

View File

@ -11,10 +11,10 @@ from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
global_counter = 0
@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):
), set_forward_context(None,
vllm_config=vllm_config): # background context
# warm up with background context
model(inputs)
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
# capturing/replaying should under context of cudagraph dispatching
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
model(torch.randn(2).cuda())
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
output = model(input)
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

View File

@ -18,9 +18,9 @@ from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -276,9 +276,11 @@ def run_model(llama_config,
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
@ -287,17 +289,37 @@ def run_model(llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
with set_forward_context({}, vllm_config=vllm_config):
with set_forward_context({},
vllm_config=vllm_config): # background context
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
model(input_ids[:2], positions[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
model(input_ids[:1], positions[:1])
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
output = model(input_ids[:2], positions[:2])
output = output.cpu()

View File

View File

@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# Helper MLP for testing
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
return self.fc2(self.fc1(x))
def _create_vllm_config(compilation_config: CompilationConfig,
max_num_seqs: int = 8) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.set_splitting_ops_for_v1()
return mock_config
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"params",
[
# Test case 0: Full CG for mixed batches, no separate routine
{
"case_id": 0,
"cudagraph_mode": "FULL",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 1: Full CG for uniform batches, piecewise for mixed
{
"case_id": 1,
"cudagraph_mode": "FULL_AND_PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
# Test case 2: Full CG for uniform batches, no CG for mixed
{
"case_id": 2,
"cudagraph_mode": "FULL_DECODE_ONLY",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 3: Piecewise for all
{
"case_id": 3,
"cudagraph_mode": "PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
])
def test_dispatcher(self, params):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=params["cudagraph_mode"],
level=params["compilation_level"],
cudagraph_capture_sizes=[1, 8])
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode,
uniform_decode_query_len=1)
# Verify the key is initialized correctly
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if params["cudagraph_mode"] == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
else:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
if params["cudagraph_mode"] == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
elif params["cudagraph_mode"] in [
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif params["cudagraph_mode"] == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_no_match)
assert rt_mode == CUDAGraphMode.NONE
assert key is None
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
def setup_method(self):
self.vllm_config = _create_vllm_config(CompilationConfig())
self.model = SimpleMLP().to("cuda")
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda")
@create_new_process_for_each_test("spawn")
def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
wrapper(self.input_tensor)
# 1. Capture
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor),\
patch("torch.cuda.graph",
wraps=torch.cuda.graph) as mock_cuda_graph:
output1 = wrapper(self.input_tensor)
# capturing phase should generate a zero output
assert torch.allclose(output1, torch.zeros_like(output1))
mock_cuda_graph.assert_called_once()
assert batch_descriptor in wrapper.concrete_cudagraph_entries
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
assert entry.cudagraph is not None
# 2. Replay
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor),\
patch.object(entry.cudagraph, 'replay',
wraps=entry.cudagraph.replay) as mock_replay:
output2 = wrapper(self.input_tensor)
mock_replay.assert_called_once()
# Compare with eager output
eager_output = self.model(self.input_tensor)
torch.testing.assert_close(eager_output, output2)
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_mismatch(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=batch_descriptor), \
patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_cuda_graph, \
patch.object(self.model, 'forward',
wraps=self.model.forward) as mock_forward:
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
mock_forward.assert_called_once()
assert not wrapper.concrete_cudagraph_entries
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_none(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=batch_descriptor), \
patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_cuda_graph:
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
assert not wrapper.concrete_cudagraph_entries
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
def setup_method(self):
# only FULL mode for non-uniform batches
self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE,
cudagraph_mode="FULL",
cudagraph_capture_sizes=[10, 20])
self.vllm_config = _create_vllm_config(self.comp_config)
self.dispatcher = CudagraphDispatcher(self.vllm_config)
self.dispatcher.initialize_cudagraph_keys(
self.comp_config.cudagraph_mode, uniform_decode_query_len=1)
def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode,
batch_descriptor):
"""Helper to run a single call and monitor the action."""
with patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_graph_context, \
patch.object(wrapper, 'runnable',
wraps=wrapper.runnable) as mock_runnable:
entry = wrapper.concrete_cudagraph_entries.get(
batch_descriptor, None)
context = set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=runtime_mode,
batch_descriptor=batch_descriptor)
mock_replay = MagicMock()
if entry and entry.cudagraph:
with context, \
patch.object(entry.cudagraph, 'replay',
new_callable=MagicMock) as mock_replay:
wrapper(input_tensor)
else:
with context:
wrapper(input_tensor)
if mock_graph_context.called:
# note that this is globally mocked, so it will be detected
# even whether called by the inner or outer wrapper
return "capture_global"
if mock_replay.called:
# only for outer wrapper
return "replay"
if mock_runnable.call_count > 0:
# only for outer wrapper
return "bypass"
return "unknown"
@create_new_process_for_each_test("spawn")
def test_capture_replay_bypass_logic(self):
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
CUDAGraphMode.FULL)
max_bs = 16
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
input_1 = persistent_input_buffer[:1]
input_2 = persistent_input_buffer[:2]
input_3 = persistent_input_buffer[:3]
desc_1 = BatchDescriptor(num_tokens=1)
desc_2 = BatchDescriptor(num_tokens=2)
desc_3_unseen = BatchDescriptor(num_tokens=3)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
full_wrapper(input_1)
rt_mode, key = self.dispatcher.dispatch(desc_1)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
key)
assert action == "capture_global"
# 2. Replay first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
key)
assert action == "replay"
rt_mode, key = self.dispatcher.dispatch(desc_2)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode,
key)
assert action == "capture_global"
# 4. Replay second shape
action = self._run_and_monitor_call(full_wrapper, input_2,
CUDAGraphMode.FULL, desc_2)
assert action == "replay"
# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode,
key)
assert action == "bypass"
# capture unseen shape is not allowed after disable
set_cudagraph_capturing_enabled(False)
with pytest.raises(RuntimeError):
self._run_and_monitor_call(full_wrapper, input_3,
CUDAGraphMode.FULL, desc_3_unseen)
set_cudagraph_capturing_enabled(True)
@create_new_process_for_each_test("spawn")
def test_nested_wrappers(self):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
CUDAGraphMode.FULL)
input_1 = torch.randn(1, 10, device="cuda")
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model = SimpleMLP().to("cuda")
piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config,
CUDAGraphMode.PIECEWISE)
inner_model.forward = MagicMock(wraps=inner_model.forward)
outer_model = SimpleMLP().to("cuda")
# When outer model is called, it calls the piecewise_wrapper
outer_model.forward = MagicMock(wraps=outer_model.forward,
side_effect=piecewise_wrapper)
full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config,
CUDAGraphMode.FULL)
desc_1 = BatchDescriptor(num_tokens=1)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
full_wrapper(input_1)
# --- Test runtime mode FULL---
# Run with FULL mode context. Expect outer wrapper to capture.
# The inner mock should be called once inside the graph capture.
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.FULL, desc_1)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again. Expect outer wrapper to replay.
# The outer model should NOT be called because the whole graph
# is replayed.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.FULL, desc_1)
assert action == "replay"
assert outer_model.forward.call_count == 1 # No new call
assert inner_model.forward.call_count == 1
# --- Test runtime mode PIECEWISE ---
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
# Run with PIECEWISE mode context.
# Expect outer wrapper to bypass and call inner wrapper.
# Inner wrapper should capture.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.PIECEWISE, desc_1)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again with PIECEWISE.
# Outer bypasses, inner replays.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.PIECEWISE, desc_1)
assert action == "bypass"
assert outer_model.forward.call_count == 2
assert inner_model.forward.call_count == 1

View File

@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
("FA3", "FULL", True),
("FA3", "FULL_AND_PIECEWISE", True),
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FA2", "FULL_AND_PIECEWISE", True),
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FlashInfer", "FULL_AND_PIECEWISE", True),
]
@pytest.mark.parametrize("combo_case", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(combo_case):
backend_name, cudagraph_mode, supported = combo_case
if backend_name == "FlashInfer":
try:
import flashinfer # noqa: F401
except ImportError:
pytest.skip("FlashInfer is not installed")
backend_config = backend_configs[backend_name]
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
with temporary_environ(env_vars), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=256,
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=3, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
try:
llm = weakref.proxy(llm)
del llm
except UnboundLocalError:
pass
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
combo_cases_2 = [
("FA2", "FULL", 0, True), # no compilation + full cudagraph
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
("FA2", "PIECEWISE", 3,
True), # piecewise compilation + piecewise cudagraph
("FA2", "FULL_AND_PIECEWISE", 0,
False), # piecewise cudagraph not supported without piecewise compilation
("FA2", "FULL_AND_PIECEWISE", 3, True),
("FA2", "FULL_DECODE_ONLY", 0, True),
("FA2", "FULL_DECODE_ONLY", 3, True),
("FA2", "NONE", 0, True), # no compilation + no cudagraph
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
]
@pytest.mark.parametrize("combo_case", combo_cases_2)
def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_level, supported\
= combo_case
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
with temporary_environ(env_vars), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=256,
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=compilation_level, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
try:
llm = weakref.proxy(llm)
del llm
except UnboundLocalError:
pass
finally:
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)

View File

@ -15,7 +15,7 @@ import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
import vllm.envs as envs
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule,
return split_gm, outputs
# we share the global graph pool among all the backends
global_graph_pool = None
compilation_start_time = 0.0
@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend
piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph))
else:
self.module.__dict__[target] = piecewise_backend
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
@ -413,9 +433,7 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
global_graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
@ -585,7 +603,7 @@ class VllmBackend:
self._called = True
if not self.compilation_config.use_cudagraph or \
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

View File

@ -1,72 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Protocol
import torch.fx as fx
from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig
class AbstractPiecewiseBackend(Protocol):
"""
PiecewiseBackend interface that allows platforms to extend
piecewise static graph.
"""
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, **kwargs):
"""
Initializes the PiecewiseBackend class with compilation and
execution-related configurations.
This class handles piecewise compilation, graph capturing,
and dispatching for specific input shapes.
Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
piecewise_compile_index (int):
Index of the current piecewise subgraph.
total_piecewise_compiles (int):
Total number of piecewise-compiled graphs.
sym_shape_indices (list[int]):
Indices of symbolic shape.
compiled_graph_for_general_shape (Callable):
Callable that executes the graph compiled for general shapes.
vllm_backend (VllmBackend):
Backend compiler that manages compilation and graph runtime
for vLLM.
Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.
"""
raise NotImplementedError
def __call__(self, *args) -> Any:
"""Executes the compiled graph for given input args.
If this is the first invocation, executes the general compiled graph
and initiates the compilation process tracking. For subsequent calls,
dynamically dispatches execution to either a compiled graph or a static
graph based on the input shape.
Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.
Returns:
Any: Output of the executed graph. This can be from the general
compiled graph, a specialized compiled version for the given shape,
or a replayed static graph.
"""
raise NotImplementedError

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Protocol
from vllm.config import CUDAGraphMode, VllmConfig
class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
runtime_mode (CUDAGraphMode): The style of the static
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError
def __call__(self, *args, **kwargs) -> Any:
"""
Executes the wrapped callable.
If the current runtime mode in the ForwardContext matches the runtime
mode of this instance, it replays the CUDAGraph or captures it using
the callable if it hasn't been captured yet. Otherwise, it calls the
original callable directly.
Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.
Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError

View File

@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
@dataclasses.dataclass
class CUDAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
class CUDAGraphWrapper:
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the cudagraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for cudagraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform cudagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: CUDAGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
= {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
cudagraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = \
CUDAGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
entry.cudagraph.replay()
return entry.output

View File

@ -2,21 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
from typing import Any, Callable
import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__)
@ -24,44 +18,29 @@ logger = init_logger(__name__)
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CUDAPiecewiseBackend:
class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
piecewise_compile_index: int, total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
@ -70,11 +49,10 @@ class CUDAPiecewiseBackend:
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)
self.is_full_graph = total_piecewise_compiles == 1
self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False
@ -84,18 +62,18 @@ class CUDAPiecewiseBackend:
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either
# compile or capture cudagraph
# the entries for different shapes that we need to compile
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
# We only keep compilation management inside this class directly.
for shape in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
runnable=self.compiled_graph_for_general_shape,
)
def check_for_ending_compilation(self):
@ -112,16 +90,14 @@ class CUDAPiecewiseBackend:
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape
if entry.need_to_compile and not entry.compiled:
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
@ -138,81 +114,4 @@ class CUDAPiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)
if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
return entry.runnable(*args)

View File

@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None
cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled():
# used to monitor whether an cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
if not cudagraph_capturing_enabled:
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
"time. This operation is currently disabled.")
def set_cudagraph_capturing_enabled(enabled: bool):
global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled

View File

@ -11,7 +11,8 @@ from typing import Callable, Optional
import torch
import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.config import (CompilationLevel, CUDAGraphMode,
get_current_vllm_config)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher:
except Exception:
pass
if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
if self.vllm_config.compilation_config.cudagraph_mode != \
CUDAGraphMode.NONE and "update" in new_code.co_names:
import depyf
src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa

View File

@ -32,7 +32,7 @@ from vllm import version
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
PassConfig)
CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
@ -3529,11 +3529,21 @@ class VllmConfig:
else:
self.compilation_config.level = \
CompilationLevel.NO_COMPILATION
else:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
# if cudagraph_mode is not explicitly set by users, set default value
if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE:
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
@ -3541,12 +3551,13 @@ class VllmConfig:
True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif envs.VLLM_USE_V1:
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
@ -3566,12 +3577,6 @@ class VllmConfig:
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.info("full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
@ -3612,9 +3617,32 @@ class VllmConfig:
"to True to enable.")
current_platform.check_and_update_config(self)
# final check of cudagraph mode after platform-specific update
if envs.VLLM_USE_V1:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
and self.model_config is not None and \
not self.model_config.disable_cascade_attn:
logger.info("CUDAGraphMode.FULL is not supported with "
"cascade attention currently. Disabling cascade"
"attention.")
self.model_config.disable_cascade_attn = True
if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation():
assert self.compilation_config.level == \
CompilationLevel.PIECEWISE, \
"Compilation level should be CompilationLevel.PIECEWISE "\
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
if not self.instance_id:
self.instance_id = random_uuid()[:5]
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import hashlib
from collections import Counter
from dataclasses import asdict, field
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
from pydantic import TypeAdapter
from pydantic import TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
@ -31,6 +32,40 @@ class CompilationLevel:
PIECEWISE = 3
class CUDAGraphMode(enum.Enum):
""" Constants for the cudagraph mode in CompilationConfig.
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
treated as concrete runtime mode for cudagraph runtime dispatching.
"""
NONE = 0
PIECEWISE = 1
FULL = 2
FULL_DECODE_ONLY = (FULL, NONE)
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
def decode_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[0]) if \
self.separate_routine() else self
def mixed_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[1]) if \
self.separate_routine() else self
def requires_piecewise_compilation(self) -> bool:
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(max(
self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
@config
@dataclass
class PassConfig:
@ -91,6 +126,7 @@ class CompilationConfig:
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`]
@ -157,7 +193,7 @@ class CompilationConfig:
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] = field(default_factory=list)
splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
@ -187,7 +223,43 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1)
cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
- FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE.
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatiable ops (i.e. some attention ops) outside the cudagraph
for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends.
Generally for performance FULL_AND_PIECEWISE is better.
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
Mixed prefill-decode batches are run without cudagraphs. Can be good for
decode instances in a P/D setup where prefill is not as important so we
can save some memory.
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models.
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
@ -197,8 +269,9 @@ class CompilationConfig:
CompilationLevel.PIECEWISE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future."""
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
@ -213,12 +286,17 @@ class CompilationConfig:
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph: bool = False
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: Optional[bool] = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
@ -253,6 +331,13 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -297,13 +382,26 @@ class CompilationConfig:
if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude
return TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode()
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
__str__ = __repr__
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
"""
enable parse the `cudagraph_mode` enum type from string
"""
if isinstance(value, str):
return CUDAGraphMode[value.upper()]
return value
def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
@ -341,7 +439,26 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning("use_cudagraph is deprecated, use "
"cudagraph_mode=NONE instead.")
if self.cudagraph_mode is not None:
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning("full_cuda_graph is deprecated, use "
"cudagraph_mode=FULL instead.")
if self.cudagraph_mode is not None:
raise ValueError("full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -414,15 +531,34 @@ class CompilationConfig:
self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")
# NOTE: this function needs to be called only when level is
# CompilationLevel.PIECEWISE
assert self.level == CompilationLevel.PIECEWISE, (
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE")
if not self.splitting_ops:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture the
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
self.splitting_ops = self._attention_ops
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
"treated as FULL cudagraph_mode. Please ensure you are "
"using attention backends that support cudagraph or set "
"cudagraph_mode to NONE explicitly if encountering "
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)

View File

@ -5,13 +5,13 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
batch for cudagraph.
"""
num_tokens: int
uniform_decode: bool = False
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
max_num_tokens: int,
chunk_idx: int) -> list[int]:
@ -152,7 +173,15 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# by default NONE, no cudagraph is used.
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
batch_descriptor: Optional[BatchDescriptor] = None
def __post_init__(self):
assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
_forward_context: Optional[ForwardContext] = None
@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
):
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
@ -198,7 +227,8 @@ def set_forward_context(
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
)
try:

View File

@ -177,17 +177,20 @@ class CudaPlatformBase(Platform):
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.")
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph):
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP "
"Data Parallel: disabling cudagraphs since DP "
"with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
compilation_config.use_cudagraph = False
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if model_config is not None:
model_config.enforce_eager = True
@ -454,8 +457,8 @@ class CudaPlatformBase(Platform):
return True
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(

View File

@ -7,7 +7,7 @@ import random
import sys
from datetime import timedelta
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import numpy as np
import torch
@ -137,6 +137,8 @@ class Platform:
additional_env_vars: list[str] = []
_global_graph_pool: Optional[Any] = None
@property
def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform."""
@ -522,6 +524,15 @@ class Platform:
" attribute.", self.device_type, key)
return None
def get_global_graph_pool(self) -> Any:
"""
Return the global graph pool for the this platform.
"""
cls = self.__class__
if cls._global_graph_pool is None:
cls._global_graph_pool = self.graph_pool_handle()
return cls._global_graph_pool
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
"""
@ -530,11 +541,11 @@ class Platform:
raise NotImplementedError
@classmethod
def get_piecewise_backend_cls(cls) -> str:
def get_static_graph_wrapper_cls(cls) -> str:
"""
Get piecewise backend class for piecewise graph.
Get static graph wrapper class for static graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(

View File

@ -421,8 +421,8 @@ class RocmPlatform(Platform):
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(

View File

@ -99,7 +99,7 @@ class TpuPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
from vllm.config import CompilationLevel, CUDAGraphMode
cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
@ -109,9 +109,17 @@ class TpuPlatform(Platform):
# TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
"disabling cudagraph.")
compilation_config.level = CompilationLevel.DYNAMO_ONCE
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[TPU] CUDA graph is not supported on TPU, "
"disabling cudagraphs.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if compilation_config.backend == "":
compilation_config.backend = "openxla"

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
from vllm.config import CUDAGraphMode
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
@ -100,16 +101,17 @@ class XPUPlatform(Platform):
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
if model_config is not None:
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True
if model_config is not None and model_config.dtype == torch.bfloat16 \
and not cls.device_support_bf16():
model_config.dtype = torch.float16
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[XPU] CUDA graph is not supported on XPU, "
"disabling cudagraphs.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# check and update parallel config
parallel_config = vllm_config.parallel_config

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import numpy as np
import torch
@ -154,9 +154,26 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS
# FA3:
# Supports full cudagraphs for all cases.
#
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# Theres probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = AttentionCGSupport.ALWAYS \
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
if self.use_full_cuda_graph:
if not self.aot_schedule:
raise ValueError(
"AoT scheduling is required for full cuda graph.")
capture_sizes = self.compilation_config.cudagraph_capture_sizes
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
if self.use_full_cuda_graph and self.aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder(
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=causal)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder(
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0
if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder(
causal=causal)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)

View File

@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_attention
@ -183,8 +183,8 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1
@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting

View File

@ -89,8 +89,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1
@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder(
m.max_query_len = 1 # decode-only
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1

View File

@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
assert m.max_query_len == 1 # decode-only
return self.build(0, m)
@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""

View File

@ -22,7 +22,7 @@ logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
class CutlassMLABackend(MLACommonBackend):

View File

@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.full_cuda_graph:
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
if self.compilation_config.full_cuda_graph:
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None

View File

@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers
if vllm_config.compilation_config.full_cuda_graph:
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
if self.compilation_config.full_cuda_graph:
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)

View File

@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder(
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False

View File

@ -58,8 +58,7 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported
return True
class TritonAttentionBackend(AttentionBackend):

View File

@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum):
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""

View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
"""
Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL cudagraph runtime mode. The keys are initialized depending on
attention support and what cudagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
cudagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (commuicate via forward context),
the cudagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without cudagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
}
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.cudagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for cudagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False))
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
and cudagraph_mode.separate_routine():
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
cudagraph_capture_sizes_for_decode = [
x for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a cudagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("cudagraph dispatching keys are not "
"initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# also check if non-uniform key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None

View File

@ -21,7 +21,9 @@ from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig,
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@ -29,7 +31,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
from vllm.forward_context import DPMetadata, set_forward_context
from vllm.forward_context import (BatchDescriptor, DPMetadata,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -48,13 +51,15 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo)
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
@ -218,11 +223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_spec_decode=bool(self.vllm_config.speculative_config),
)
self.use_cuda_graph = (
self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and self.vllm_config.compilation_config.use_cudagraph
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
@ -230,8 +230,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
self.full_cuda_graph = self.compilation_config.full_cuda_graph
# Cache the device properties.
self._init_device_properties()
@ -326,6 +324,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device)
self.uniform_decode_query_len = 1 if not self.speculative_config else \
1 + self.speculative_config.num_speculative_tokens
# Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = (MultiModalBudget(
self.model_config,
self.scheduler_config,
@ -471,7 +475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model)
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
@ -679,13 +683,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str,
Any], bool, torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata]]:
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata], int]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
@ -820,7 +822,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.use_cuda_graph
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
@ -925,17 +927,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all(
g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
for g in self._attn_group_iterator())
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens,
spec_decode_common_attn_metadata)
return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens)
def _compute_cascade_attn_prefix_len(
self,
@ -1259,6 +1257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return mm_embeds
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper):
return self.model.unwrap()
return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
@ -1415,9 +1416,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
assert self.eplb_state is not None
assert is_mixture_of_experts(self.model)
model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state.step(
self.model,
model,
is_dummy,
is_profile,
log_stats=self.parallel_config.eplb_log_balancedness,
@ -1507,15 +1509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config)
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens_np,
spec_decode_common_attn_metadata) = (
self._prepare_inputs(scheduler_output))
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Use CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
@ -1581,10 +1582,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model.
# Use persistent buffers for CUDA graphs.
@ -1593,10 +1596,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
model_output = self.model(
input_ids=input_ids,
positions=positions,
@ -2021,20 +2024,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model.compile(
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
return
# for other compilation levels, cudagraph behavior is controlled by
# CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.model = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)
model = self.get_model()
model_loader.load_weights(model, model_config=self.model_config)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
model = self.get_model()
TensorizerLoader.save_model(
self.model,
model,
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
@ -2210,31 +2224,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
capture_attn_cudagraph: bool = False,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
CUDA graph for the model.
Args:
num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
needed.
force_attention: If True, always create attention metadata. Used to
warm up attention backend when mode is NONE.
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
"""
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
# uniform decode batches. A uniform decode batch means that all
# requests have identical query length, except a potential virtual
# request (shorter) in the batch account for padding.
# Uniform decode batch could either be common pure decode, where
# max_query_len == 1, or speculative decode, where
# max_query_len == 1 + num_spec_decode_tokens.
# When setting max_query_len = 1, we switch to and capture the optimized
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
# for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
if uniform_decode:
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
else:
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph:
# If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == \
CUDAGraphMode.FULL:
attn_metadata = {}
# Make sure max_model_len is used at the graph capture time.
@ -2255,7 +2320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
max_query_len=max_query_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
@ -2299,12 +2364,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
# sanity check
assert cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor):
outputs = self.model(
input_ids=input_ids,
positions=positions,
@ -2436,7 +2515,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32,
device=self.device)
model = cast(VllmModelForPooling, self.model)
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)
@ -2546,12 +2625,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
gc.collect()
def capture_model(self) -> None:
if not self.use_cuda_graph:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"set -O %s and ensure `use_cudagraph` was not manually set to "
"False", CompilationLevel.PIECEWISE)
"ensure `cudagraph_mode` was not manually set to `NONE`")
return
else:
self.initialize_cudagraph_capture()
compilation_counter.num_gpu_runner_capture_triggers += 1
@ -2576,25 +2656,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
full_cg = self.full_cuda_graph
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank():
compilation_cases = tqdm(
list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes")
for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False)
# Capture full cudagraph for uniform decode batches if we have
# dont already have full mixed prefill-decode cudagraphs
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
cudagraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.cudagraph_batch_sizes if
x <= max_num_tokens and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable cudagraph capturing globally, so any unexpected cudagraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
# we may doing lazy capturing in future that still allows capturing
# after here.
set_cudagraph_capturing_enabled(False)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -2604,6 +2700,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def _capture_cudagraphs(self, compilation_cases: list[int],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
cudagraph_runtime_mode.name))
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
skip_eplb=True)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
skip_eplb=True)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
@ -2648,25 +2779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builder_i,
layer_names)
attn_groups.append(attn_group)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
@ -2734,6 +2846,75 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"All or none of the layers are expected to be encoder-only"
self.is_encoder_only_model = True
def initialize_cudagraph_capture(self) -> None:
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
for attn_group in self._attn_group_iterator():
builder = attn_group.metadata_builder
if builder.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder.cudagraph_support
min_cg_builder_name = builder.__class__.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_cg_support != AttentionCGSupport.ALWAYS:
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if min_cg_support == AttentionCGSupport.NEVER:
# if not supported any full cudagraphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full cudagraph related mode
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_AND_PIECEWISE
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_cg_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_builder_name} (support: {min_cg_support})")
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full cudagraph if they are requested
# even after automatic downgrades
if cudagraph_mode.has_full_cudagraphs() \
and min_cg_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_builder_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
# Trigger cudagraph dispatching keys initialization here (after
# initializing attn backends).
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering

View File

@ -322,16 +322,11 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True,
)
if self.model_runner.is_pooling_model: