[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@ -3,7 +3,8 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -32,27 +33,130 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
test_params_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(
|
||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
||||
|
||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||
other_backend_configs = [
|
||||
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||
]
|
||||
for backend_config in other_backend_configs:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def llm_pair(request):
|
||||
model = request.param
|
||||
model, backend_config = request.param
|
||||
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||
!= current_platform.get_device_capability():
|
||||
if backend_config.specific_gpu_arch == (9, 0):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
elif backend_config.specific_gpu_arch == (10, 0):
|
||||
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
||||
|
||||
env_vars = {
|
||||
"VLLM_USE_V1": "1",
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||
**backend_config.env_vars,
|
||||
}
|
||||
with temporary_environ(env_vars):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True),
|
||||
max_num_seqs=128,
|
||||
compilation_config=\
|
||||
CompilationConfig(**backend_config.comp_config),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
max_num_seqs=128,
|
||||
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
||||
@ -66,90 +170,7 @@ def llm_pair(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cutlass_mla_llm_pair(request):
|
||||
model = request.param
|
||||
|
||||
# force V1 engine and Cutlass MLA backend
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
}):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
|
||||
),
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
)
|
||||
|
||||
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||
del full
|
||||
del piecewise
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cutlass_mla_llm_pair",
|
||||
[
|
||||
# use an MLA model
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
|
||||
reason="Only Blackwell GPUs support Cutlass MLA")
|
||||
class TestFullCUDAGraphCutlassMLA:
|
||||
"""
|
||||
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||
(8, 8),
|
||||
])
|
||||
def test_full_cudagraph_sm100_cutlass_mla(
|
||||
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
|
||||
LLM]):
|
||||
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
# Model names for the llm_pair fixture
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
"Qwen/Qwen2-1.5B-Instruct"
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
||||
class TestFullCUDAGraph:
|
||||
"""
|
||||
Use a class such that an llm pair is constructed once for all
|
||||
@ -178,12 +199,14 @@ class TestFullCUDAGraph:
|
||||
full cudagraph compilation works for padded cases too.
|
||||
"""
|
||||
|
||||
piecewise_llm, full_cudagraph_llm = llm_pair
|
||||
full_cudagraph_llm, piecewise_llm = llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
prompts = ["the quick brown fox"] * batch_size
|
||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||
# that can amplify tiny numeric differences across runtimes.
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
top_p=1.0)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
@ -191,42 +214,16 @@ class TestFullCUDAGraph:
|
||||
# Check that all responses are the same
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supported",
|
||||
[
|
||||
("Qwen/Qwen2-1.5B-Instruct", True),
|
||||
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
|
||||
("deepseek-ai/DeepSeek-V2-Lite", False),
|
||||
])
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
def test_lower_max_num_seqs(model, supported):
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}), ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(RuntimeError))
|
||||
|
||||
llm = LLM(model=model,
|
||||
max_num_seqs=256,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[64, 256, 512]))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
assert piecewise_res.outputs[0].text.lower() == \
|
||||
full_res.outputs[0].text.lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
"2" #FA2 not supported with full_cuda_graph
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}), pytest.raises(RuntimeError):
|
||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
||||
|
@ -11,10 +11,10 @@ from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
|
||||
), set_forward_context(None,
|
||||
vllm_config=vllm_config): # background context
|
||||
# warm up with background context
|
||||
model(inputs)
|
||||
|
||||
model(torch.randn(2).cuda())
|
||||
model(torch.randn(1).cuda())
|
||||
# capturing/replaying should under context of cudagraph dispatching
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
model(torch.randn(2).cuda())
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=1, )):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
output = model(input)
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
@ -18,9 +18,9 @@ from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -276,9 +276,11 @@ def run_model(llama_config,
|
||||
)
|
||||
if split_attn:
|
||||
compilation_config.splitting_ops = ["silly.attention"]
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, )
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config,
|
||||
additional_config=llama_config)
|
||||
@ -287,17 +289,37 @@ def run_model(llama_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="").eval().cuda()
|
||||
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config): # background context
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(input_ids, positions)
|
||||
model(input_ids[:2], positions[:2])
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
model(input_ids[:2], positions[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
output = output.cpu()
|
||||
|
||||
|
0
tests/v1/cudagraph/__init__.py
Normal file
0
tests/v1/cudagraph/__init__.py
Normal file
406
tests/v1/cudagraph/test_cudagraph_dispatch.py
Normal file
406
tests/v1/cudagraph/test_cudagraph_dispatch.py
Normal file
@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
ParallelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
|
||||
|
||||
# Helper MLP for testing
|
||||
class SimpleMLP(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(10, 10)
|
||||
self.fc2 = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def _create_vllm_config(compilation_config: CompilationConfig,
|
||||
max_num_seqs: int = 8) -> MagicMock:
|
||||
mock_config = MagicMock(spec=VllmConfig)
|
||||
mock_config.compilation_config = compilation_config
|
||||
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
||||
mock_config.parallel_config = ParallelConfig()
|
||||
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
return mock_config
|
||||
|
||||
|
||||
class TestCudagraphDispatcher:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
# Test case 0: Full CG for mixed batches, no separate routine
|
||||
{
|
||||
"case_id": 0,
|
||||
"cudagraph_mode": "FULL",
|
||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||
},
|
||||
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||
{
|
||||
"case_id": 1,
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"compilation_level": CompilationLevel.PIECEWISE,
|
||||
},
|
||||
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||
{
|
||||
"case_id": 2,
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||
},
|
||||
# Test case 3: Piecewise for all
|
||||
{
|
||||
"case_id": 3,
|
||||
"cudagraph_mode": "PIECEWISE",
|
||||
"compilation_level": CompilationLevel.PIECEWISE,
|
||||
},
|
||||
])
|
||||
def test_dispatcher(self, params):
|
||||
# Setup dispatcher
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=params["cudagraph_mode"],
|
||||
level=params["compilation_level"],
|
||||
cudagraph_capture_sizes=[1, 8])
|
||||
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
dispatcher.initialize_cudagraph_keys(
|
||||
cudagraph_mode=comp_config.cudagraph_mode,
|
||||
uniform_decode_query_len=1)
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||
|
||||
# Test dispatch logic
|
||||
# 1. non-uniform batch, size in cudagraph size list
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||
if params["cudagraph_mode"] == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_full_exact
|
||||
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
# 2. uniform decode batch, size in cudagraph size list
|
||||
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
||||
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
||||
if params["cudagraph_mode"] == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
elif params["cudagraph_mode"] in [
|
||||
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
|
||||
]:
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact
|
||||
elif params["cudagraph_mode"] == "PIECEWISE":
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
# 3. No key match
|
||||
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_no_match)
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCUDAGraphWrapper:
|
||||
|
||||
def setup_method(self):
|
||||
self.vllm_config = _create_vllm_config(CompilationConfig())
|
||||
self.model = SimpleMLP().to("cuda")
|
||||
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
|
||||
self.input_tensor = torch.randn(1, 10, device="cuda")
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_capture_and_replay(self):
|
||||
wrapper = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||
|
||||
# 0. global warmup
|
||||
with set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
wrapper(self.input_tensor)
|
||||
|
||||
# 1. Capture
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
batch_descriptor=batch_descriptor),\
|
||||
patch("torch.cuda.graph",
|
||||
wraps=torch.cuda.graph) as mock_cuda_graph:
|
||||
output1 = wrapper(self.input_tensor)
|
||||
# capturing phase should generate a zero output
|
||||
assert torch.allclose(output1, torch.zeros_like(output1))
|
||||
mock_cuda_graph.assert_called_once()
|
||||
|
||||
assert batch_descriptor in wrapper.concrete_cudagraph_entries
|
||||
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
|
||||
assert entry.cudagraph is not None
|
||||
|
||||
# 2. Replay
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
batch_descriptor=batch_descriptor),\
|
||||
patch.object(entry.cudagraph, 'replay',
|
||||
wraps=entry.cudagraph.replay) as mock_replay:
|
||||
output2 = wrapper(self.input_tensor)
|
||||
mock_replay.assert_called_once()
|
||||
|
||||
# Compare with eager output
|
||||
eager_output = self.model(self.input_tensor)
|
||||
torch.testing.assert_close(eager_output, output2)
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_bypass_on_mode_mismatch(self):
|
||||
wrapper = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=batch_descriptor), \
|
||||
patch('torch.cuda.graph',
|
||||
wraps=torch.cuda.graph) as mock_cuda_graph, \
|
||||
patch.object(self.model, 'forward',
|
||||
wraps=self.model.forward) as mock_forward:
|
||||
wrapper(self.input_tensor)
|
||||
mock_cuda_graph.assert_not_called()
|
||||
mock_forward.assert_called_once()
|
||||
assert not wrapper.concrete_cudagraph_entries
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_bypass_on_mode_none(self):
|
||||
wrapper = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=batch_descriptor), \
|
||||
patch('torch.cuda.graph',
|
||||
wraps=torch.cuda.graph) as mock_cuda_graph:
|
||||
wrapper(self.input_tensor)
|
||||
mock_cuda_graph.assert_not_called()
|
||||
assert not wrapper.concrete_cudagraph_entries
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCudagraphIntegration:
|
||||
|
||||
def setup_method(self):
|
||||
# only FULL mode for non-uniform batches
|
||||
self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
cudagraph_mode="FULL",
|
||||
cudagraph_capture_sizes=[10, 20])
|
||||
self.vllm_config = _create_vllm_config(self.comp_config)
|
||||
self.dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
self.dispatcher.initialize_cudagraph_keys(
|
||||
self.comp_config.cudagraph_mode, uniform_decode_query_len=1)
|
||||
|
||||
def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode,
|
||||
batch_descriptor):
|
||||
"""Helper to run a single call and monitor the action."""
|
||||
|
||||
with patch('torch.cuda.graph',
|
||||
wraps=torch.cuda.graph) as mock_graph_context, \
|
||||
patch.object(wrapper, 'runnable',
|
||||
wraps=wrapper.runnable) as mock_runnable:
|
||||
|
||||
entry = wrapper.concrete_cudagraph_entries.get(
|
||||
batch_descriptor, None)
|
||||
|
||||
context = set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=runtime_mode,
|
||||
batch_descriptor=batch_descriptor)
|
||||
mock_replay = MagicMock()
|
||||
if entry and entry.cudagraph:
|
||||
with context, \
|
||||
patch.object(entry.cudagraph, 'replay',
|
||||
new_callable=MagicMock) as mock_replay:
|
||||
wrapper(input_tensor)
|
||||
else:
|
||||
with context:
|
||||
wrapper(input_tensor)
|
||||
|
||||
if mock_graph_context.called:
|
||||
# note that this is globally mocked, so it will be detected
|
||||
# even whether called by the inner or outer wrapper
|
||||
return "capture_global"
|
||||
if mock_replay.called:
|
||||
# only for outer wrapper
|
||||
return "replay"
|
||||
if mock_runnable.call_count > 0:
|
||||
# only for outer wrapper
|
||||
return "bypass"
|
||||
return "unknown"
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_capture_replay_bypass_logic(self):
|
||||
model = SimpleMLP().to("cuda")
|
||||
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
|
||||
CUDAGraphMode.FULL)
|
||||
max_bs = 16
|
||||
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
|
||||
input_1 = persistent_input_buffer[:1]
|
||||
input_2 = persistent_input_buffer[:2]
|
||||
input_3 = persistent_input_buffer[:3]
|
||||
|
||||
desc_1 = BatchDescriptor(num_tokens=1)
|
||||
desc_2 = BatchDescriptor(num_tokens=2)
|
||||
desc_3_unseen = BatchDescriptor(num_tokens=3)
|
||||
|
||||
# 0. global warmup
|
||||
with set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
full_wrapper(input_1)
|
||||
|
||||
rt_mode, key = self.dispatcher.dispatch(desc_1)
|
||||
# 1. Capture first shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
|
||||
key)
|
||||
assert action == "capture_global"
|
||||
|
||||
# 2. Replay first shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
|
||||
key)
|
||||
assert action == "replay"
|
||||
|
||||
rt_mode, key = self.dispatcher.dispatch(desc_2)
|
||||
# 3. Capture second shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode,
|
||||
key)
|
||||
assert action == "capture_global"
|
||||
|
||||
# 4. Replay second shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_2,
|
||||
CUDAGraphMode.FULL, desc_2)
|
||||
assert action == "replay"
|
||||
|
||||
# 5. Bypass if no key match
|
||||
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode,
|
||||
key)
|
||||
assert action == "bypass"
|
||||
|
||||
# capture unseen shape is not allowed after disable
|
||||
set_cudagraph_capturing_enabled(False)
|
||||
with pytest.raises(RuntimeError):
|
||||
self._run_and_monitor_call(full_wrapper, input_3,
|
||||
CUDAGraphMode.FULL, desc_3_unseen)
|
||||
set_cudagraph_capturing_enabled(True)
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_nested_wrappers(self):
|
||||
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
|
||||
model = SimpleMLP().to("cuda")
|
||||
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
|
||||
CUDAGraphMode.FULL)
|
||||
input_1 = torch.randn(1, 10, device="cuda")
|
||||
|
||||
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
|
||||
inner_model = SimpleMLP().to("cuda")
|
||||
piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config,
|
||||
CUDAGraphMode.PIECEWISE)
|
||||
inner_model.forward = MagicMock(wraps=inner_model.forward)
|
||||
outer_model = SimpleMLP().to("cuda")
|
||||
# When outer model is called, it calls the piecewise_wrapper
|
||||
outer_model.forward = MagicMock(wraps=outer_model.forward,
|
||||
side_effect=piecewise_wrapper)
|
||||
full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config,
|
||||
CUDAGraphMode.FULL)
|
||||
|
||||
desc_1 = BatchDescriptor(num_tokens=1)
|
||||
|
||||
# 0. global warmup
|
||||
with set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
full_wrapper(input_1)
|
||||
|
||||
# --- Test runtime mode FULL---
|
||||
# Run with FULL mode context. Expect outer wrapper to capture.
|
||||
# The inner mock should be called once inside the graph capture.
|
||||
outer_model.forward.reset_mock()
|
||||
inner_model.forward.reset_mock()
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.FULL, desc_1)
|
||||
assert action == "capture_global"
|
||||
assert outer_model.forward.call_count == 1
|
||||
assert inner_model.forward.call_count == 1
|
||||
|
||||
# Run again. Expect outer wrapper to replay.
|
||||
# The outer model should NOT be called because the whole graph
|
||||
# is replayed.
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.FULL, desc_1)
|
||||
assert action == "replay"
|
||||
assert outer_model.forward.call_count == 1 # No new call
|
||||
assert inner_model.forward.call_count == 1
|
||||
|
||||
# --- Test runtime mode PIECEWISE ---
|
||||
outer_model.forward.reset_mock()
|
||||
inner_model.forward.reset_mock()
|
||||
# Run with PIECEWISE mode context.
|
||||
# Expect outer wrapper to bypass and call inner wrapper.
|
||||
# Inner wrapper should capture.
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.PIECEWISE, desc_1)
|
||||
assert action == "capture_global"
|
||||
assert outer_model.forward.call_count == 1
|
||||
assert inner_model.forward.call_count == 1
|
||||
|
||||
# Run again with PIECEWISE.
|
||||
# Outer bypasses, inner replays.
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.PIECEWISE, desc_1)
|
||||
assert action == "bypass"
|
||||
assert outer_model.forward.call_count == 2
|
||||
assert inner_model.forward.call_count == 1
|
187
tests/v1/cudagraph/test_cudagraph_mode.py
Normal file
187
tests/v1/cudagraph/test_cudagraph_mode.py
Normal file
@ -0,0 +1,187 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_environ(env_vars):
|
||||
"""
|
||||
Temporarily set environment variables and restore them afterward.
|
||||
We have to do this vs monkeypatch because monkeypatch doesn't work
|
||||
with "module" scoped fixtures.
|
||||
"""
|
||||
original_env = {k: os.environ.get(k) for k in env_vars}
|
||||
try:
|
||||
os.environ.update(env_vars)
|
||||
yield
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
# test attention backend and cudagraph_mode combo
|
||||
# (backend_name, cudagraph_mode, supported)
|
||||
combo_cases_1 = [
|
||||
("FA3", "FULL", True),
|
||||
("FA3", "FULL_AND_PIECEWISE", True),
|
||||
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||
("FA2", "FULL_AND_PIECEWISE", True),
|
||||
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||
("FlashInfer", "FULL_AND_PIECEWISE", True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combo_case", combo_cases_1)
|
||||
def test_backend_and_cudagraph_mode_combo(combo_case):
|
||||
backend_name, cudagraph_mode, supported = combo_case
|
||||
if backend_name == "FlashInfer":
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("FlashInfer is not installed")
|
||||
backend_config = backend_configs[backend_name]
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||
!= current_platform.get_device_capability():
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
|
||||
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(Exception))
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_seqs=256,
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
level=3, cudagraph_mode=cudagraph_mode))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
|
||||
try:
|
||||
llm = weakref.proxy(llm)
|
||||
del llm
|
||||
except UnboundLocalError:
|
||||
pass
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
# test cudagraph_mode with different compilation level.
|
||||
# (backend_name, cudagraph_mode, compilation_level, supported)
|
||||
combo_cases_2 = [
|
||||
("FA2", "FULL", 0, True), # no compilation + full cudagraph
|
||||
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
|
||||
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
|
||||
("FA2", "PIECEWISE", 3,
|
||||
True), # piecewise compilation + piecewise cudagraph
|
||||
("FA2", "FULL_AND_PIECEWISE", 0,
|
||||
False), # piecewise cudagraph not supported without piecewise compilation
|
||||
("FA2", "FULL_AND_PIECEWISE", 3, True),
|
||||
("FA2", "FULL_DECODE_ONLY", 0, True),
|
||||
("FA2", "FULL_DECODE_ONLY", 3, True),
|
||||
("FA2", "NONE", 0, True), # no compilation + no cudagraph
|
||||
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combo_case", combo_cases_2)
|
||||
def test_cudagraph_compilation_combo(combo_case):
|
||||
backend_name, cudagraph_mode, compilation_level, supported\
|
||||
= combo_case
|
||||
|
||||
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(Exception))
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_seqs=256,
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
level=compilation_level, cudagraph_mode=cudagraph_mode))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
try:
|
||||
llm = weakref.proxy(llm)
|
||||
del llm
|
||||
except UnboundLocalError:
|
||||
pass
|
||||
finally:
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
@ -15,7 +15,7 @@ import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||
@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule,
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
# we share the global graph pool among all the backends
|
||||
global_graph_pool = None
|
||||
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
# Lazy import here to avoid circular import
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
from .cuda_piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = resolve_obj_by_qualname(
|
||||
current_platform.get_piecewise_backend_cls())
|
||||
self.module.__dict__[target] = piecewise_backend(
|
||||
submod, self.vllm_config, self.graph_pool, index,
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod, self.vllm_config, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||
|
||||
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
graph_pool=self.graph_pool,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph))
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
@ -413,9 +433,7 @@ class VllmBackend:
|
||||
# them, e.g. backbone (default), eagle_head, etc.
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
global_graph_pool = current_platform.graph_pool_handle()
|
||||
global_graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
@ -585,7 +603,7 @@ class VllmBackend:
|
||||
|
||||
self._called = True
|
||||
|
||||
if not self.compilation_config.use_cudagraph or \
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
|
||||
not self.compilation_config.cudagraph_copy_inputs:
|
||||
return self.split_gm
|
||||
|
||||
|
@ -1,72 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class AbstractPiecewiseBackend(Protocol):
|
||||
"""
|
||||
PiecewiseBackend interface that allows platforms to extend
|
||||
piecewise static graph.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend, **kwargs):
|
||||
"""
|
||||
Initializes the PiecewiseBackend class with compilation and
|
||||
execution-related configurations.
|
||||
|
||||
This class handles piecewise compilation, graph capturing,
|
||||
and dispatching for specific input shapes.
|
||||
|
||||
Args:
|
||||
graph (fx.GraphModule): The graph represented in fx.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
graph_pool (Any):
|
||||
Graph memory pool handle, e.g.,
|
||||
`torch.cuda.graph_pool_handle()`.
|
||||
piecewise_compile_index (int):
|
||||
Index of the current piecewise subgraph.
|
||||
total_piecewise_compiles (int):
|
||||
Total number of piecewise-compiled graphs.
|
||||
sym_shape_indices (list[int]):
|
||||
Indices of symbolic shape.
|
||||
compiled_graph_for_general_shape (Callable):
|
||||
Callable that executes the graph compiled for general shapes.
|
||||
vllm_backend (VllmBackend):
|
||||
Backend compiler that manages compilation and graph runtime
|
||||
for vLLM.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments reserved for future
|
||||
extensions or custom platforms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
"""Executes the compiled graph for given input args.
|
||||
|
||||
If this is the first invocation, executes the general compiled graph
|
||||
and initiates the compilation process tracking. For subsequent calls,
|
||||
dynamically dispatches execution to either a compiled graph or a static
|
||||
graph based on the input shape.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
graph. The symbolic shape is expected to be in position
|
||||
`sym_shape_indices[0]`.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed graph. This can be from the general
|
||||
compiled graph, a specialized compiled version for the given shape,
|
||||
or a replayed static graph.
|
||||
"""
|
||||
raise NotImplementedError
|
54
vllm/compilation/base_static_graph.py
Normal file
54
vllm/compilation/base_static_graph.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
graph_pool (Any):
|
||||
Graph memory pool handle, e.g.,
|
||||
`torch.cuda.graph_pool_handle()`.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
193
vllm/compilation/cuda_graph.py
Normal file
193
vllm/compilation/cuda_graph.py
Normal file
@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
graph_pool: Any = None,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.graph_pool = graph_pool
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
cudagraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = \
|
||||
CUDAGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
@ -2,21 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,44 +18,29 @@ logger = init_logger(__name__)
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
need_to_compile: bool # the size is in compile_sizes
|
||||
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
||||
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
num_finished_warmup: int = 0
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class CUDAPiecewiseBackend:
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
Independently, we will capture cudagraph for different shapes.
|
||||
|
||||
If a shape needs both compilation and cudagraph, we will
|
||||
compile it first, and then capture cudagraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
@ -70,11 +49,10 @@ class CUDAPiecewiseBackend:
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
self.cudagraph_capture_sizes: set[int] = set(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
) if self.compilation_config.use_cudagraph else set()
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
@ -84,18 +62,18 @@ class CUDAPiecewiseBackend:
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
# the entries for different shapes that we need to compile
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for shape in self.compile_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
need_to_compile=shape in self.compile_sizes,
|
||||
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
||||
runnable=self.compiled_graph_for_general_shape,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
@ -112,16 +90,14 @@ class CUDAPiecewiseBackend:
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.compiled_graph_for_general_shape
|
||||
|
||||
if entry.need_to_compile and not entry.compiled:
|
||||
if not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
@ -138,81 +114,4 @@ class CUDAPiecewiseBackend:
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
# Skip CUDA graphs if this entry doesn't use them OR
|
||||
# if we're supposed to skip them globally
|
||||
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
|
||||
if not entry.use_cudagraph or skip_cuda_graphs:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
if self.is_first_graph:
|
||||
logger.debug(
|
||||
"Warming up %s/%s for shape %s",
|
||||
entry.num_finished_warmup,
|
||||
self.compilation_config.cudagraph_num_of_warmups,
|
||||
runtime_shape)
|
||||
return entry.runnable(*args)
|
||||
|
||||
if self.is_first_graph:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every shape.
|
||||
# We only log it in the debug mode.
|
||||
logger.debug("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of cudagraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for cudagraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
return entry.runnable(*args)
|
||||
|
@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled():
|
||||
# used to monitor whether an cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled.")
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
|
@ -11,7 +11,8 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode,
|
||||
get_current_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.vllm_config.compilation_config.use_cudagraph and \
|
||||
"update" in new_code.co_names:
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != \
|
||||
CUDAGraphMode.NONE and "update" in new_code.co_names:
|
||||
import depyf
|
||||
src = depyf.decompile(new_code)
|
||||
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||
|
@ -32,7 +32,7 @@ from vllm import version
|
||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
PrefixCachingHashAlgo)
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
PassConfig)
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.utils import ConfigType, config
|
||||
@ -3529,11 +3529,21 @@ class VllmConfig:
|
||||
else:
|
||||
self.compilation_config.level = \
|
||||
CompilationLevel.NO_COMPILATION
|
||||
|
||||
else:
|
||||
# NB: Passing both --enforce-eager and a compilation level
|
||||
# in V0 means the compilation level wins out.
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if envs.VLLM_USE_V1 and self.compilation_config.level \
|
||||
== CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
@ -3541,12 +3551,13 @@ class VllmConfig:
|
||||
True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
|
||||
# is set to True, full CUDA graphs will be used.
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and self.model_config.enforce_eager:
|
||||
logger.info("Cudagraph is disabled under eager mode")
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif envs.VLLM_USE_V1:
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
@ -3566,12 +3577,6 @@ class VllmConfig:
|
||||
"Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if self.compilation_config.full_cuda_graph and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
disable_chunked_prefill_reasons: list[str] = []
|
||||
|
||||
if self.model_config and self.model_config.pooler_config:
|
||||
@ -3612,9 +3617,32 @@ class VllmConfig:
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# final check of cudagraph mode after platform-specific update
|
||||
if envs.VLLM_USE_V1:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
||||
and self.model_config is not None and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("CUDAGraphMode.FULL is not supported with "
|
||||
"cascade attention currently. Disabling cascade"
|
||||
"attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode\
|
||||
.requires_piecewise_compilation():
|
||||
assert self.compilation_config.level == \
|
||||
CompilationLevel.PIECEWISE, \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE "\
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
|
@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import hashlib
|
||||
from collections import Counter
|
||||
from dataclasses import asdict, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -31,6 +32,40 @@ class CompilationLevel:
|
||||
PIECEWISE = 3
|
||||
|
||||
|
||||
class CUDAGraphMode(enum.Enum):
|
||||
""" Constants for the cudagraph mode in CompilationConfig.
|
||||
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
|
||||
treated as concrete runtime mode for cudagraph runtime dispatching.
|
||||
"""
|
||||
NONE = 0
|
||||
PIECEWISE = 1
|
||||
FULL = 2
|
||||
FULL_DECODE_ONLY = (FULL, NONE)
|
||||
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
|
||||
|
||||
def decode_mode(self) -> 'CUDAGraphMode':
|
||||
return CUDAGraphMode(self.value[0]) if \
|
||||
self.separate_routine() else self
|
||||
|
||||
def mixed_mode(self) -> 'CUDAGraphMode':
|
||||
return CUDAGraphMode(self.value[1]) if \
|
||||
self.separate_routine() else self
|
||||
|
||||
def requires_piecewise_compilation(self) -> bool:
|
||||
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
|
||||
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
|
||||
|
||||
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
|
||||
return CUDAGraphMode(max(
|
||||
self.value)) if self.separate_routine() else self
|
||||
|
||||
def has_full_cudagraphs(self) -> bool:
|
||||
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
|
||||
|
||||
def separate_routine(self) -> bool:
|
||||
return isinstance(self.value, tuple)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PassConfig:
|
||||
@ -91,6 +126,7 @@ class CompilationConfig:
|
||||
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
|
||||
- CudaGraph capture:
|
||||
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
|
||||
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
|
||||
- [`cudagraph_capture_sizes`]
|
||||
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
|
||||
- [`cudagraph_num_of_warmups`]
|
||||
@ -157,7 +193,7 @@ class CompilationConfig:
|
||||
By default, all custom ops are enabled when running without Inductor and
|
||||
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||
splitting_ops: list[str] = field(default_factory=list)
|
||||
splitting_ops: Optional[list[str]] = None
|
||||
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
||||
compilation."""
|
||||
|
||||
@ -187,7 +223,43 @@ class CompilationConfig:
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||
|
||||
# CudaGraph compilation
|
||||
use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1)
|
||||
cudagraph_mode: Optional[CUDAGraphMode] = None
|
||||
"""
|
||||
The mode of the cudagraph.
|
||||
- NONE, no cudagraph capture.
|
||||
- PIECEWISE. (v1 default)
|
||||
- FULL.
|
||||
- FULL_DECODE_ONLY.
|
||||
- FULL_AND_PIECEWISE.
|
||||
|
||||
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
|
||||
incompatiable ops (i.e. some attention ops) outside the cudagraph
|
||||
for general flexibility.
|
||||
This is the default mode.
|
||||
|
||||
FULL mode: Capture full cudagraph for all batches. Can be good for small
|
||||
models or workloads with small prompts; not supported by many backends.
|
||||
Generally for performance FULL_AND_PIECEWISE is better.
|
||||
|
||||
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
|
||||
Mixed prefill-decode batches are run without cudagraphs. Can be good for
|
||||
decode instances in a P/D setup where prefill is not as important so we
|
||||
can save some memory.
|
||||
|
||||
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
|
||||
piecewise cudagraph for prefill and mixed prefill-decode batches.
|
||||
This is like the most performant mode for most models.
|
||||
|
||||
Currently, the cudagraph mode is only used for the v1 engine.
|
||||
Note that the cudagraph logic is generally orthogonal to the
|
||||
compilation logic. While piecewise cudagraphs require piecewise
|
||||
compilation (level=PIECEWISE and non-empty splitting_ops), full
|
||||
cudagraphs are supported with and without compilation.
|
||||
|
||||
Warning: This flag is new and subject to change in addition
|
||||
more modes may be added.
|
||||
"""
|
||||
use_cudagraph: bool = True
|
||||
"""Whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
@ -197,8 +269,9 @@ class CompilationConfig:
|
||||
CompilationLevel.PIECEWISE (aka -O3).
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future."""
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||
"""
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
"""Number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
@ -213,12 +286,17 @@ class CompilationConfig:
|
||||
cudagraph. If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False."""
|
||||
full_cuda_graph: bool = False
|
||||
internally managed buffer. Default is False.
|
||||
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
|
||||
"""
|
||||
full_cuda_graph: Optional[bool] = False
|
||||
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs. Thus this
|
||||
flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models."""
|
||||
performance benefits for smaller models.
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||
"""
|
||||
|
||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||
"""Custom inductor passes, see PassConfig for more details"""
|
||||
@ -253,6 +331,13 @@ class CompilationConfig:
|
||||
Map from layer name to layer objects that need to be accessed outside
|
||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||
|
||||
# Attention ops; used for piecewise cudagraphs
|
||||
_attention_ops: ClassVar[list[str]] = [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.mamba_mixer2",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -297,13 +382,26 @@ class CompilationConfig:
|
||||
if pass_config_exclude:
|
||||
exclude["pass_config"] = pass_config_exclude
|
||||
|
||||
return TypeAdapter(CompilationConfig).dump_json(
|
||||
self,
|
||||
exclude=exclude, # type: ignore[arg-type]
|
||||
exclude_unset=True).decode()
|
||||
# The cast to string is necessary because Pydantic is mocked in docs
|
||||
# builds and sphinx-argparse doesn't know the return type of decode()
|
||||
return str(
|
||||
TypeAdapter(CompilationConfig).dump_json(
|
||||
self,
|
||||
exclude=exclude, # type: ignore[arg-type]
|
||||
exclude_unset=True).decode())
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@field_validator("cudagraph_mode", mode="before")
|
||||
@classmethod
|
||||
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
|
||||
"""
|
||||
enable parse the `cudagraph_mode` enum type from string
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return CUDAGraphMode[value.upper()]
|
||||
return value
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
count_none = self.custom_ops.count("none")
|
||||
count_all = self.custom_ops.count("all")
|
||||
@ -341,7 +439,26 @@ class CompilationConfig:
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
|
||||
# migrate the deprecated flags
|
||||
if not self.use_cudagraph:
|
||||
logger.warning("use_cudagraph is deprecated, use "
|
||||
"cudagraph_mode=NONE instead.")
|
||||
if self.cudagraph_mode is not None:
|
||||
raise ValueError(
|
||||
"use_cudagraph and cudagraph_mode are mutually"
|
||||
" exclusive, prefer cudagraph_mode since "
|
||||
"use_cudagraph is deprecated.")
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
if self.full_cuda_graph:
|
||||
logger.warning("full_cuda_graph is deprecated, use "
|
||||
"cudagraph_mode=FULL instead.")
|
||||
if self.cudagraph_mode is not None:
|
||||
raise ValueError("full_cuda_graph and cudagraph_mode are "
|
||||
"mutually exclusive, prefer cudagraph_mode "
|
||||
"since full_cuda_graph is deprecated.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
|
||||
@ -414,15 +531,34 @@ class CompilationConfig:
|
||||
self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def set_splitting_ops_for_v1(self):
|
||||
# NOTE: this function needs to be called
|
||||
if self.splitting_ops and self.full_cuda_graph:
|
||||
raise ValueError("full_cuda_graph cannot be used together with "
|
||||
"splitting_ops, as Full CUDA graph will override "
|
||||
f"the splitting_ops: {self.splitting_ops}")
|
||||
# NOTE: this function needs to be called only when level is
|
||||
# CompilationLevel.PIECEWISE
|
||||
assert self.level == CompilationLevel.PIECEWISE, (
|
||||
"set_splitting_ops_for_v1 should only be called when "
|
||||
"level is CompilationLevel.PIECEWISE")
|
||||
|
||||
if not self.splitting_ops:
|
||||
self.splitting_ops = [] if self.full_cuda_graph else [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.mamba_mixer2",
|
||||
]
|
||||
if self.splitting_ops is None:
|
||||
# NOTE: When using full cudagraph, instead of setting an empty
|
||||
# list and capture the full cudagraph inside the flattened fx
|
||||
# graph, we keep the piecewise fx graph structure but capture the
|
||||
# full cudagraph outside the fx graph. This reduces some cpu
|
||||
# overhead when the runtime batch_size is not cudagraph captured.
|
||||
# see https://github.com/vllm-project/vllm/pull/20059 for details.
|
||||
self.splitting_ops = self._attention_ops
|
||||
elif len(self.splitting_ops) == 0:
|
||||
logger.warning_once("Using piecewise compilation with empty "
|
||||
"splitting_ops.")
|
||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.warning_once(
|
||||
"When compilation level is piecewise with empty "
|
||||
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
||||
"treated as FULL cudagraph_mode. Please ensure you are "
|
||||
"using attention backends that support cudagraph or set "
|
||||
"cudagraph_mode to NONE explicitly if encountering "
|
||||
"any problems.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
return self.splitting_ops is not None and all(
|
||||
op in self.splitting_ops for op in self._attention_ops)
|
||||
|
@ -5,13 +5,13 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
|
||||
class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
Batch descriptor for cudagraph dispatching. We should keep the num of
|
||||
items as minimal as possible to properly and uniquely describe the padded
|
||||
batch for cudagraph.
|
||||
"""
|
||||
num_tokens: int
|
||||
uniform_decode: bool = False
|
||||
"""
|
||||
False can also be used for an uniform decode batch to dispatch to the
|
||||
cudagraph supporting non-uniform batches.
|
||||
"""
|
||||
|
||||
@property
|
||||
def non_uniform(self) -> "BatchDescriptor":
|
||||
"""
|
||||
Return a non-uniform version of current batch descriptor.
|
||||
"""
|
||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
@ -152,7 +173,15 @@ class ForwardContext:
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
skip_cuda_graphs: bool = False
|
||||
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
||||
# by default NONE, no cudagraph is used.
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
||||
batch_descriptor: Optional[BatchDescriptor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.cudagraph_runtime_mode in [
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext:
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
skip_cuda_graphs: bool = False,
|
||||
):
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
@ -198,7 +227,8 @@ def set_forward_context(
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -177,17 +177,20 @@ class CudaPlatformBase(Platform):
|
||||
logger.info("Forcing kv cache block size to 128 for "
|
||||
"CUTLASS_MLA backend.")
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and compilation_config.use_cudagraph):
|
||||
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
|
||||
logger.info(
|
||||
"Data Parallel: Forcing enforce eager to be True since DP "
|
||||
"Data Parallel: disabling cudagraphs since DP "
|
||||
"with DeepEP high-throughput kernels are not CUDA Graph "
|
||||
"compatible. The DeepEP low-latency kernels are CUDA Graph "
|
||||
"compatible. Set the all_to_all backend to deepep_low_latency "
|
||||
"to use those kernels instead.")
|
||||
compilation_config.use_cudagraph = False
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
if model_config is not None:
|
||||
model_config.enforce_eager = True
|
||||
|
||||
@ -454,8 +457,8 @@ class CudaPlatformBase(Platform):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
|
@ -7,7 +7,7 @@ import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from platform import uname
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -137,6 +137,8 @@ class Platform:
|
||||
|
||||
additional_env_vars: list[str] = []
|
||||
|
||||
_global_graph_pool: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
"""Returns the supported dtypes for the current platform."""
|
||||
@ -522,6 +524,15 @@ class Platform:
|
||||
" attribute.", self.device_type, key)
|
||||
return None
|
||||
|
||||
def get_global_graph_pool(self) -> Any:
|
||||
"""
|
||||
Return the global graph pool for the this platform.
|
||||
"""
|
||||
cls = self.__class__
|
||||
if cls._global_graph_pool is None:
|
||||
cls._global_graph_pool = self.graph_pool_handle()
|
||||
return cls._global_graph_pool
|
||||
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
"""
|
||||
@ -530,11 +541,11 @@ class Platform:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
"""
|
||||
Get piecewise backend class for piecewise graph.
|
||||
Get static graph wrapper class for static graph.
|
||||
"""
|
||||
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
|
||||
return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
|
@ -421,8 +421,8 @@ class RocmPlatform(Platform):
|
||||
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
|
@ -99,7 +99,7 @@ class TpuPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
# For v0, the default block size is 16.
|
||||
@ -109,9 +109,17 @@ class TpuPlatform(Platform):
|
||||
|
||||
# TPU only supports DYNAMO_ONCE compilation level
|
||||
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
|
||||
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
|
||||
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
|
||||
"disabling cudagraph.")
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
if compilation_config.cudagraph_mode is None or \
|
||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
||||
!= CUDAGraphMode.NONE:
|
||||
logger.info("[TPU] CUDA graph is not supported on TPU, "
|
||||
"disabling cudagraphs.")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
|
||||
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
@ -100,16 +101,17 @@ class XPUPlatform(Platform):
|
||||
# Instances created using VllmConfig() typically have model_config as
|
||||
# None by default. The modification involves adding a check to prevent
|
||||
# potential null exceptions check and update model config.
|
||||
if model_config is not None:
|
||||
if model_config.dtype == torch.bfloat16:
|
||||
bf16_supported = cls.device_support_bf16()
|
||||
if not bf16_supported:
|
||||
model_config.dtype = torch.float16
|
||||
if not model_config.enforce_eager:
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on XPU, fallback to the eager "
|
||||
"mode.")
|
||||
model_config.enforce_eager = True
|
||||
if model_config is not None and model_config.dtype == torch.bfloat16 \
|
||||
and not cls.device_support_bf16():
|
||||
model_config.dtype = torch.float16
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.cudagraph_mode is None or \
|
||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
||||
!= CUDAGraphMode.NONE:
|
||||
logger.info("[XPU] CUDA graph is not supported on XPU, "
|
||||
"disabling cudagraphs.")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -154,9 +154,26 @@ def _get_sliding_window_configs(
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
|
||||
else AttentionCGSupport.ALWAYS
|
||||
# FA3:
|
||||
# Supports full cudagraphs for all cases.
|
||||
#
|
||||
# FA2:
|
||||
# For FA2, a graph is captured with max_query_len=1, (which is what we
|
||||
# capture by default for num_tokens <= max_num_seqs when there is no
|
||||
# spec-decode) then these graphs will not work for mixed prefill-decode
|
||||
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
|
||||
# in FA2.
|
||||
# In summary if we are running with spec decodes the graphs would
|
||||
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
|
||||
# the graphs would not work for mixed prefill-decode; sorta the inverse
|
||||
# of UNIFORM_SINGLE_TOKEN_DECODE.
|
||||
# Theres probably a better way to describe this using `AttentionCGSupport`
|
||||
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
|
||||
# to FULL_AND_PIECEWISE.
|
||||
# TODO(luka, lucas): audit FA2 as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
cudagraph_support = AttentionCGSupport.ALWAYS \
|
||||
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph:
|
||||
if not self.aot_schedule:
|
||||
raise ValueError(
|
||||
"AoT scheduling is required for full cuda graph.")
|
||||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||
if not capture_sizes:
|
||||
raise ValueError(
|
||||
"cudagraph_capture_sizes should not be None when "
|
||||
"full_cuda_graph is True.")
|
||||
self.max_cudagraph_size = max(capture_sizes)
|
||||
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||
|
||||
if self.max_cudagraph_size > 992:
|
||||
# This condition derives from FA3's internal heuristic.
|
||||
# TODO(woosuk): Support larger cudagraph sizes.
|
||||
@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder(
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=causal)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
# For FA3 + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder(
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
max_num_splits = 0
|
||||
if (self.use_full_cuda_graph
|
||||
and num_actual_tokens <= self.max_cudagraph_size):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
if num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder(
|
||||
causal=causal)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_attention
|
||||
@ -183,8 +183,8 @@ class FlashInferMetadata:
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
|
||||
decode_mode() == CUDAGraphMode.FULL
|
||||
if self.enable_cuda_graph:
|
||||
# For full cudagraph capture, one `decode_wrapper` for each batch
|
||||
# size is needed for FlashInfer.
|
||||
@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
|
@ -89,8 +89,8 @@ class Mamba2AttentionMetadata:
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder(
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
assert m.max_query_len == 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
|
@ -22,7 +22,7 @@ logger = init_logger(__name__)
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
attn_cudagraph_support: ClassVar[
|
||||
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
|
||||
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
|
@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
if vllm_config.compilation_config.full_cuda_graph:
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
|
@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -58,8 +58,7 @@ class TritonAttentionMetadata:
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
|
@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum):
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
ALWAYS = 3
|
||||
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
||||
UNIFORM_BATCH = 2
|
||||
"""Cudagraph supported for batches the only contain query lengths that are
|
||||
the same, this can be used for spec-decode
|
||||
i.e. "decodes" are 1 + num_speculative_tokens"""
|
||||
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
||||
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
PURE_DECODE_ONLY = 1
|
||||
"""Cudagraph supported for pure decode, need to run without
|
||||
cudagraph for mixed prefill-decode batches"""
|
||||
ALWAYS = 2
|
||||
"""Cudagraph always supported"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
"""
|
||||
Can this batch (with given metadata) use CUDA Graphs for attention.
|
||||
"""
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
|
120
vllm/v1/cudagraph_dispatcher.py
Normal file
120
vllm/v1/cudagraph_dispatcher.py
Normal file
@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudagraphDispatcher:
|
||||
"""
|
||||
Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.
|
||||
|
||||
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
|
||||
for FULL cudagraph runtime mode. The keys are initialized depending on
|
||||
attention support and what cudagraph mode is set in CompilationConfig. The
|
||||
keys stored in dispatcher are the only source of truth for valid
|
||||
cudagraphs that can be dispatched at runtime.
|
||||
|
||||
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
|
||||
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
|
||||
based on the input key. After dispatching (commuicate via forward context),
|
||||
the cudagraph wrappers will trust the dispatch key to do either capturing
|
||||
or replaying (if mode matched), or pass through to the underlying runnable
|
||||
without cudagraph (if mode no match or mode is NONE).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
|
||||
# Dict to store valid cudagraph dispatching keys.
|
||||
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
|
||||
CUDAGraphMode.PIECEWISE: set(),
|
||||
CUDAGraphMode.FULL: set(),
|
||||
}
|
||||
|
||||
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
|
||||
(self.compilation_config.level == CompilationLevel.PIECEWISE and
|
||||
self.compilation_config.splitting_ops_contain_attention()), \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE when "\
|
||||
"cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.cudagraph_mode}, "\
|
||||
f"compilation_level={self.compilation_config.level}, "\
|
||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||
|
||||
self.keys_initialized = False
|
||||
|
||||
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
|
||||
batch_descriptor: BatchDescriptor):
|
||||
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||
f"Invalid cudagraph runtime mode: {runtime_mode}"
|
||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
|
||||
uniform_decode_query_len: int):
|
||||
# This should be called only after attention backend is initialized.
|
||||
|
||||
# Note: we create all valid keys possible for cudagraph but do not
|
||||
# guarantee all keys would be used. For example, we create keys for
|
||||
# piecewise cudagraphs when it is piecewise compilation, which is always
|
||||
# valid, but for attention backend support unified routine, we may not
|
||||
# trigger capturing/replaying the piecewise cudagraphs depending on
|
||||
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs in self.compilation_config.cudagraph_capture_sizes:
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=False))
|
||||
|
||||
# if decode cudagraph mode is FULL, and we don't already have mixed
|
||||
# mode full cudagraphs then add them here.
|
||||
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
|
||||
and cudagraph_mode.separate_routine():
|
||||
max_num_tokens = uniform_decode_query_len * \
|
||||
self.vllm_config.scheduler_config.max_num_seqs
|
||||
cudagraph_capture_sizes_for_decode = [
|
||||
x for x in self.compilation_config.cudagraph_capture_sizes
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs in cudagraph_capture_sizes_for_decode:
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
||||
self.keys_initialized = True
|
||||
|
||||
def dispatch(
|
||||
self, batch_descriptor: BatchDescriptor
|
||||
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
|
||||
"""
|
||||
Given a batch descriptor, dispatch to a cudagraph mode.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
"""
|
||||
# if not initialized, just skip dispatching.
|
||||
if not self.keys_initialized:
|
||||
logger.warning_once("cudagraph dispatching keys are not "
|
||||
"initialized. No cudagraph will be used.")
|
||||
return CUDAGraphMode.NONE, None
|
||||
|
||||
# check if key exists for full cudagraph
|
||||
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_descriptor
|
||||
|
||||
# otherwise, check if non-uniform key exists
|
||||
non_uniform_key = batch_descriptor.non_uniform
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, non_uniform_key
|
||||
|
||||
# also check if non-uniform key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, non_uniform_key
|
||||
|
||||
# finally, just return no cudagraphs
|
||||
return CUDAGraphMode.NONE, None
|
@ -21,7 +21,9 @@ from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
@ -29,7 +31,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import DPMetadata, set_forward_context
|
||||
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@ -48,13 +51,15 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
get_dtype_size, is_pin_memory_available, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
@ -218,11 +223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (
|
||||
self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and self.vllm_config.compilation_config.use_cudagraph
|
||||
and not self.model_config.enforce_eager)
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
# The convention is different.
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
@ -230,8 +230,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
|
||||
# Cache the device properties.
|
||||
self._init_device_properties()
|
||||
|
||||
@ -326,6 +324,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||
|
||||
self.uniform_decode_query_len = 1 if not self.speculative_config else \
|
||||
1 + self.speculative_config.num_speculative_tokens
|
||||
|
||||
# Cudagraph dispatcher for runtime cudagraph dispatching.
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
|
||||
self.mm_budget = (MultiModalBudget(
|
||||
self.model_config,
|
||||
self.scheduler_config,
|
||||
@ -471,7 +475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert (task := pooling_params.task) is not None, (
|
||||
"You did not set `task` in the API")
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
model = cast(VllmModelForPooling, self.get_model())
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
@ -679,13 +683,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str,
|
||||
Any], bool, torch.Tensor, Optional[SpecDecodeMetadata],
|
||||
np.ndarray, Optional[CommonAttentionMetadata]]:
|
||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
|
||||
np.ndarray, Optional[CommonAttentionMetadata], int]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
attention_cuda_graphs: whether attention can run in cudagraph
|
||||
logits_indices, spec_decode_metadata
|
||||
]
|
||||
"""
|
||||
@ -820,7 +822,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# valid, we fill the padded indices with the last index.
|
||||
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
|
||||
logits_indices[-1].item())
|
||||
if (self.use_cuda_graph
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_logits <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
@ -925,17 +927,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
continue
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
attention_cuda_graphs = all(
|
||||
g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
|
||||
for g in self._attn_group_iterator())
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata, num_scheduled_tokens,
|
||||
spec_decode_common_attn_metadata)
|
||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1259,6 +1257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return mm_embeds
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
# get raw model out of the cudagraph wrapper.
|
||||
if isinstance(self.model, CUDAGraphWrapper):
|
||||
return self.model.unwrap()
|
||||
return self.model
|
||||
|
||||
def get_supported_generation_tasks(self) -> list[GenerationTask]:
|
||||
@ -1415,9 +1416,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return
|
||||
|
||||
assert self.eplb_state is not None
|
||||
assert is_mixture_of_experts(self.model)
|
||||
model = self.get_model()
|
||||
assert is_mixture_of_experts(model)
|
||||
self.eplb_state.step(
|
||||
self.model,
|
||||
model,
|
||||
is_dummy,
|
||||
is_profile,
|
||||
log_stats=self.parallel_config.eplb_log_balancedness,
|
||||
@ -1507,15 +1509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata, num_scheduled_tokens_np,
|
||||
spec_decode_common_attn_metadata) = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = (self._prepare_inputs(scheduler_output))
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Use CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
@ -1581,10 +1582,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
# Some attention backends only support CUDA Graphs in pure decode.
|
||||
# If attention doesn't support CUDA Graphs for this batch, but we
|
||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
@ -1593,10 +1596,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
), self.maybe_get_kv_connector_output(
|
||||
scheduler_output) as kv_connector_output:
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -2021,20 +2024,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.model.compile(
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=backend)
|
||||
return
|
||||
# for other compilation levels, cudagraph behavior is controlled by
|
||||
# CudagraphWraper and CudagraphDispatcher of vllm.
|
||||
|
||||
# wrap the model with full cudagraph wrapper if needed.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.model = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
assert getattr(self, "model", None) is not None, \
|
||||
"Cannot reload weights before model is loaded."
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
logger.info("Reloading weights inplace...")
|
||||
model_loader.load_weights(self.model, model_config=self.model_config)
|
||||
model = self.get_model()
|
||||
model_loader.load_weights(model, model_config=self.model_config)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
model = self.get_model()
|
||||
TensorizerLoader.save_model(
|
||||
self.model,
|
||||
model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
@ -2210,31 +2224,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_attn_cudagraph: bool = False,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
skip_eplb: bool = False,
|
||||
is_profile: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run a dummy forward pass to warm up/profile run or capture the
|
||||
CUDA graph for the model.
|
||||
|
||||
Args:
|
||||
num_tokens: Number of tokens to run the dummy forward pass.
|
||||
cudagraph_runtime_mode: used to control the behavior.
|
||||
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
|
||||
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
|
||||
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
|
||||
needed.
|
||||
force_attention: If True, always create attention metadata. Used to
|
||||
warm up attention backend when mode is NONE.
|
||||
uniform_decode: If True, the batch is a uniform decode batch.
|
||||
skip_eplb: If True, skip EPLB state update.
|
||||
is_profile: If True, this is a profile run.
|
||||
"""
|
||||
assert cudagraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens += num_pad
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
# cudagraph_mode.seperate_routine(). This means that we are using
|
||||
# different graphs and/or modes for mixed prefill-decode batches vs.
|
||||
# uniform decode batches. A uniform decode batch means that all
|
||||
# requests have identical query length, except a potential virtual
|
||||
# request (shorter) in the batch account for padding.
|
||||
# Uniform decode batch could either be common pure decode, where
|
||||
# max_query_len == 1, or speculative decode, where
|
||||
# max_query_len == 1 + num_spec_decode_tokens.
|
||||
|
||||
# When setting max_query_len = 1, we switch to and capture the optimized
|
||||
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
|
||||
# for GQA/MQA.
|
||||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||||
num_tokens
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
if uniform_decode:
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||
if num_tokens % max_query_len != 0:
|
||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||||
else:
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
if capture_attn_cudagraph:
|
||||
|
||||
# If force_attention is True, we always capture attention. Otherwise,
|
||||
# it only happens for cudagraph_runtime_mode=FULL.
|
||||
if force_attention or cudagraph_runtime_mode == \
|
||||
CUDAGraphMode.FULL:
|
||||
attn_metadata = {}
|
||||
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
@ -2255,7 +2320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.input_batch.
|
||||
@ -2299,12 +2364,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_tokens, None, False)
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
|
||||
batch_descriptor = None
|
||||
else:
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
# sanity check
|
||||
assert cudagraph_runtime_mode == _cg_mode, (
|
||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
|
||||
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -2436,7 +2515,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
model = cast(VllmModelForPooling, self.get_model())
|
||||
dummy_pooling_params = PoolingParams(task=task)
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(dummy_pooling_params)
|
||||
@ -2546,12 +2625,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
gc.collect()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if not self.use_cuda_graph:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
||||
"set -O %s and ensure `use_cudagraph` was not manually set to "
|
||||
"False", CompilationLevel.PIECEWISE)
|
||||
"ensure `cudagraph_mode` was not manually set to `NONE`")
|
||||
return
|
||||
else:
|
||||
self.initialize_cudagraph_capture()
|
||||
|
||||
compilation_counter.num_gpu_runner_capture_triggers += 1
|
||||
|
||||
@ -2576,25 +2656,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
set_cudagraph_capturing_enabled(True)
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
full_cg = self.full_cuda_graph
|
||||
# Only rank 0 should print progress bar during capture
|
||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||
if is_global_first_rank():
|
||||
compilation_cases = tqdm(
|
||||
list(compilation_cases),
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
desc="Capturing CUDA graph shapes")
|
||||
for num_tokens in compilation_cases:
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens,
|
||||
capture_attn_cudagraph=full_cg,
|
||||
skip_eplb=True)
|
||||
self._dummy_run(num_tokens,
|
||||
capture_attn_cudagraph=full_cg,
|
||||
skip_eplb=True)
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
|
||||
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=False)
|
||||
|
||||
# Capture full cudagraph for uniform decode batches if we have
|
||||
# dont already have full mixed prefill-decode cudagraphs
|
||||
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
cudagraph_mode.separate_routine():
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
self.uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x for x in self.cudagraph_batch_sizes if
|
||||
x <= max_num_tokens and x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
reversed(decode_cudagraph_batch_sizes))
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True)
|
||||
|
||||
# Disable cudagraph capturing globally, so any unexpected cudagraph
|
||||
# capturing will be detected and raise an error after here.
|
||||
# Note: We don't put it into graph_capture context manager because
|
||||
# we may doing lazy capturing in future that still allows capturing
|
||||
# after here.
|
||||
set_cudagraph_capturing_enabled(False)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
@ -2604,6 +2700,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def _capture_cudagraphs(self, compilation_cases: list[int],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool):
|
||||
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
|
||||
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
|
||||
CUDAGraphMode.PIECEWISE]
|
||||
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if is_global_first_rank():
|
||||
compilation_cases = tqdm(
|
||||
compilation_cases,
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
desc="Capturing CUDA graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
cudagraph_runtime_mode.name))
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
skip_eplb=True)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
skip_eplb=True)
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
@ -2648,25 +2779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata_builder_i,
|
||||
layer_names)
|
||||
attn_groups.append(attn_group)
|
||||
|
||||
if self.full_cuda_graph:
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.NEVER:
|
||||
raise ValueError(
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend.__name__}. Turn off "
|
||||
f"CompilationConfig.full_cuda_graph or use a "
|
||||
f" different attention backend.")
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY:
|
||||
# Limit the max cudagraph size to the max number of
|
||||
# sequences for pure decode only cudagraph backend,
|
||||
# whose max_query_len is 1.
|
||||
self.cudagraph_batch_sizes = [
|
||||
size for size in self.cudagraph_batch_sizes
|
||||
if size <= self.scheduler_config.max_num_seqs
|
||||
]
|
||||
|
||||
return attn_groups
|
||||
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
@ -2734,6 +2846,75 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def initialize_cudagraph_capture(self) -> None:
|
||||
min_cg_support = AttentionCGSupport.ALWAYS
|
||||
min_cg_builder_name = None
|
||||
|
||||
for attn_group in self._attn_group_iterator():
|
||||
builder = attn_group.metadata_builder
|
||||
if builder.cudagraph_support.value < min_cg_support.value:
|
||||
min_cg_support = builder.cudagraph_support
|
||||
min_cg_builder_name = builder.__class__.__name__
|
||||
|
||||
# Flexible resolve the cudagraph mode
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
# check cudagraph for mixed batch is supported
|
||||
if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
||||
and min_cg_support != AttentionCGSupport.ALWAYS:
|
||||
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
|
||||
f"with {min_cg_builder_name} backend (support: "
|
||||
f"{min_cg_support})")
|
||||
if min_cg_support == AttentionCGSupport.NEVER:
|
||||
# if not supported any full cudagraphs, just raise it.
|
||||
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
||||
"make sure compilation level is piecewise"
|
||||
raise ValueError(msg)
|
||||
|
||||
# attempt to resolve the full cudagraph related mode
|
||||
if self.compilation_config.splitting_ops_contain_attention():
|
||||
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.FULL_DECODE_ONLY
|
||||
logger.warning(msg)
|
||||
|
||||
# check that if we are doing spec-decode + decode full-cudagraphs it is
|
||||
# supported
|
||||
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.uniform_decode_query_len > 1 and min_cg_support.value
|
||||
< AttentionCGSupport.UNIFORM_BATCH.value):
|
||||
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
|
||||
f" with spec-decode for attention backend "
|
||||
f"{min_cg_builder_name} (support: {min_cg_support})")
|
||||
if self.compilation_config.splitting_ops_contain_attention():
|
||||
msg += "; setting cudagraph_mode=PIECEWISE"
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=NONE"
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.NONE
|
||||
logger.warning(msg)
|
||||
|
||||
# double check that we can support full cudagraph if they are requested
|
||||
# even after automatic downgrades
|
||||
if cudagraph_mode.has_full_cudagraphs() \
|
||||
and min_cg_support == AttentionCGSupport.NEVER:
|
||||
raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not "
|
||||
f"supported with {min_cg_builder_name} backend ("
|
||||
f"support:{min_cg_support}) "
|
||||
"; please try cudagraph_mode=PIECEWISE, "
|
||||
"and make sure compilation level is piecewise")
|
||||
|
||||
# Trigger cudagraph dispatching keys initialization here (after
|
||||
# initializing attn backends).
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
||||
self.compilation_config.cudagraph_mode,
|
||||
self.uniform_decode_query_len)
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
"""
|
||||
Check that if any backends reorder batches; that the reordering
|
||||
|
@ -322,16 +322,11 @@ class Worker(WorkerBase):
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
# activate building attn_metadata for this dummy run to avoid
|
||||
# potential illegal memory access for full cudagraph relay.
|
||||
attn_cudagraph = self.compilation_config.full_cuda_graph and\
|
||||
not self.model_config.enforce_eager
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
hidden_states, last_hidden_states = \
|
||||
self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
capture_attn_cudagraph=attn_cudagraph,
|
||||
skip_eplb=True,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
|
Reference in New Issue
Block a user