Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
230 lines
7.5 KiB
Python
230 lines
7.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import contextlib
|
|
import os
|
|
import weakref
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
|
|
from tests.utils import wait_for_gpu_memory_to_clear
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import CompilationConfig
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def temporary_environ(env_vars):
|
|
"""
|
|
Temporarily set environment variables and restore them afterward.
|
|
We have to do this vs monkeypatch because monkeypatch doesn't work
|
|
with "module" scoped fixtures.
|
|
"""
|
|
original_env = {k: os.environ.get(k) for k in env_vars}
|
|
try:
|
|
os.environ.update(env_vars)
|
|
yield
|
|
finally:
|
|
for k, v in original_env.items():
|
|
if v is None:
|
|
os.environ.pop(k, None)
|
|
else:
|
|
os.environ[k] = v
|
|
|
|
|
|
@dataclass
|
|
class BackendConfig:
|
|
name: str
|
|
env_vars: dict
|
|
comp_config: dict
|
|
specific_gpu_arch: Optional[tuple] = None
|
|
|
|
|
|
# Define all backend configurations of full cudagraph to be tested
|
|
backend_configs = {
|
|
# FA3 on Hopper
|
|
"FA3":
|
|
BackendConfig(name="FA3",
|
|
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
|
comp_config={
|
|
"cudagraph_mode": "FULL",
|
|
},
|
|
specific_gpu_arch=(9, 0)),
|
|
# FlashMLA on Hopper
|
|
"FlashMLA":
|
|
BackendConfig(name="FlashMLA",
|
|
env_vars={
|
|
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
|
},
|
|
comp_config={
|
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
},
|
|
specific_gpu_arch=(9, 0)),
|
|
# Cutlass MLA on Blackwell
|
|
"CutlassMLA":
|
|
BackendConfig(
|
|
name="CutlassMLA",
|
|
env_vars={
|
|
"VLLM_USE_V1": "1",
|
|
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
|
"FORCE_NUM_KV_SPLITS":
|
|
"1", # TODO: remove this when hang issue is fixed
|
|
},
|
|
comp_config={
|
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
|
},
|
|
specific_gpu_arch=(10, 0)),
|
|
# FA2
|
|
"FA2":
|
|
BackendConfig(name="FA2",
|
|
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
|
comp_config={
|
|
"cudagraph_mode": "FULL",
|
|
}),
|
|
# Triton Attention
|
|
"TritonAttn":
|
|
BackendConfig(name="TritonAttn",
|
|
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
|
comp_config={
|
|
"cudagraph_mode": "FULL",
|
|
}),
|
|
# FlashInfer
|
|
"FlashInfer":
|
|
BackendConfig(name="FlashInfer",
|
|
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
|
comp_config={
|
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
}),
|
|
}
|
|
|
|
test_params_full_cudagraph = []
|
|
|
|
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
|
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
|
for mla_backend in MLA_backends:
|
|
test_params_full_cudagraph.append(
|
|
pytest.param(
|
|
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
|
|
|
# Qwen/Qwen2-1.5B-Instruct with other backends
|
|
other_backend_configs = [
|
|
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
|
]
|
|
for backend_config in other_backend_configs:
|
|
test_params_full_cudagraph.append(
|
|
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def llm_pair(request):
|
|
model, backend_config = request.param
|
|
|
|
# Dynamically skip test if GPU capability is not met
|
|
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
|
!= current_platform.get_device_capability():
|
|
if backend_config.specific_gpu_arch == (9, 0):
|
|
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
|
elif backend_config.specific_gpu_arch == (10, 0):
|
|
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
|
|
|
env_vars = {
|
|
"VLLM_USE_V1": "1",
|
|
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
|
# when per-request generators are not used in V1.
|
|
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
|
**backend_config.env_vars,
|
|
}
|
|
with temporary_environ(env_vars):
|
|
full = LLM(
|
|
model=model,
|
|
gpu_memory_utilization=0.43,
|
|
trust_remote_code=True,
|
|
max_model_len=1024,
|
|
max_num_seqs=128,
|
|
compilation_config=\
|
|
CompilationConfig(**backend_config.comp_config),
|
|
generation_config="vllm",
|
|
seed=42,
|
|
)
|
|
piecewise = LLM(
|
|
model=model,
|
|
gpu_memory_utilization=0.43,
|
|
trust_remote_code=True,
|
|
max_model_len=1024,
|
|
max_num_seqs=128,
|
|
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
|
|
generation_config="vllm",
|
|
seed=42,
|
|
)
|
|
|
|
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
|
yield weakref.proxy(full), weakref.proxy(piecewise)
|
|
del full
|
|
del piecewise
|
|
|
|
wait_for_gpu_memory_to_clear(
|
|
devices=[0],
|
|
threshold_ratio=0.1,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
|
class TestFullCUDAGraph:
|
|
"""
|
|
Use a class such that an llm pair is constructed once for all
|
|
batch_size/max_tokens combinations and released immediately after.
|
|
|
|
Module-scope fixtures would stick around the whole time,
|
|
meaning there would be multiple LLM instances hogging memory simultaneously.
|
|
"""
|
|
|
|
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
|
(1, 10),
|
|
(7, 10),
|
|
(16, 10),
|
|
(25, 10),
|
|
(32, 10),
|
|
(45, 10),
|
|
(64, 10),
|
|
(123, 10),
|
|
(8, 5),
|
|
(8, 30),
|
|
])
|
|
def test_full_cudagraph(self, batch_size, max_tokens,
|
|
llm_pair: tuple[LLM, LLM]):
|
|
"""
|
|
Test various batch sizes and max_tokens to ensure that the
|
|
full cudagraph compilation works for padded cases too.
|
|
"""
|
|
|
|
full_cudagraph_llm, piecewise_llm = llm_pair
|
|
|
|
prompts = ["the quick brown fox"] * batch_size
|
|
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
|
# that can amplify tiny numeric differences across runtimes.
|
|
sampling_params = SamplingParams(temperature=0.0,
|
|
max_tokens=max_tokens,
|
|
top_p=1.0)
|
|
|
|
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
|
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
|
|
|
# Check that all responses are the same
|
|
for piecewise_res, full_res in zip(piecewise_responses,
|
|
full_responses):
|
|
assert piecewise_res.outputs[0].text.lower() == \
|
|
full_res.outputs[0].text.lower()
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
|
def test_full_cudagraph_with_invalid_backend():
|
|
with temporary_environ({
|
|
"VLLM_USE_V1": "1",
|
|
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
|
# Flex_Attention is not supported with full cuda graph
|
|
}), pytest.raises(RuntimeError):
|
|
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
|
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|