Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@ -1,46 +0,0 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.format]
docstring-code-format = true

View File

@ -6,28 +6,16 @@ default_stages:
- manual # Run in CI - manual # Run in CI
exclude: 'vllm/third_party/.*' exclude: 'vllm/third_party/.*'
repos: repos:
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
# Keep the same list from yapfignore here to avoid yapf failing without any inputs
exclude: '(.buildkite|benchmarks|build|examples)/.*'
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7 rev: v0.11.7
hooks: hooks:
- id: ruff - id: ruff
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format - id: ruff-format
files: ^(.buildkite|benchmarks|examples)/.*
- repo: https://github.com/crate-ci/typos - repo: https://github.com/crate-ci/typos
rev: v1.35.5 rev: v1.35.5
hooks: hooks:
- id: typos - id: typos
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v20.1.3 rev: v20.1.3
hooks: hooks:

View File

@ -2,9 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
from benchmark_utils import TimeCollector
from tabulate import tabulate from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool

View File

@ -5,9 +5,9 @@ import time
from unittest import mock from unittest import mock
import numpy as np import numpy as np
from benchmark_utils import TimeCollector
from tabulate import tabulate from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,

View File

@ -37,14 +37,13 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from backend_request_func import ( from backend_request_func import (
ASYNC_REQUEST_FUNCS, ASYNC_REQUEST_FUNCS,
RequestFuncInput, RequestFuncInput,
RequestFuncOutput, RequestFuncOutput,
) )
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
try: try:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer

View File

@ -1,49 +0,0 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true

View File

@ -16,7 +16,7 @@ import shutil
from torch.utils.hipify.hipify_python import hipify from torch.utils.hipify.hipify_python import hipify
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Project directory where all the source + include files live. # Project directory where all the source + include files live.
@ -34,15 +34,14 @@ if __name__ == '__main__':
) )
# Source files to convert. # Source files to convert.
parser.add_argument("sources", parser.add_argument(
help="Source files to hipify.", "sources", help="Source files to hipify.", nargs="*", default=[]
nargs="*", )
default=[])
args = parser.parse_args() args = parser.parse_args()
# Limit include scope to project_dir only # Limit include scope to project_dir only
includes = [os.path.join(args.project_dir, '*')] includes = [os.path.join(args.project_dir, "*")]
# Get absolute path for all source files. # Get absolute path for all source files.
extra_files = [os.path.abspath(s) for s in args.sources] extra_files = [os.path.abspath(s) for s in args.sources]
@ -51,25 +50,31 @@ if __name__ == '__main__':
# The directory might already exist to hold object files so we ignore that. # The directory might already exist to hold object files so we ignore that.
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
hipify_result = hipify(project_directory=args.project_dir, hipify_result = hipify(
output_directory=args.output_dir, project_directory=args.project_dir,
header_include_dirs=[], output_directory=args.output_dir,
includes=includes, header_include_dirs=[],
extra_files=extra_files, includes=includes,
show_detailed=True, extra_files=extra_files,
is_pytorch_extension=True, show_detailed=True,
hipify_extra_files_only=True) is_pytorch_extension=True,
hipify_extra_files_only=True,
)
hipified_sources = [] hipified_sources = []
for source in args.sources: for source in args.sources:
s_abs = os.path.abspath(source) s_abs = os.path.abspath(source)
hipified_s_abs = (hipify_result[s_abs].hipified_path if hipified_s_abs = (
(s_abs in hipify_result hipify_result[s_abs].hipified_path
and hipify_result[s_abs].hipified_path is not None) if (
else s_abs) s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None
)
else s_abs
)
hipified_sources.append(hipified_s_abs) hipified_sources.append(hipified_s_abs)
assert (len(hipified_sources) == len(args.sources)) assert len(hipified_sources) == len(args.sources)
# Print hipified source files. # Print hipified source files.
print("\n".join(hipified_sources)) print("\n".join(hipified_sources))

View File

@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
**{ **{
VLLMDataType.u4b8: "u4b8", VLLMDataType.u4b8: "u4b8",
VLLMDataType.u8b128: "u8b128", VLLMDataType.u8b128: "u8b128",
} },
} }
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
**{ **{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
} },
} }
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
**{ **{
VLLMDataType.u4b8: 4, VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8, VLLMDataType.u8b128: 8,
} },
} }
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.f32: "at::ScalarType::Float", DataType.f32: "at::ScalarType::Float",
} }
VLLMKernelScheduleTag: dict[Union[ VLLMKernelScheduleTag: dict[
MixedInputKernelScheduleType, KernelScheduleType], str] = { Union[MixedInputKernelScheduleType, KernelScheduleType], str
**KernelScheduleTag, # type: ignore ] = {
**{ **KernelScheduleTag, # type: ignore
MixedInputKernelScheduleType.TmaWarpSpecialized: **{
"cutlass::gemm::KernelTmaWarpSpecialized", MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
"cutlass::gemm::KernelTmaWarpSpecializedPingpong", MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: },
"cutlass::gemm::KernelTmaWarpSpecializedCooperative", }
}
}

View File

@ -17,25 +17,30 @@ FILE_HEAD = """
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """.strip()
TEMPLATE = ("template __global__ void Marlin<" TEMPLATE = (
"{{scalar_t}}, " "template __global__ void Marlin<"
"{{w_type_id}}, " "{{scalar_t}}, "
"{{s_type_id}}, " "{{w_type_id}}, "
"{{threads}}, " "{{s_type_id}}, "
"{{thread_m_blocks}}, " "{{threads}}, "
"{{thread_n_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_n_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{thread_k_blocks}}, "
"{{stages}}, " "{{'true' if m_block_size_8 else 'false'}}, "
"{{group_blocks}}, " "{{stages}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{group_blocks}}, "
"( MARLIN_KERNEL_PARAMS );") "{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported, # int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = [ SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", "vllm::kU4",
"vllm::kFE2M1f" "vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
] ]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
@ -58,11 +63,12 @@ def generate_new_kernels():
all_template_str_list = [] all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8 # act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [ if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128" "vllm::kU4B8",
"vllm::kU8B128",
]: ]:
continue continue
if thread_configs[2] == 256: if thread_configs[2] == 256:

View File

@ -17,28 +17,32 @@ FILE_HEAD = """
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """.strip()
TEMPLATE = ("template __global__ void Marlin<" TEMPLATE = (
"{{scalar_t}}, " "template __global__ void Marlin<"
"{{w_type_id}}, " "{{scalar_t}}, "
"{{s_type_id}}, " "{{w_type_id}}, "
"{{threads}}, " "{{s_type_id}}, "
"{{thread_m_blocks}}, " "{{threads}}, "
"{{thread_n_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_n_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{thread_k_blocks}}, "
"{{stages}}, " "{{'true' if m_block_size_8 else 'false'}}, "
"{{group_blocks}}, " "{{stages}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{group_blocks}}, "
"( MARLIN_KERNEL_PARAMS );") "{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported, # int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = [ SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", "vllm::kU4",
"vllm::kFE2M1f" "vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
] ]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
(128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks: # group_blocks:
@ -59,11 +63,12 @@ def generate_new_kernels():
all_template_str_list = [] all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8 # act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [ if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128" "vllm::kU4B8",
"vllm::kU8B128",
]: ]:
continue continue
if thread_configs[2] == 256: if thread_configs[2] == 256:
@ -93,8 +98,7 @@ def generate_new_kernels():
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
is_zp_float_list = [False] is_zp_float_list = [False]
if dtype == "fp16" and scalar_type == "vllm::kU4" and \ if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
group_blocks == 4:
# HQQ (is_zp_float = true) only supports # HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16 # 4bit quantization and fp16
is_zp_float_list.append(True) is_zp_float_list.append(True)

View File

@ -12,18 +12,24 @@ from functools import reduce
from typing import Optional, Union from typing import Optional, Union
import jinja2 import jinja2
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, from vllm_cutlass_library_extension import (
EpilogueScheduleType, DataType,
MixedInputKernelScheduleType, EpilogueScheduleTag,
TileSchedulerTag, EpilogueScheduleType,
TileSchedulerType, VLLMDataType, MixedInputKernelScheduleType,
VLLMDataTypeNames, TileSchedulerTag,
VLLMDataTypeSize, VLLMDataTypeTag, TileSchedulerType,
VLLMDataTypeTorchDataTypeTag, VLLMDataType,
VLLMDataTypeVLLMScalarTypeTag, VLLMDataTypeNames,
VLLMKernelScheduleTag) VLLMDataTypeSize,
VLLMDataTypeTag,
VLLMDataTypeTorchDataTypeTag,
VLLMDataTypeVLLMScalarTypeTag,
VLLMKernelScheduleTag,
)
# yapf: enable # yapf: enable
@ -286,18 +292,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
tile_shape = ( tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
) )
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + cluster_shape = (
f"x{schedule_config.cluster_shape_mnk[1]}" + f"{schedule_config.cluster_shape_mnk[0]}"
f"x{schedule_config.cluster_shape_mnk[2]}") + f"x{schedule_config.cluster_shape_mnk[1]}"
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ + f"x{schedule_config.cluster_shape_mnk[2]}"
.split("::")[-1] )
epilogue_schedule = EpilogueScheduleTag[ kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
schedule_config.epilogue_schedule].split("::")[-1] "::"
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ )[-1]
.split("::")[-1] epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
"::"
)[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + return (
f"_{epilogue_schedule}_{tile_scheduler}") f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
+ f"_{epilogue_schedule}_{tile_scheduler}"
)
# mostly unique shorter sch_sig # mostly unique shorter sch_sig
@ -316,18 +327,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
# unique type_name # unique type_name
def generate_type_signature(kernel_types: TypeConfig): def generate_type_signature(kernel_types: TypeConfig):
return str("".join([ return str(
VLLMDataTypeNames[getattr(kernel_types, field.name)] "".join(
for field in fields(TypeConfig) [
])) VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
)
def generate_type_option_name(kernel_types: TypeConfig): def generate_type_option_name(kernel_types: TypeConfig):
return ", ".join([ return ", ".join(
f"{field.name.replace('b_', 'with_')+'_type'}=" + [
VLLMDataTypeNames[getattr(kernel_types, field.name)] f"{field.name.replace('b_', 'with_') + '_type'}="
for field in fields(TypeConfig) + VLLMDataTypeNames[getattr(kernel_types, field.name)]
]) for field in fields(TypeConfig)
]
)
def is_power_of_two(n): def is_power_of_two(n):
@ -335,7 +352,6 @@ def is_power_of_two(n):
def to_cute_constant(value: list[int]): def to_cute_constant(value: list[int]):
def _to_cute_constant(value: int): def _to_cute_constant(value: int):
if is_power_of_two(value): if is_power_of_two(value):
return f"_{value}" return f"_{value}"
@ -350,11 +366,11 @@ def to_cute_constant(value: list[int]):
def unique_schedules(impl_configs: list[ImplConfig]): def unique_schedules(impl_configs: list[ImplConfig]):
# Use dict over set for deterministic ordering # Use dict over set for deterministic ordering
return list({ return list(
sch: None {
for impl_config in impl_configs sch: None for impl_config in impl_configs for sch in impl_config.schedules
for sch in impl_config.schedules }.keys()
}.keys()) )
def unsigned_type_with_bitwidth(num_bits): def unsigned_type_with_bitwidth(num_bits):
@ -380,7 +396,7 @@ template_globals = {
"gen_type_sig": generate_type_signature, "gen_type_sig": generate_type_signature,
"unique_schedules": unique_schedules, "unique_schedules": unique_schedules,
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
"gen_type_option_name": generate_type_option_name "gen_type_option_name": generate_type_option_name,
} }
@ -398,23 +414,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
sources = [] sources = []
sources.append(( sources.append(
"machete_mm_dispatch", (
mm_dispatch_template.render(impl_configs=impl_configs), "machete_mm_dispatch",
)) mm_dispatch_template.render(impl_configs=impl_configs),
)
)
prepack_types = [] prepack_types = []
for impl_config in impl_configs: for impl_config in impl_configs:
convert_type = impl_config.types.a \ convert_type = (
if impl_config.types.b_group_scale == DataType.void \ impl_config.types.a
else impl_config.types.b_group_scale if impl_config.types.b_group_scale == DataType.void
else impl_config.types.b_group_scale
)
prepack_types.append( prepack_types.append(
PrepackTypeConfig( PrepackTypeConfig(
a=impl_config.types.a, a=impl_config.types.a,
b_num_bits=VLLMDataTypeSize[impl_config.types.b], b_num_bits=VLLMDataTypeSize[impl_config.types.b],
convert=convert_type, convert=convert_type,
accumulator=impl_config.types.accumulator, accumulator=impl_config.types.accumulator,
)) )
)
def prepacked_type_key(prepack_type: PrepackTypeConfig): def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now, we can just use the first accumulator type seen since # For now, we can just use the first accumulator type seen since
@ -430,10 +451,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
unique_prepack_types.append(prepack_type) unique_prepack_types.append(prepack_type)
prepack_types_seen.add(key) prepack_types_seen.add(key)
sources.append(( sources.append(
"machete_prepack", (
prepack_dispatch_template.render(types=unique_prepack_types, ), "machete_prepack",
)) prepack_dispatch_template.render(
types=unique_prepack_types,
),
)
)
# Split up impls across files # Split up impls across files
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
@ -466,10 +491,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
curr_impl_in_file += len(files_impls[-1][-1].schedules) curr_impl_in_file += len(files_impls[-1][-1].schedules)
for part, file_impls in enumerate(files_impls): for part, file_impls in enumerate(files_impls):
sources.append(( sources.append(
f"machete_mm_impl_part{part+1}", (
mm_impl_template.render(impl_configs=file_impls), f"machete_mm_impl_part{part + 1}",
)) mm_impl_template.render(impl_configs=file_impls),
)
)
return sources return sources
@ -514,8 +541,7 @@ def generate():
# For now we use the same heuristic for all types # For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s # Heuristic is currently tuned for H100s
default_heuristic = [ default_heuristic = [
(cond, ScheduleConfig(*tile_config, (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
**sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items() for cond, tile_config in default_tile_heuristic_config.items()
] ]
@ -541,14 +567,18 @@ def generate():
a_token_scale=DataType.void, a_token_scale=DataType.void,
out=a, out=a,
accumulator=DataType.f32, accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) )
for a in (DataType.f16, DataType.bf16)) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [ impl_configs += [
ImplConfig(x[0], x[1], x[2]) ImplConfig(x[0], x[1], x[2])
for x in zip(GPTQ_kernel_type_configs, for x in zip(
itertools.repeat(get_unique_schedules(default_heuristic)), GPTQ_kernel_type_configs,
itertools.repeat(default_heuristic)) itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
] ]
AWQ_kernel_type_configs = list( AWQ_kernel_type_configs = list(
@ -561,14 +591,18 @@ def generate():
a_token_scale=DataType.void, a_token_scale=DataType.void,
out=a, out=a,
accumulator=DataType.f32, accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8) )
for a in (DataType.f16, DataType.bf16)) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [ impl_configs += [
ImplConfig(x[0], x[1], x[2]) ImplConfig(x[0], x[1], x[2])
for x in zip(AWQ_kernel_type_configs, for x in zip(
itertools.repeat(get_unique_schedules(default_heuristic)), AWQ_kernel_type_configs,
itertools.repeat(default_heuristic)) itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
] ]
# TODO: Support W4A8 when ready # TODO: Support W4A8 when ready

View File

@ -33,8 +33,11 @@ def auto_mock(module, attr, max_mocks=50):
try: try:
# First treat attr as an attr, then as a submodule # First treat attr as an attr, then as a submodule
with patch("importlib.metadata.version", return_value="0.0.0"): with patch("importlib.metadata.version", return_value="0.0.0"):
return getattr(importlib.import_module(module), attr, return getattr(
importlib.import_module(f"{module}.{attr}")) importlib.import_module(module),
attr,
importlib.import_module(f"{module}.{attr}"),
)
except importlib.metadata.PackageNotFoundError as e: except importlib.metadata.PackageNotFoundError as e:
raise e raise e
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
@ -42,7 +45,8 @@ def auto_mock(module, attr, max_mocks=50):
sys.modules[e.name] = PydanticMagicMock() sys.modules[e.name] = PydanticMagicMock()
raise ImportError( raise ImportError(
f"Failed to import {module}.{attr} after mocking {max_mocks} imports") f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
)
latency = auto_mock("vllm.benchmarks", "latency") latency = auto_mock("vllm.benchmarks", "latency")
@ -61,9 +65,7 @@ class MarkdownFormatter(HelpFormatter):
"""Custom formatter that generates markdown for argument groups.""" """Custom formatter that generates markdown for argument groups."""
def __init__(self, prog, starting_heading_level=3): def __init__(self, prog, starting_heading_level=3):
super().__init__(prog, super().__init__(prog, max_help_position=float("inf"), width=float("inf"))
max_help_position=float('inf'),
width=float('inf'))
self._section_heading_prefix = "#" * starting_heading_level self._section_heading_prefix = "#" * starting_heading_level
self._argument_heading_prefix = "#" * (starting_heading_level + 1) self._argument_heading_prefix = "#" * (starting_heading_level + 1)
self._markdown_output = [] self._markdown_output = []
@ -85,23 +87,19 @@ class MarkdownFormatter(HelpFormatter):
def add_arguments(self, actions): def add_arguments(self, actions):
for action in actions: for action in actions:
if (len(action.option_strings) == 0 if len(action.option_strings) == 0 or "--help" in action.option_strings:
or "--help" in action.option_strings):
continue continue
option_strings = f'`{"`, `".join(action.option_strings)}`' option_strings = f"`{'`, `'.join(action.option_strings)}`"
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n" heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
self._markdown_output.append(heading_md) self._markdown_output.append(heading_md)
if choices := action.choices: if choices := action.choices:
choices = f'`{"`, `".join(str(c) for c in choices)}`' choices = f"`{'`, `'.join(str(c) for c in choices)}`"
self._markdown_output.append( self._markdown_output.append(f"Possible choices: {choices}\n\n")
f"Possible choices: {choices}\n\n") elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)):
elif ((metavar := action.metavar) metavar = f"`{'`, `'.join(str(m) for m in metavar)}`"
and isinstance(metavar, (list, tuple))): self._markdown_output.append(f"Possible choices: {metavar}\n\n")
metavar = f'`{"`, `".join(str(m) for m in metavar)}`'
self._markdown_output.append(
f"Possible choices: {metavar}\n\n")
if action.help: if action.help:
self._markdown_output.append(f"{action.help}\n\n") self._markdown_output.append(f"{action.help}\n\n")
@ -143,24 +141,17 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# Create parsers to document # Create parsers to document
parsers = { parsers = {
"engine_args": "engine_args": create_parser(EngineArgs.add_cli_args),
create_parser(EngineArgs.add_cli_args), "async_engine_args": create_parser(
"async_engine_args": AsyncEngineArgs.add_cli_args, async_args_only=True
create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True), ),
"serve": "serve": create_parser(cli_args.make_arg_parser),
create_parser(cli_args.make_arg_parser), "chat": create_parser(ChatCommand.add_cli_args),
"chat": "complete": create_parser(CompleteCommand.add_cli_args),
create_parser(ChatCommand.add_cli_args), "bench_latency": create_parser(latency.add_cli_args),
"complete": "bench_throughput": create_parser(throughput.add_cli_args),
create_parser(CompleteCommand.add_cli_args), "bench_serve": create_parser(serve.add_cli_args),
"bench_latency": "run-batch": create_parser(run_batch.make_arg_parser),
create_parser(latency.add_cli_args),
"bench_throughput":
create_parser(throughput.add_cli_args),
"bench_serve":
create_parser(serve.add_cli_args),
"run-batch":
create_parser(run_batch.make_arg_parser),
} }
# Generate documentation for each parser # Generate documentation for each parser

View File

@ -11,7 +11,7 @@ import regex as re
logger = logging.getLogger("mkdocs") logger = logging.getLogger("mkdocs")
ROOT_DIR = Path(__file__).parent.parent.parent.parent ROOT_DIR = Path(__file__).parent.parent.parent.parent
ROOT_DIR_RELATIVE = '../../../../..' ROOT_DIR_RELATIVE = "../../../../.."
EXAMPLE_DIR = ROOT_DIR / "examples" EXAMPLE_DIR = ROOT_DIR / "examples"
EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples"
@ -36,7 +36,7 @@ def fix_case(text: str) -> str:
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
} }
for pattern, repl in subs.items(): for pattern, repl in subs.items():
text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE) text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE)
return text return text
@ -58,7 +58,8 @@ class Example:
determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file.
determine_title() -> str: Determines the title of the document. determine_title() -> str: Determines the title of the document.
generate() -> str: Generates the documentation content. generate() -> str: Generates the documentation content.
""" # noqa: E501 """ # noqa: E501
path: Path path: Path
category: str = None category: str = None
main_file: Path = field(init=False) main_file: Path = field(init=False)
@ -84,9 +85,8 @@ class Example:
Markdown file found in the directory. Markdown file found in the directory.
Raises: Raises:
IndexError: If no Markdown files are found in the directory. IndexError: If no Markdown files are found in the directory.
""" # noqa: E501 """ # noqa: E501
return self.path if self.path.is_file() else list( return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop()
self.path.glob("*.md")).pop()
def determine_other_files(self) -> list[Path]: def determine_other_files(self) -> list[Path]:
""" """
@ -98,7 +98,7 @@ class Example:
Returns: Returns:
list[Path]: A list of Path objects representing the other files in the directory. list[Path]: A list of Path objects representing the other files in the directory.
""" # noqa: E501 """ # noqa: E501
if self.path.is_file(): if self.path.is_file():
return [] return []
is_other_file = lambda file: file.is_file() and file != self.main_file is_other_file = lambda file: file.is_file() and file != self.main_file
@ -109,9 +109,9 @@ class Example:
# Specify encoding for building on Windows # Specify encoding for building on Windows
with open(self.main_file, encoding="utf-8") as f: with open(self.main_file, encoding="utf-8") as f:
first_line = f.readline().strip() first_line = f.readline().strip()
match = re.match(r'^#\s+(?P<title>.+)$', first_line) match = re.match(r"^#\s+(?P<title>.+)$", first_line)
if match: if match:
return match.group('title') return match.group("title")
return fix_case(self.path.stem.replace("_", " ").title()) return fix_case(self.path.stem.replace("_", " ").title())
def fix_relative_links(self, content: str) -> str: def fix_relative_links(self, content: str) -> str:
@ -127,7 +127,7 @@ class Example:
""" """
# Regex to match markdown links [text](relative_path) # Regex to match markdown links [text](relative_path)
# This matches links that don't start with http, https, ftp, or # # This matches links that don't start with http, https, ftp, or #
link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)' link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)"
def replace_link(match): def replace_link(match):
link_text = match.group(1) link_text = match.group(1)
@ -137,7 +137,7 @@ class Example:
gh_file = (self.main_file.parent / relative_path).resolve() gh_file = (self.main_file.parent / relative_path).resolve()
gh_file = gh_file.relative_to(ROOT_DIR) gh_file = gh_file.relative_to(ROOT_DIR)
return f'[{link_text}](gh-file:{gh_file})' return f"[{link_text}](gh-file:{gh_file})"
return re.sub(link_pattern, replace_link, content) return re.sub(link_pattern, replace_link, content)
@ -150,9 +150,11 @@ class Example:
code_fence = "``````" code_fence = "``````"
if self.is_code: if self.is_code:
content += (f"{code_fence}{self.main_file.suffix[1:]}\n" content += (
f'--8<-- "{self.main_file}"\n' f"{code_fence}{self.main_file.suffix[1:]}\n"
f"{code_fence}\n") f'--8<-- "{self.main_file}"\n'
f"{code_fence}\n"
)
else: else:
with open(self.main_file) as f: with open(self.main_file) as f:
# Skip the title from md snippets as it's been included above # Skip the title from md snippets as it's been included above

View File

@ -7,7 +7,7 @@ from typing import Literal
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa # see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag": if os.getenv("READTHEDOCS_VERSION_TYPE") == "tag":
# remove the warning banner if the version is a tagged release # remove the warning banner if the version is a tagged release
mkdocs_dir = Path(__file__).parent.parent mkdocs_dir = Path(__file__).parent.parent
announcement_path = mkdocs_dir / "overrides/main.html" announcement_path = mkdocs_dir / "overrides/main.html"

View File

@ -25,8 +25,9 @@ from mkdocs.structure.files import Files
from mkdocs.structure.pages import Page from mkdocs.structure.pages import Page
def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, def on_page_markdown(
files: Files) -> str: markdown: str, *, page: Page, config: MkDocsConfig, files: Files
) -> str:
""" """
Custom MkDocs plugin hook to rewrite special GitHub reference links Custom MkDocs plugin hook to rewrite special GitHub reference links
in Markdown. in Markdown.
@ -92,11 +93,11 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
Example: Example:
[My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123)
""" """
url = f'{urls[match.group("type")]}/{match.group("path")}' url = f"{urls[match.group('type')]}/{match.group('path')}"
if fragment := match.group("fragment"): if fragment := match.group("fragment"):
url += f"#{fragment}" url += f"#{fragment}"
return f'[{gh_icon} {match.group("title")}]({url})' return f"[{gh_icon} {match.group('title')}]({url})"
def replace_auto_link(match: re.Match) -> str: def replace_auto_link(match: re.Match) -> str:
""" """

View File

@ -1,54 +0,0 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true

View File

@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi
where = ["."] where = ["."]
include = ["vllm*"] include = ["vllm*"]
[tool.yapfignore]
ignore_patterns = [
".buildkite/**",
"benchmarks/**",
"build/**",
"examples/**",
]
[tool.ruff]
# Allow lines to be as long as 80.
line-length = 80
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"] "vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"] "vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"] "vllm/_version.py" = ["ALL"]
# Python 3.8 typing - skip V0 code # TEMPORARY! These ignores will be fixed forward
"vllm/attention/**/*.py" = ["UP006", "UP035"] ## Line length violations
"vllm/engine/**/*.py" = ["UP006", "UP035"] "csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "tests/compile/piecewise/test_simple.py" = ["E501"]
"vllm/worker/**/*.py" = ["UP006", "UP035"] "tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"]
"tests/entrypoints/conftest.py" = ["E501"]
"tests/entrypoints/openai/test_audio.py" = ["E501"]
"tests/entrypoints/openai/test_chat.py" = ["E501"]
"tests/entrypoints/openai/test_chat_template.py" = ["E501"]
"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"]
"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"]
"tests/entrypoints/openai/test_video.py" = ["E501"]
"tests/entrypoints/openai/test_vision.py" = ["E501"]
"tests/entrypoints/test_chat_utils.py" = ["E501"]
"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"]
"tests/models/language/generation/test_gemma.py" = ["E501"]
"tests/models/language/generation/test_mistral.py" = ["E501"]
"tests/models/multimodal/generation/test_ultravox.py" = ["E501"]
"tests/models/multimodal/generation/test_voxtral.py" = ["E501"]
"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"]
"tests/tool_use/test_tool_choice_required.py" = ["E501"]
"tests/v1/attention/utils.py" = ["E501"]
"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"]
"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"]
"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"]
"tests/v1/logits_processors/test_custom_offline.py" = ["E501"]
"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"]
"vllm/compilation/collective_fusion.py" = ["E501"]
"vllm/compilation/wrapper.py" = ["E501"]
"vllm/config/vllm.py" = ["E501"]
"vllm/distributed/device_communicators/all2all.py" = ["E501"]
"vllm/entrypoints/openai/protocol.py" = ["E501"]
"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"]
"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"]
"vllm/model_executor/models/bailing_moe.py" = ["E501"]
"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"]
"vllm/model_executor/models/llama4_eagle.py" = ["E501"]
"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"]
"vllm/model_executor/models/phi4mm.py" = ["E501"]
"vllm/model_executor/models/qwen3_next.py" = ["E501"]
"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"]
"vllm/v1/attention/backends/mla/common.py" = ["E501"]
"vllm/v1/engine/utils.py" = ["E501"]
"vllm/v1/utils.py" = ["E501"]
"vllm/v1/worker/gpu_model_runner.py" = ["E501"]
## Simplification rules
"tests/distributed/test_expert_placement.py" = ["SIM108"]
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
"tests/kernels/test_onednn.py" = ["SIM108"]
"tests/kernels/utils.py" = ["SIM108"]
"tests/multimodal/test_processing.py" = ["SIM108"]
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
"vllm/distributed/parallel_state.py" = ["SIM108"]
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
"vllm/entrypoints/llm.py" = ["SIM108"]
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
"vllm/utils/__init__.py" = ["SIM108"]
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
"vllm/_custom_ops.py" = ["SIM108"]
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
## Loop variable binding issues
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
## Type annotation modernization and other rules
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
"vllm/attention/layer.py" = ["UP035", "UP006"]
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
"vllm/engine/metrics.py" = ["UP035", "UP006"]
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
## Type comparison issues
"vllm/multimodal/inputs.py" = ["E721"]
# End of temporary ignores
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
@ -87,7 +166,7 @@ select = [
# flake8-simplify # flake8-simplify
"SIM", "SIM",
# isort # isort
# "I", "I",
# flake8-logging-format # flake8-logging-format
"G", "G",
] ]
@ -104,21 +183,15 @@ ignore = [
"UP007", "UP007",
] ]
[tool.ruff.format]
docstring-code-format = true
[tool.mypy] [tool.mypy]
plugins = ['pydantic.mypy'] plugins = ['pydantic.mypy']
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "silent" follow_imports = "silent"
[tool.isort]
skip_glob = [
".buildkite/*",
"benchmarks/*",
"examples/*",
]
use_parentheses = true
skip_gitignore = true
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = [ markers = [
"slow_test", "slow_test",

255
setup.py
View File

@ -34,32 +34,36 @@ logger = logging.getLogger(__name__)
# cannot import envs directly because it depends on vllm, # cannot import envs directly because it depends on vllm,
# which is not installed yet # which is not installed yet
envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py"))
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu": if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
logger.warning( logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
"VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
VLLM_TARGET_DEVICE = "cpu" VLLM_TARGET_DEVICE = "cpu"
elif not (sys.platform.startswith("linux") elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")):
or sys.platform.startswith("darwin")):
logger.warning( logger.warning(
"vLLM only supports Linux platform (including WSL) and MacOS." "vLLM only supports Linux platform (including WSL) and MacOS."
"Building on %s, " "Building on %s, "
"so vLLM may not be able to run correctly", sys.platform) "so vLLM may not be able to run correctly",
sys.platform,
)
VLLM_TARGET_DEVICE = "empty" VLLM_TARGET_DEVICE = "empty"
elif (sys.platform.startswith("linux") and torch.version.cuda is None elif (
and os.getenv("VLLM_TARGET_DEVICE") is None sys.platform.startswith("linux")
and torch.version.hip is None): and torch.version.cuda is None
and os.getenv("VLLM_TARGET_DEVICE") is None
and torch.version.hip is None
):
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
# fallback to cpu # fallback to cpu
VLLM_TARGET_DEVICE = "cpu" VLLM_TARGET_DEVICE = "cpu"
def is_sccache_available() -> bool: def is_sccache_available() -> bool:
return which("sccache") is not None and \ return which("sccache") is not None and not bool(
not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))) int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))
)
def is_ccache_available() -> bool: def is_ccache_available() -> bool:
@ -83,8 +87,7 @@ def is_url_available(url: str) -> bool:
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
super().__init__(name, sources=[], py_limited_api=True, **kwa) super().__init__(name, sources=[], py_limited_api=True, **kwa)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
@ -121,8 +124,8 @@ class cmake_build_ext(build_ext):
if nvcc_threads is not None: if nvcc_threads is not None:
nvcc_threads = int(nvcc_threads) nvcc_threads = int(nvcc_threads)
logger.info( logger.info(
"Using NVCC_THREADS=%d as the number of nvcc threads.", "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
nvcc_threads) )
else: else:
nvcc_threads = 1 nvcc_threads = 1
num_jobs = max(1, num_jobs // nvcc_threads) num_jobs = max(1, num_jobs // nvcc_threads)
@ -146,36 +149,36 @@ class cmake_build_ext(build_ext):
cfg = envs.CMAKE_BUILD_TYPE or default_cfg cfg = envs.CMAKE_BUILD_TYPE or default_cfg
cmake_args = [ cmake_args = [
'-DCMAKE_BUILD_TYPE={}'.format(cfg), "-DCMAKE_BUILD_TYPE={}".format(cfg),
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), "-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE),
] ]
verbose = envs.VERBOSE verbose = envs.VERBOSE
if verbose: if verbose:
cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
if is_sccache_available(): if is_sccache_available():
cmake_args += [ cmake_args += [
'-DCMAKE_C_COMPILER_LAUNCHER=sccache', "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
'-DCMAKE_CXX_COMPILER_LAUNCHER=sccache', "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
'-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache', "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
'-DCMAKE_HIP_COMPILER_LAUNCHER=sccache', "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
] ]
elif is_ccache_available(): elif is_ccache_available():
cmake_args += [ cmake_args += [
'-DCMAKE_C_COMPILER_LAUNCHER=ccache', "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
'-DCMAKE_CXX_COMPILER_LAUNCHER=ccache', "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
'-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache', "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
'-DCMAKE_HIP_COMPILER_LAUNCHER=ccache', "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
] ]
# Pass the python executable to cmake so it can find an exact # Pass the python executable to cmake so it can find an exact
# match. # match.
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)]
# Pass the python path to cmake so it can reuse the build dependencies # Pass the python path to cmake so it can reuse the build dependencies
# on subsequent calls to python. # on subsequent calls to python.
cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))] cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))]
# Override the base directory for FetchContent downloads to $ROOT/.deps # Override the base directory for FetchContent downloads to $ROOT/.deps
# This allows sharing dependencies between profiles, # This allows sharing dependencies between profiles,
@ -183,7 +186,7 @@ class cmake_build_ext(build_ext):
# To override this, set the FETCHCONTENT_BASE_DIR environment variable. # To override this, set the FETCHCONTENT_BASE_DIR environment variable.
fc_base_dir = os.path.join(ROOT_DIR, ".deps") fc_base_dir = os.path.join(ROOT_DIR, ".deps")
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)] cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
# #
# Setup parallelism and build tool # Setup parallelism and build tool
@ -191,35 +194,36 @@ class cmake_build_ext(build_ext):
num_jobs, nvcc_threads = self.compute_num_jobs() num_jobs, nvcc_threads = self.compute_num_jobs()
if nvcc_threads: if nvcc_threads:
cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)] cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
if is_ninja_available(): if is_ninja_available():
build_tool = ['-G', 'Ninja'] build_tool = ["-G", "Ninja"]
cmake_args += [ cmake_args += [
'-DCMAKE_JOB_POOL_COMPILE:STRING=compile', "-DCMAKE_JOB_POOL_COMPILE:STRING=compile",
'-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs), "-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs),
] ]
else: else:
# Default build tool to whatever cmake picks. # Default build tool to whatever cmake picks.
build_tool = [] build_tool = []
# Make sure we use the nvcc from CUDA_HOME # Make sure we use the nvcc from CUDA_HOME
if _is_cuda(): if _is_cuda():
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc'] cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
other_cmake_args = os.environ.get("CMAKE_ARGS") other_cmake_args = os.environ.get("CMAKE_ARGS")
if other_cmake_args: if other_cmake_args:
cmake_args += other_cmake_args.split() cmake_args += other_cmake_args.split()
subprocess.check_call( subprocess.check_call(
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
cwd=self.build_temp) cwd=self.build_temp,
)
def build_extensions(self) -> None: def build_extensions(self) -> None:
# Ensure that CMake is present and working # Ensure that CMake is present and working
try: try:
subprocess.check_output(['cmake', '--version']) subprocess.check_output(["cmake", "--version"])
except OSError as e: except OSError as e:
raise RuntimeError('Cannot find CMake executable') from e raise RuntimeError("Cannot find CMake executable") from e
# Create build directory if it does not exist. # Create build directory if it does not exist.
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
@ -258,13 +262,18 @@ class cmake_build_ext(build_ext):
# CMake appends the extension prefix to the install path, # CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it. # and outdir already contains that prefix, so we need to remove it.
prefix = outdir prefix = outdir
for _ in range(ext.name.count('.')): for _ in range(ext.name.count(".")):
prefix = prefix.parent prefix = prefix.parent
# prefix here should actually be the same for all components # prefix here should actually be the same for all components
install_args = [ install_args = [
"cmake", "--install", ".", "--prefix", prefix, "--component", "cmake",
target_name(ext.name) "--install",
".",
"--prefix",
prefix,
"--component",
target_name(ext.name),
] ]
subprocess.check_call(install_args, cwd=self.build_temp) subprocess.check_call(install_args, cwd=self.build_temp)
@ -275,12 +284,15 @@ class cmake_build_ext(build_ext):
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
# directory so that they can be included in the editable build # directory so that they can be included in the editable build
import glob import glob
files = glob.glob(os.path.join(self.build_lib, "vllm",
"vllm_flash_attn", "**", "*.py"), files = glob.glob(
recursive=True) os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"),
recursive=True,
)
for file in files: for file in files:
dst_file = os.path.join("vllm/vllm_flash_attn", dst_file = os.path.join(
file.split("vllm/vllm_flash_attn/")[-1]) "vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1]
)
print(f"Copying {file} to {dst_file}") print(f"Copying {file} to {dst_file}")
os.makedirs(os.path.dirname(dst_file), exist_ok=True) os.makedirs(os.path.dirname(dst_file), exist_ok=True)
self.copy_file(file, dst_file) self.copy_file(file, dst_file)
@ -290,8 +302,7 @@ class precompiled_build_ext(build_ext):
"""Disables extension building when using precompiled binaries.""" """Disables extension building when using precompiled binaries."""
def run(self) -> None: def run(self) -> None:
assert _is_cuda( assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
def build_extensions(self) -> None: def build_extensions(self) -> None:
print("Skipping build_ext: using precompiled extensions.") print("Skipping build_ext: using precompiled extensions.")
@ -312,9 +323,9 @@ class precompiled_wheel_utils:
wheel_filename = wheel_url_or_path.split("/")[-1] wheel_filename = wheel_url_or_path.split("/")[-1]
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
wheel_path = os.path.join(temp_dir, wheel_filename) wheel_path = os.path.join(temp_dir, wheel_filename)
print(f"Downloading wheel from {wheel_url_or_path} " print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}")
f"to {wheel_path}")
from urllib.request import urlretrieve from urllib.request import urlretrieve
urlretrieve(wheel_url_or_path, filename=wheel_path) urlretrieve(wheel_url_or_path, filename=wheel_path)
else: else:
wheel_path = wheel_url_or_path wheel_path = wheel_url_or_path
@ -335,25 +346,29 @@ class precompiled_wheel_utils:
] ]
compiled_regex = re.compile( compiled_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
file_members = list( file_members = list(
filter(lambda x: x.filename in files_to_copy, filter(lambda x: x.filename in files_to_copy, wheel.filelist)
wheel.filelist)) )
file_members += list( file_members += list(
filter(lambda x: compiled_regex.match(x.filename), filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
wheel.filelist)) )
for file in file_members: for file in file_members:
print(f"[extract] {file.filename}") print(f"[extract] {file.filename}")
target_path = os.path.join(".", file.filename) target_path = os.path.join(".", file.filename)
os.makedirs(os.path.dirname(target_path), exist_ok=True) os.makedirs(os.path.dirname(target_path), exist_ok=True)
with wheel.open(file.filename) as src, open( with (
target_path, "wb") as dst: wheel.open(file.filename) as src,
open(target_path, "wb") as dst,
):
shutil.copyfileobj(src, dst) shutil.copyfileobj(src, dst)
pkg = os.path.dirname(file.filename).replace("/", ".") pkg = os.path.dirname(file.filename).replace("/", ".")
package_data_patch.setdefault(pkg, []).append( package_data_patch.setdefault(pkg, []).append(
os.path.basename(file.filename)) os.path.basename(file.filename)
)
return package_data_patch return package_data_patch
finally: finally:
@ -369,10 +384,13 @@ class precompiled_wheel_utils:
try: try:
# Get the latest commit hash of the upstream main branch. # Get the latest commit hash of the upstream main branch.
resp_json = subprocess.check_output([ resp_json = subprocess.check_output(
"curl", "-s", [
"https://api.github.com/repos/vllm-project/vllm/commits/main" "curl",
]).decode("utf-8") "-s",
"https://api.github.com/repos/vllm-project/vllm/commits/main",
]
).decode("utf-8")
upstream_main_commit = json.loads(resp_json)["sha"] upstream_main_commit = json.loads(resp_json)["sha"]
# In Docker build context, .git may be immutable or missing. # In Docker build context, .git may be immutable or missing.
@ -382,25 +400,32 @@ class precompiled_wheel_utils:
# Check if the upstream_main_commit exists in the local repo # Check if the upstream_main_commit exists in the local repo
try: try:
subprocess.check_output( subprocess.check_output(
["git", "cat-file", "-e", f"{upstream_main_commit}"]) ["git", "cat-file", "-e", f"{upstream_main_commit}"]
)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
# If not present, fetch it from the remote repository. # If not present, fetch it from the remote repository.
# Note that this does not update any local branches, # Note that this does not update any local branches,
# but ensures that this commit ref and its history are # but ensures that this commit ref and its history are
# available in our local repo. # available in our local repo.
subprocess.check_call([ subprocess.check_call(
"git", "fetch", "https://github.com/vllm-project/vllm", ["git", "fetch", "https://github.com/vllm-project/vllm", "main"]
"main" )
])
# Then get the commit hash of the current branch that is the same as # Then get the commit hash of the current branch that is the same as
# the upstream main commit. # the upstream main commit.
current_branch = subprocess.check_output( current_branch = (
["git", "branch", "--show-current"]).decode("utf-8").strip() subprocess.check_output(["git", "branch", "--show-current"])
.decode("utf-8")
.strip()
)
base_commit = subprocess.check_output([ base_commit = (
"git", "merge-base", f"{upstream_main_commit}", current_branch subprocess.check_output(
]).decode("utf-8").strip() ["git", "merge-base", f"{upstream_main_commit}", current_branch]
)
.decode("utf-8")
.strip()
)
return base_commit return base_commit
except ValueError as err: except ValueError as err:
raise ValueError(err) from None raise ValueError(err) from None
@ -408,7 +433,9 @@ class precompiled_wheel_utils:
logger.warning( logger.warning(
"Failed to get the base commit in the main branch. " "Failed to get the base commit in the main branch. "
"Using the nightly wheel. The libraries in this " "Using the nightly wheel. The libraries in this "
"wheel may not be compatible with your dev branch: %s", err) "wheel may not be compatible with your dev branch: %s",
err,
)
return "nightly" return "nightly"
@ -418,12 +445,13 @@ def _no_device() -> bool:
def _is_cuda() -> bool: def _is_cuda() -> bool:
has_cuda = torch.version.cuda is not None has_cuda = torch.version.cuda is not None
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()) return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()
def _is_hip() -> bool: def _is_hip() -> bool:
return (VLLM_TARGET_DEVICE == "cuda" return (
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm"
) and torch.version.hip is not None
def _is_tpu() -> bool: def _is_tpu() -> bool:
@ -462,8 +490,12 @@ def get_rocm_version():
minor = ctypes.c_uint32() minor = ctypes.c_uint32()
patch = ctypes.c_uint32() patch = ctypes.c_uint32()
if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor), if (
ctypes.byref(patch)) == 0): get_rocm_core_version(
ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)
)
== 0
):
return f"{major.value}.{minor.value}.{patch.value}" return f"{major.value}.{minor.value}.{patch.value}"
return None return None
except Exception: except Exception:
@ -476,8 +508,9 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
""" """
assert CUDA_HOME is not None, "CUDA_HOME is not set" assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], nvcc_output = subprocess.check_output(
universal_newlines=True) [CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True
)
output = nvcc_output.split() output = nvcc_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0]) nvcc_cuda_version = parse(output[release_idx].split(",")[0])
@ -489,14 +522,20 @@ def get_gaudi_sw_version():
Returns the driver version. Returns the driver version.
""" """
# Enable console printing for `hl-smi` check # Enable console printing for `hl-smi` check
output = subprocess.run("hl-smi", output = subprocess.run(
shell=True, "hl-smi",
text=True, shell=True,
capture_output=True, text=True,
env={"ENABLE_CONSOLE": "true"}) capture_output=True,
env={"ENABLE_CONSOLE": "true"},
)
if output.returncode == 0 and output.stdout: if output.returncode == 0 and output.stdout:
return output.stdout.split("\n")[2].replace( return (
" ", "").split(":")[1][:-1].split("-")[0] output.stdout.split("\n")[2]
.replace(" ", "")
.split(":")[1][:-1]
.split("-")[0]
)
return "0.0.0" # when hl-smi is not available return "0.0.0" # when hl-smi is not available
@ -546,8 +585,11 @@ def get_requirements() -> list[str]:
for line in requirements: for line in requirements:
if line.startswith("-r "): if line.startswith("-r "):
resolved_requirements += _read_requirements(line.split()[1]) resolved_requirements += _read_requirements(line.split()[1])
elif not line.startswith("--") and not line.startswith( elif (
"#") and line.strip() != "": not line.startswith("--")
and not line.startswith("#")
and line.strip() != ""
):
resolved_requirements.append(line) resolved_requirements.append(line)
return resolved_requirements return resolved_requirements
@ -558,7 +600,7 @@ def get_requirements() -> list[str]:
cuda_major, cuda_minor = torch.version.cuda.split(".") cuda_major, cuda_minor = torch.version.cuda.split(".")
modified_requirements = [] modified_requirements = []
for req in requirements: for req in requirements:
if ("vllm-flash-attn" in req and cuda_major != "12"): if "vllm-flash-attn" in req and cuda_major != "12":
# vllm-flash-attn is built only for CUDA 12.x. # vllm-flash-attn is built only for CUDA 12.x.
# Skip for other versions. # Skip for other versions.
continue continue
@ -573,8 +615,7 @@ def get_requirements() -> list[str]:
elif _is_xpu(): elif _is_xpu():
requirements = _read_requirements("xpu.txt") requirements = _read_requirements("xpu.txt")
else: else:
raise ValueError( raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.")
"Unsupported platform, please use CUDA, ROCm, or CPU.")
return requirements return requirements
@ -590,14 +631,13 @@ if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
# FA3 requires CUDA 12.3 or later # FA3 requires CUDA 12.3 or later
ext_modules.append( ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
# Optional since this doesn't get built (produce an .so file) when # Optional since this doesn't get built (produce an .so file) when
# not targeting a hopper system # not targeting a hopper system
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
ext_modules.append( ext_modules.append(
CMakeExtension(name="vllm._flashmla_C", optional=True)) CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
ext_modules.append( )
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops(): if _build_custom_ops():
@ -619,6 +659,7 @@ if envs.VLLM_USE_PRECOMPILED:
wheel_url = wheel_location wheel_url = wheel_location
else: else:
import platform import platform
arch = platform.machine() arch = platform.machine()
if arch == "x86_64": if arch == "x86_64":
wheel_tag = "manylinux1_x86_64" wheel_tag = "manylinux1_x86_64"
@ -628,8 +669,11 @@ if envs.VLLM_USE_PRECOMPILED:
raise ValueError(f"Unsupported architecture: {arch}") raise ValueError(f"Unsupported architecture: {arch}")
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" nightly_wheel_url = (
f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
)
from urllib.request import urlopen from urllib.request import urlopen
try: try:
with urlopen(wheel_url) as resp: with urlopen(wheel_url) as resp:
if resp.status != 200: if resp.status != 200:
@ -638,8 +682,7 @@ if envs.VLLM_USE_PRECOMPILED:
print(f"[warn] Falling back to nightly wheel: {e}") print(f"[warn] Falling back to nightly wheel: {e}")
wheel_url = nightly_wheel_url wheel_url = nightly_wheel_url
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url)
wheel_url)
for pkg, files in patch.items(): for pkg, files in patch.items():
package_data.setdefault(pkg, []).extend(files) package_data.setdefault(pkg, []).extend(files)
@ -650,8 +693,9 @@ if not ext_modules:
cmdclass = {} cmdclass = {}
else: else:
cmdclass = { cmdclass = {
"build_ext": "build_ext": precompiled_build_ext
precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext if envs.VLLM_USE_PRECOMPILED
else cmake_build_ext
} }
setup( setup(
@ -664,8 +708,11 @@ setup(
"tensorizer": ["tensorizer==2.10.1"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"], "fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"], "runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
"audio": ["librosa", "soundfile", "audio": [
"mistral_common[audio]"], # Required for audio processing "librosa",
"soundfile",
"mistral_common[audio]",
], # Required for audio processing
"video": [], # Kept for backwards compatibility "video": [], # Kept for backwards compatibility
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
"flashinfer": ["flashinfer-python==0.3.1"], "flashinfer": ["flashinfer-python==0.3.1"],

View File

@ -4,6 +4,7 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`. Run `pytest tests/basic_correctness/test_basic_correctness.py`.
""" """
import os import os
import weakref import weakref
from unittest.mock import Mock from unittest.mock import Mock
@ -37,16 +38,21 @@ def test_vllm_gc_ed():
def _fix_prompt_embed_outputs( def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, vllm_outputs: list[tuple[list[int], str]],
example_prompts: list[str]) -> list[tuple[list[int], str]]: hf_model: HfRunner,
example_prompts: list[str],
) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = [] fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip( for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts), vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
example_prompts): ):
hf_input_ids = hf_input["input_ids"].tolist()[0] hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append( fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):], (
prompt + vllm_output[1])) hf_input_ids + vllm_output[0][len(hf_input_ids) :],
prompt + vllm_output[1],
)
)
return fixed_vllm_outputs return fixed_vllm_outputs
@ -69,8 +75,7 @@ def test_models(
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
if backend == "XFORMERS" and model == "google/gemma-2-2b-it": if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip( pytest.skip(f"{backend} does not support gemma2 with full context length.")
f"{backend} does not support gemma2 with full context length.")
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend) m.setenv("VLLM_ATTENTION_BACKEND", backend)
@ -78,34 +83,35 @@ def test_models(
# 5042 tokens for gemma2 # 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096 # gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window # we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join( prompt = (
str(i) for i in range(1024)) + " are:" "The following numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt] example_prompts = [prompt]
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds: if enable_prompt_embeds:
with torch.no_grad(): with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings( prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
example_prompts)
with VllmRunner( with VllmRunner(
model, model,
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
distributed_executor_backend=model_executor, distributed_executor_backend=model_executor,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs( vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts) vllm_outputs, hf_model, example_prompts
)
else: else:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -117,21 +123,18 @@ def test_models(
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, " "model, distributed_executor_backend, attention_backend, test_suite, extra_env",
"test_suite, extra_env", [ [
("distilbert/distilgpt2", "ray", "", "L4", {}), ("distilbert/distilgpt2", "ray", "", "L4", {}),
("distilbert/distilgpt2", "mp", "", "L4", {}), ("distilbert/distilgpt2", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "L4", { ("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
"VLLM_SLEEP_WHEN_IDLE": "1" ("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
}),
("distilbert/distilgpt2", "mp", "", "L4", {
"VLLM_SLEEP_WHEN_IDLE": "1"
}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}),
("distilbert/distilgpt2", "mp", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}),
]) ],
)
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@ -149,11 +152,14 @@ def test_models_distributed(
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context: with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa if (
model == "meta-llama/Llama-3.2-1B-Instruct"
and distributed_executor_backend == "ray"
and attention_backend == ""
and test_suite == "L4"
): # noqa
if enable_prompt_embeds: if enable_prompt_embeds:
pytest.skip( pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
@ -175,30 +181,26 @@ def test_models_distributed(
# will hurt multiprocessing backend with fork method # will hurt multiprocessing backend with fork method
# (the default method). # (the default method).
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad(): with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings( prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
example_prompts) vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs( vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts) vllm_outputs, hf_model, example_prompts
hf_outputs = hf_model.generate_greedy( )
example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
else: else:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -209,27 +211,23 @@ def test_models_distributed(
def test_failed_model_execution(vllm_runner, monkeypatch) -> None: def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
from vllm.envs import VLLM_USE_V1 from vllm.envs import VLLM_USE_V1
if not VLLM_USE_V1: if not VLLM_USE_V1:
pytest.skip("Skipping V0 test, dump input not supported") pytest.skip("Skipping V0 test, dump input not supported")
# Needed to mock an error in the same process # Needed to mock an error in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1): if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
v1_test_failed_model_execution(vllm_model) v1_test_failed_model_execution(vllm_model)
def v1_test_failed_model_execution(vllm_model): def v1_test_failed_model_execution(vllm_model):
engine = vllm_model.llm.llm_engine engine = vllm_model.llm.llm_engine
mocked_execute_model = Mock( mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
side_effect=RuntimeError("Mocked Critical Error")) engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
with pytest.raises(RuntimeError) as exc_info: with pytest.raises(RuntimeError) as exc_info:
prompts = [ prompts = [

View File

@ -5,5 +5,6 @@ from ..utils import compare_two_settings
def test_cpu_offload(): def test_cpu_offload():
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [], compare_two_settings(
["--cpu-offload-gb", "1"]) "meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"]
)

View File

@ -23,13 +23,13 @@ def test_python_error():
tensors = [] tensors = []
with allocator.use_memory_pool(): with allocator.use_memory_pool():
# allocate 70% of the total memory # allocate 70% of the total memory
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
tensors.append(x) tensors.append(x)
# release the memory # release the memory
allocator.sleep() allocator.sleep()
# allocate more memory than the total memory # allocate more memory than the total memory
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
tensors.append(y) tensors.append(y)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# when the allocator is woken up, it should raise an error # when the allocator is woken up, it should raise an error
@ -41,17 +41,17 @@ def test_python_error():
def test_basic_cumem(): def test_basic_cumem():
# some tensors from default memory pool # some tensors from default memory pool
shape = (1024, 1024) shape = (1024, 1024)
x = torch.empty(shape, device='cuda') x = torch.empty(shape, device="cuda")
x.zero_() x.zero_()
# some tensors from custom memory pool # some tensors from custom memory pool
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
# custom memory pool # custom memory pool
y = torch.empty(shape, device='cuda') y = torch.empty(shape, device="cuda")
y.zero_() y.zero_()
y += 1 y += 1
z = torch.empty(shape, device='cuda') z = torch.empty(shape, device="cuda")
z.zero_() z.zero_()
z += 2 z += 2
@ -74,16 +74,16 @@ def test_basic_cumem():
def test_cumem_with_cudagraph(): def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
weight = torch.eye(1024, device='cuda') weight = torch.eye(1024, device="cuda")
with allocator.use_memory_pool(tag="discard"): with allocator.use_memory_pool(tag="discard"):
cache = torch.empty(1024, 1024, device='cuda') cache = torch.empty(1024, 1024, device="cuda")
def model(x): def model(x):
out = x @ weight out = x @ weight
cache[:out.size(0)].copy_(out) cache[: out.size(0)].copy_(out)
return out + 1 return out + 1
x = torch.empty(128, 1024, device='cuda') x = torch.empty(128, 1024, device="cuda")
# warmup # warmup
model(x) model(x)
@ -109,7 +109,7 @@ def test_cumem_with_cudagraph():
model_graph.replay() model_graph.replay()
# cache content is as expected # cache content is as expected
assert torch.allclose(x, cache[:x.size(0)]) assert torch.allclose(x, cache[: x.size(0)])
# output content is as expected # output content is as expected
assert torch.allclose(y, x + 1) assert torch.allclose(y, x + 1)
@ -123,7 +123,8 @@ def test_cumem_with_cudagraph():
("meta-llama/Llama-3.2-1B", True), ("meta-llama/Llama-3.2-1B", True),
# sleep mode with pytorch checkpoint # sleep mode with pytorch checkpoint
("facebook/opt-125m", True), ("facebook/opt-125m", True),
]) ],
)
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
with monkeypatch.context() as m: with monkeypatch.context() as m:
assert use_v1 assert use_v1

View File

@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_latency(): def test_bench_latency():
command = [ command = [
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", "vllm",
"--output-len", "1", "--enforce-eager", "--load-format", "dummy" "bench",
"latency",
"--model",
MODEL_NAME,
"--input-len",
"32",
"--output-len",
"1",
"--enforce-eager",
"--load-format",
"dummy",
] ]
result = subprocess.run(command, capture_output=True, text=True) result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout) print(result.stdout)

View File

@ -7,8 +7,11 @@ import numpy as np
import pytest import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, from vllm.benchmarks.datasets import (
SampleRequest) RandomDataset,
RandomMultiModalDataset,
SampleRequest,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -27,11 +30,9 @@ class Params(NamedTuple):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def random_dataset_params() -> Params: def random_dataset_params() -> Params:
return Params(num_requests=16, return Params(
prefix_len=7, num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
range_ratio=0.3, )
input_len=50,
output_len=20)
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
return (req.prompt, req.prompt_len, req.expected_output_len) return (req.prompt, req.prompt_len, req.expected_output_len)
def _collect_samples(dataset: RandomDataset, def _collect_samples(
tokenizer: PreTrainedTokenizerBase, dataset: RandomDataset,
num_requests: int = 16, tokenizer: PreTrainedTokenizerBase,
prefix_len: int = 7, num_requests: int = 16,
range_ratio: float = 0.3, prefix_len: int = 7,
input_len: int = 50, range_ratio: float = 0.3,
output_len: int = 20) -> list[tuple[str, int, int]]: input_len: int = 50,
output_len: int = 20,
) -> list[tuple[str, int, int]]:
samples = dataset.sample( samples = dataset.sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=num_requests, num_requests=num_requests,
@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_dataset_same_seed( def test_random_dataset_same_seed(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
random_dataset_params: Params) -> None: ) -> None:
"""Same seed should yield identical outputs, even if global RNGs change. """Same seed should yield identical outputs, even if global RNGs change.
This guards against accidental reliance on Python's random or np.random This guards against accidental reliance on Python's random or np.random
@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
common_seed = 123 common_seed = 123
dataset_a = RandomDataset(random_seed=common_seed) dataset_a = RandomDataset(random_seed=common_seed)
dataset_b = RandomDataset(random_seed=common_seed) dataset_b = RandomDataset(random_seed=common_seed)
a = _collect_samples(dataset_a, a = _collect_samples(
hf_tokenizer, dataset_a,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
# Perturb global RNG state to ensure isolation # Perturb global RNG state to ensure isolation
random.seed(999) random.seed(999)
@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
np.random.seed(888) np.random.seed(888)
_ = [np.random.random() for _ in range(100)] _ = [np.random.random() for _ in range(100)]
b = _collect_samples(dataset_b, b = _collect_samples(
hf_tokenizer, dataset_b,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
assert a == b assert a == b
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_dataset_different_seeds( def test_random_dataset_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
random_dataset_params: Params) -> None: ) -> None:
"""Different seeds should change outputs with overwhelming likelihood.""" """Different seeds should change outputs with overwhelming likelihood."""
p = random_dataset_params p = random_dataset_params
seed_a = 0 seed_a = 0
dataset_a = RandomDataset(random_seed=seed_a) dataset_a = RandomDataset(random_seed=seed_a)
a = _collect_samples(dataset_a, a = _collect_samples(
hf_tokenizer, dataset_a,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
seed_b = 999 seed_b = 999
dataset_b = RandomDataset(random_seed=seed_b) dataset_b = RandomDataset(random_seed=seed_b)
# Perturb global RNG with same seed as dataset_a to ensure isolation # Perturb global RNG with same seed as dataset_a to ensure isolation
random.seed(seed_a) random.seed(seed_a)
np.random.seed(seed_a) np.random.seed(seed_a)
b = _collect_samples(dataset_b, b = _collect_samples(
hf_tokenizer, dataset_b,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
assert a != b assert a != b
@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
# RandomMultiModalDataset tests # RandomMultiModalDataset tests
# ----------------------------- # -----------------------------
def _mm_fingerprint_sample( def _mm_fingerprint_sample(
req: SampleRequest, req: SampleRequest,
) -> tuple[str, int, int, int, list[str]]: ) -> tuple[str, int, int, int, list[str]]:
@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
item_prefixes.append(f"video:{url[:22]}") item_prefixes.append(f"video:{url[:22]}")
else: else:
item_prefixes.append("unknown:") item_prefixes.append("unknown:")
return (req.prompt, req.prompt_len, req.expected_output_len, len(items), return (
item_prefixes) req.prompt,
req.prompt_len,
req.expected_output_len,
len(items),
item_prefixes,
)
def _collect_mm_samples( def _collect_mm_samples(
@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
fb = [_mm_fingerprint_sample(s) for s in b] fb = [_mm_fingerprint_sample(s) for s in b]
assert fa != fb assert fa != fb
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_mm_respects_limits( def test_random_mm_respects_limits(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase,
@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
for s in samples: for s in samples:
assert s.multi_modal_data == [] assert s.multi_modal_data == []
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_mm_num_items_per_prompt( def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0) ds = RandomMultiModalDataset(random_seed=0)
# Fixed number of images per prompt # Fixed number of images per prompt
# set num_mm_items_range_ratio to 0.0 # set num_mm_items_range_ratio to 0.0
@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
def test_random_mm_bucket_config_not_mutated( def test_random_mm_bucket_config_not_mutated(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase,
) -> None: ) -> None:
ds = RandomMultiModalDataset(random_seed=0) ds = RandomMultiModalDataset(random_seed=0)
# This bucket config is not normalized to sum to 1 # This bucket config is not normalized to sum to 1
# and has more buckets than requested images # and has more buckets than requested images
@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
# Ensure the original dict content is unchanged # Ensure the original dict content is unchanged
assert original == snapshot assert original == snapshot
# Vary number of mm items per prompt # Vary number of mm items per prompt
# set num_mm_items_range_ratio to 0.5 # set num_mm_items_range_ratio to 0.5
samples_varying_items = _collect_mm_samples( samples_varying_items = _collect_mm_samples(

View File

@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@ -46,6 +44,7 @@ def test_bench_serve(server):
assert result.returncode == 0, f"Benchmark failed: {result.stderr}" assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_serve_chat(server): def test_bench_serve_chat(server):
command = [ command = [

View File

@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_throughput(): def test_bench_throughput():
command = [ command = [
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", "vllm",
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" "bench",
"throughput",
"--model",
MODEL_NAME,
"--input-len",
"32",
"--output-len",
"1",
"--enforce-eager",
"--load-format",
"dummy",
] ]
result = subprocess.run(command, capture_output=True, text=True) result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout) print(result.stdout)

View File

@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
and then immediately invoke it. and then immediately invoke it.
""" """
def __init__(self, pass_cls: type[VllmInductorPass], def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
vllm_config: VllmConfig):
self.pass_cls = pass_cls self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
@ -45,20 +44,18 @@ class TestBackend:
Inductor config is default-initialized from VllmConfig.CompilationConfig. Inductor config is default-initialized from VllmConfig.CompilationConfig.
""" """
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
None]]):
self.custom_passes = list(passes) self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config compile_config = get_current_vllm_config().compilation_config
self.inductor_config = compile_config.inductor_compile_config self.inductor_config = compile_config.inductor_compile_config
self.inductor_config['force_disable_caches'] = True self.inductor_config["force_disable_caches"] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs): def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph) self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs, return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
config_patches=self.inductor_config)
@with_pattern_match_debug @with_pattern_match_debug
def post_pass(self, graph: fx.Graph): def post_pass(self, graph: fx.Graph):
@ -82,8 +79,7 @@ class TestBackend:
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}" assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced: if fully_replaced:
assert num_post == 0, \ assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]): def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops: for op in ops:

View File

@ -38,8 +38,8 @@ test_params_full_cudagraph = []
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends: for mla_backend in MLA_backends:
test_params_full_cudagraph.append( test_params_full_cudagraph.append(
pytest.param( pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))) )
# Qwen/Qwen2-1.5B-Instruct with other backends # Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [ other_backend_configs = [
@ -47,7 +47,8 @@ other_backend_configs = [
] ]
for backend_config in other_backend_configs: for backend_config in other_backend_configs:
test_params_full_cudagraph.append( test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))) pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
)
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
@ -55,8 +56,10 @@ def llm_pair(request):
model, backend_config = request.param model, backend_config = request.param
# Dynamically skip test if GPU capability is not met # Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ if (
!= current_platform.get_device_capability(): backend_config.specific_gpu_arch
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
):
if backend_config.specific_gpu_arch == (9, 0): if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0): elif backend_config.specific_gpu_arch == (10, 0):
@ -76,8 +79,7 @@ def llm_pair(request):
trust_remote_code=True, trust_remote_code=True,
max_model_len=1024, max_model_len=1024,
max_num_seqs=128, max_num_seqs=128,
compilation_config=\ compilation_config=CompilationConfig(**backend_config.comp_config),
CompilationConfig(**backend_config.comp_config),
generation_config="vllm", generation_config="vllm",
seed=42, seed=42,
) )
@ -113,20 +115,22 @@ class TestFullCUDAGraph:
meaning there would be multiple LLM instances hogging memory simultaneously. meaning there would be multiple LLM instances hogging memory simultaneously.
""" """
@pytest.mark.parametrize(("batch_size", "max_tokens"), [ @pytest.mark.parametrize(
(1, 10), ("batch_size", "max_tokens"),
(7, 10), [
(16, 10), (1, 10),
(25, 10), (7, 10),
(32, 10), (16, 10),
(45, 10), (25, 10),
(64, 10), (32, 10),
(123, 10), (45, 10),
(8, 5), (64, 10),
(8, 30), (123, 10),
]) (8, 5),
def test_full_cudagraph(self, batch_size, max_tokens, (8, 30),
llm_pair: tuple[LLM, LLM]): ],
)
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
""" """
Test various batch sizes and max_tokens to ensure that the Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too. full cudagraph compilation works for padded cases too.
@ -137,26 +141,34 @@ class TestFullCUDAGraph:
prompts = ["the quick brown fox"] * batch_size prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity # Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes. # that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(
max_tokens=max_tokens, temperature=0.0, max_tokens=max_tokens, top_p=1.0
top_p=1.0) )
piecewise_responses = piecewise_llm.generate(prompts, sampling_params) piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
# Check that all responses are the same # Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses, for piecewise_res, full_res in zip(piecewise_responses, full_responses):
full_responses): assert (
assert piecewise_res.outputs[0].text.lower() == \ piecewise_res.outputs[0].text.lower()
full_res.outputs[0].text.lower() == full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
with temporary_environ({ with (
"VLLM_USE_V1": "1", temporary_environ(
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION" {
# Flex_Attention is not supported with full cuda graph "VLLM_USE_V1": "1",
}), pytest.raises(RuntimeError): "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
LLM(model="Qwen/Qwen2-1.5B-Instruct", # Flex_Attention is not supported with full cuda graph
compilation_config=CompilationConfig(cudagraph_mode="FULL")) }
),
pytest.raises(RuntimeError),
):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
)

View File

@ -10,10 +10,14 @@ from torch import nn
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile, from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
support_torch_compile) from vllm.config import (
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, CompilationConfig,
VllmConfig, set_current_vllm_config) CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
@ -27,12 +31,7 @@ RANDOM_SEED = 0
@support_torch_compile @support_torch_compile
class ParentModel(nn.Module): class ParentModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -40,7 +39,6 @@ class ParentModel(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None: def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__() super().__init__()
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
@ -51,17 +49,21 @@ class Attention(nn.Module):
nn.init.xavier_normal_( nn.init.xavier_normal_(
self.pre_attn.weight.data, self.pre_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED), generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001) gain=0.001,
)
nn.init.xavier_normal_( nn.init.xavier_normal_(
self.post_attn.weight.data, self.post_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED), generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001) gain=0.001,
)
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor: def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float() x_f32 = x.float()
return (x_f32 * torch.rsqrt( return (
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) * x_f32
self.rms_norm_weight).to(x.dtype) * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
* self.rms_norm_weight
).to(x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x) x = self.pre_attn(x)
@ -76,14 +78,15 @@ class Attention(nn.Module):
@support_torch_compile @support_torch_compile
class CompiledAttention(nn.Module): class CompiledAttention(nn.Module):
def __init__(
def __init__(self, self,
*, *,
mlp_size: int, mlp_size: int,
hidden_size: int, hidden_size: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = '', prefix: str = "",
**kwargs) -> None: **kwargs,
) -> None:
super().__init__() super().__init__()
self.attn = Attention(mlp_size, hidden_size) self.attn = Attention(mlp_size, hidden_size)
@ -93,21 +96,21 @@ class CompiledAttention(nn.Module):
@support_torch_compile @support_torch_compile
class CompiledAttentionTwo(CompiledAttention): class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x return self.attn(x) + x
@ignore_torch_compile @ignore_torch_compile
class SimpleModelWithTwoGraphs(ParentModel): class SimpleModelWithTwoGraphs(ParentModel):
def __init__(
def __init__(self, self,
*, *,
mlp_size: int, mlp_size: int,
hidden_size: int, hidden_size: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = '', prefix: str = "",
**kwargs) -> None: **kwargs,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
# Test will fail without set_model_tag here with error: # Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)" # "ValueError: too many values to unpack (expected 3)"
@ -142,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel):
@torch.inference_mode @torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, def run_model(
cudagraph_runtime_mode: CUDAGraphMode): vllm_config: VllmConfig,
model: nn.Module,
inputs: torch.Tensor,
cudagraph_runtime_mode: CUDAGraphMode,
):
with set_forward_context({}, vllm_config=vllm_config): with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE # warmup for the model with cudagraph_mode NONE
model(inputs) model(inputs)
# simulate cudagraphs capturing # simulate cudagraphs capturing
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(inputs[:2]) model(inputs[:2])
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(inputs[:1]) model(inputs[:1])
# simulate cudagraphs replay # simulate cudagraphs replay
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(inputs[:2]) output = model(inputs[:2])
output = output.cpu() output = output.cpu()
@ -178,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal():
outputs = [] outputs = []
# piecewise compile # piecewise compile
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=True, level=CompilationLevel.PIECEWISE,
splitting_ops=["silly.attention"], use_cudagraph=True,
cudagraph_capture_sizes=[1, 2], splitting_ops=["silly.attention"],
)) cudagraph_capture_sizes=[1, 2],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, model = (
hidden_size=HIDDEN_SIZE, SimpleModelWithTwoGraphs(
vllm_config=vllm_config, mlp_size=MLP_SIZE,
prefix='').eval().cuda() hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
# Pre-allocate memory for CUDAGraph which expects # Pre-allocate memory for CUDAGraph which expects
# static tensor addresses # static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6, num_piecewise_graphs_seen=6,
# attn_one, attn_two each has 3 piecewise graphs # attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each # (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=4,
# attn_one, attn_two has pre attn and post attn each, total=4 # attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8, num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
outputs.append( outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# no compile or cudagraph # no compile or cudagraph
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.NO_COMPILATION, )) compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION,
)
)
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_runtime_mode = CUDAGraphMode.NONE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, model = (
hidden_size=HIDDEN_SIZE, SimpleModelWithTwoGraphs(
vllm_config=vllm_config, mlp_size=MLP_SIZE,
prefix='').eval().cuda() hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=0, num_graphs_seen=0,
num_piecewise_graphs_seen=0, num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0, num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0, num_backend_compilations=0,
num_cudagraph_captured=0, num_cudagraph_captured=0,
): ):
outputs.append( outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# piecewise compile without CUDA graph # piecewise compile without CUDA graph
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=False, level=CompilationLevel.PIECEWISE,
splitting_ops=["silly.attention"], use_cudagraph=False,
)) splitting_ops=["silly.attention"],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, model = (
hidden_size=HIDDEN_SIZE, SimpleModelWithTwoGraphs(
vllm_config=vllm_config, mlp_size=MLP_SIZE,
prefix='').eval().cuda() hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=2, num_graphs_seen=2,
num_piecewise_graphs_seen=6, num_piecewise_graphs_seen=6,
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4, num_backend_compilations=4,
num_cudagraph_captured=0, # no cudagraph captured num_cudagraph_captured=0, # no cudagraph captured
): ):
outputs.append( outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# Generally don't expect outputs with and without inductor # Generally don't expect outputs with and without inductor
# to be bitwise equivalent # to be bitwise equivalent

View File

@ -11,8 +11,13 @@ from torch import nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import (
VllmConfig, set_current_vllm_config) CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.envs import VLLM_USE_V1 from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
@ -23,12 +28,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
@support_torch_compile @support_torch_compile
class SillyModel(nn.Module): class SillyModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -60,53 +60,65 @@ def _run_simple_model(
expected_num_backend_compilations, expected_num_backend_compilations,
expected_num_cudagraph_captured, expected_num_cudagraph_captured,
): ):
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=True, level=CompilationLevel.PIECEWISE,
use_inductor=use_inductor, use_cudagraph=True,
splitting_ops=splitting_ops, use_inductor=use_inductor,
use_inductor_graph_partition=use_inductor_graph_partition, splitting_ops=splitting_ops,
cudagraph_copy_inputs=True, use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_capture_sizes=[1, 2], cudagraph_copy_inputs=True,
)) cudagraph_capture_sizes=[1, 2],
)
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='') model = SillyModel(vllm_config=vllm_config, prefix="")
inputs = torch.randn(100).cuda() inputs = torch.randn(100).cuda()
with compilation_counter.expect( with (
compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen= num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations, num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured, num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None, ),
vllm_config=vllm_config): # background context set_forward_context(None, vllm_config=vllm_config),
): # background context
# warm up with background context # warm up with background context
model(inputs) model(inputs)
# capturing/replaying should under context of cudagraph dispatching # capturing/replaying should under context of cudagraph dispatching
with set_forward_context( with set_forward_context(
None, None,
vllm_config=vllm_config, vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2).cuda()) model(torch.randn(2).cuda())
with set_forward_context( with set_forward_context(
None, None,
vllm_config=vllm_config, vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1).cuda()) model(torch.randn(1).cuda())
input = torch.zeros(2).cuda() input = torch.zeros(2).cuda()
reset_global_counter() reset_global_counter()
with set_forward_context( with set_forward_context(
None, None,
vllm_config=vllm_config, vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input) output = model(input)
assert get_global_counter() == 2 assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
@ -122,10 +134,8 @@ def test_simple_piecewise_compile(use_inductor):
use_inductor=use_inductor, use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1 expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations= expected_num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
3, # num_piecewise_capturable_graphs_seen expected_num_cudagraph_captured=6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
) )
@ -134,8 +144,7 @@ def test_simple_piecewise_compile(use_inductor):
def test_simple_inductor_graph_partition(splitting_ops): def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1 assert VLLM_USE_V1
if not is_torch_equal_or_newer("2.9.0.dev"): if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available " pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
"in PyTorch 2.9+")
_run_simple_model( _run_simple_model(
# inductor graph partition automatically resets splitting_ops # inductor graph partition automatically resets splitting_ops
@ -143,13 +152,9 @@ def test_simple_inductor_graph_partition(splitting_ops):
splitting_ops=splitting_ops, splitting_ops=splitting_ops,
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
use_inductor=True, use_inductor=True,
expected_num_piecewise_graphs_seen= expected_num_piecewise_graphs_seen=1, # since not splitting at fx graph level
1, # since not splitting at fx graph level expected_num_piecewise_capturable_graphs_seen=1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen= expected_num_backend_compilations=1, # since not splitting at fx graph level
1, # since not splitting at fx graph level expected_num_cudagraph_captured=6, # inductor graph partition still captures 6
expected_num_backend_compilations=
1, # since not splitting at fx graph level
expected_num_cudagraph_captured=
6, # inductor graph partition still captures 6
# graph, same as fx graph partition. # graph, same as fx graph partition.
) )

View File

@ -8,6 +8,7 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed. initialized randomly with a fixed seed.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
@ -17,8 +18,13 @@ from torch import nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import (
VllmConfig, set_current_vllm_config) CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
@ -43,15 +49,14 @@ class LlamaConfig:
factors.append((k, v)) factors.append((k, v))
factors.sort() factors.sort()
import hashlib import hashlib
return hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest() return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
def __post_init__(self): def __post_init__(self):
assert self.mlp_size >= self.hidden_size assert self.mlp_size >= self.hidden_size
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self, config: LlamaConfig) -> None:
super().__init__() super().__init__()
self.gate_up_projection = nn.Linear( self.gate_up_projection = nn.Linear(
@ -66,31 +71,31 @@ class LlamaMLP(nn.Module):
) )
if config.tractable_init: if config.tractable_init:
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size]) nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:]) nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
nn.init.eye_(self.down_projection.weight.data) nn.init.eye_(self.down_projection.weight.data)
else: else:
nn.init.xavier_normal_(self.gate_up_projection.weight.data, nn.init.xavier_normal_(
generator=torch.Generator().manual_seed( self.gate_up_projection.weight.data,
config.random_seed), generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001) gain=0.001,
nn.init.xavier_normal_(self.down_projection.weight.data, )
generator=torch.Generator().manual_seed( nn.init.xavier_normal_(
config.random_seed), self.down_projection.weight.data,
gain=0.001) generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward(self, x): def forward(self, x):
# for tractable_init and positive input, this is # for tractable_init and positive input, this is
# essentially an elementwise-square # essentially an elementwise-square
x = self.gate_up_projection(x) x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
x[:, x.size(1) // 2:])
x = self.down_projection(x) x = self.down_projection(x)
return x return x
class LlamaAttention(nn.Module): class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self, config: LlamaConfig) -> None:
super().__init__() super().__init__()
self.qkv_projection = nn.Linear( self.qkv_projection = nn.Linear(
@ -106,21 +111,25 @@ class LlamaAttention(nn.Module):
) )
if config.tractable_init: if config.tractable_init:
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size]) nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 * nn.init.eye_(
config.hidden_size]) self.qkv_projection.weight.data[
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size : 2 * config.hidden_size
config.hidden_size:]) ]
)
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
nn.init.eye_(self.output_projection.weight.data) nn.init.eye_(self.output_projection.weight.data)
else: else:
nn.init.xavier_normal_(self.qkv_projection.weight.data, nn.init.xavier_normal_(
generator=torch.Generator().manual_seed( self.qkv_projection.weight.data,
config.random_seed), generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001) gain=0.001,
nn.init.xavier_normal_(self.output_projection.weight.data, )
generator=torch.Generator().manual_seed( nn.init.xavier_normal_(
config.random_seed), self.output_projection.weight.data,
gain=0.001) generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward( def forward(
self, self,
@ -144,7 +153,6 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self, config: LlamaConfig) -> None:
super().__init__() super().__init__()
self.self_attention = LlamaAttention(config) self.self_attention = LlamaAttention(config)
@ -164,7 +172,7 @@ class LlamaDecoderLayer(nn.Module):
- if residual is not None, the outputs are: - if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3 - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2 - hidden_states = (residual + 1) ** 2
""" # noqa """ # noqa
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states + 1 hidden_states = hidden_states + 1
@ -173,8 +181,9 @@ class LlamaDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = hidden_states + 1 hidden_states = hidden_states + 1
hidden_states = self.self_attention(positions=positions, hidden_states = self.self_attention(
hidden_states=hidden_states) positions=positions, hidden_states=hidden_states
)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
residual = hidden_states residual = hidden_states
@ -186,20 +195,22 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(
def __init__(self, self,
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
config: LlamaConfig, config: LlamaConfig,
prefix: str = '', prefix: str = "",
**kwargs) -> None: **kwargs,
) -> None:
super().__init__() super().__init__()
self.embedding_tokens = nn.Embedding( self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size, num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size, embedding_dim=config.hidden_size,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]) [LlamaDecoderLayer(config) for _ in range(config.num_layers)]
)
# this is the initial value of the hidden states # this is the initial value of the hidden states
self.embedding_tokens.weight.data.fill_(config.init_value) self.embedding_tokens.weight.data.fill_(config.init_value)
@ -216,34 +227,39 @@ class LlamaModel(nn.Module):
return hidden_states return hidden_states
def tractable_computation(input_ids: torch.Tensor, def tractable_computation(
positions: torch.Tensor, input_ids: torch.Tensor,
config: LlamaConfig, positions: torch.Tensor,
init_value: float = 1.0) -> torch.Tensor: config: LlamaConfig,
hidden_states = torch.ones(input_ids.size(0), init_value: float = 1.0,
config.hidden_size, ) -> torch.Tensor:
device=input_ids.device, hidden_states = (
dtype=input_ids.dtype) * init_value torch.ones(
input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype,
)
* init_value
)
# first layer # first layer
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2 hidden_states = (residual + 1) ** 2
# following layers # following layers
for _ in range(config.num_layers - 1): for _ in range(config.num_layers - 1):
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2 hidden_states = (residual + 1) ** 2
return hidden_states return hidden_states
@torch.inference_mode @torch.inference_mode
def run_model(llama_config, def run_model(
use_compile: bool, llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
use_inductor: bool, ) -> torch.Tensor:
split_attn: bool = False) -> torch.Tensor:
if use_compile: if use_compile:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
@ -256,54 +272,66 @@ def run_model(llama_config,
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ) level=CompilationLevel.NO_COMPILATION,
)
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config, vllm_config = VllmConfig(
additional_config=llama_config) compilation_config=compilation_config, additional_config=llama_config
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = (
vllm_config=vllm_config, LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
prefix="").eval().cuda() .eval()
.cuda()
)
with set_forward_context({}, with set_forward_context({}, vllm_config=vllm_config): # background context
vllm_config=vllm_config): # background context
B = 16 # max batch size B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda() positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE # warmup for the model with cudagraph_mode NONE
model(input_ids, positions) model(input_ids, positions)
# simulate cudagraphs capturing # simulate cudagraphs capturing
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(input_ids[:2], positions[:2]) model(input_ids[:2], positions[:2])
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(input_ids[:1], positions[:1]) model(input_ids[:1], positions[:1])
input_ids[:2].zero_() input_ids[:2].zero_()
# simulate cudagraphs replay # simulate cudagraphs replay
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input_ids[:2], positions[:2]) output = model(input_ids[:2], positions[:2])
output = output.cpu() output = output.cpu()
if llama_config.tractable_init: if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], expected_output = tractable_computation(
positions[:2], input_ids[:2], positions[:2], llama_config
llama_config).cpu() ).cpu()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)
else: else:
@ -314,27 +342,23 @@ def run_model(llama_config,
def test_toy_llama(use_inductor: bool): def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation # compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128, llama_config = LlamaConfig(
mlp_size=256, hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
vocab_size=128, )
num_layers=12)
tractable_config = LlamaConfig(hidden_size=128, tractable_config = LlamaConfig(
mlp_size=256, hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
vocab_size=128, )
num_layers=2,
tractable_init=True)
outputs = [] outputs = []
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=0, num_graphs_seen=0,
num_piecewise_graphs_seen=0, num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0, num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0, num_backend_compilations=0,
num_cudagraph_captured=0, num_cudagraph_captured=0,
): ):
outputs.append( outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False) run_model(tractable_config, use_inductor=False, use_compile=False)
if use_inductor: if use_inductor:
@ -343,41 +367,41 @@ def test_toy_llama(use_inductor: bool):
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1, num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1, num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured= num_cudagraph_captured=2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen **kwargs,
**kwargs,
): ):
outputs.append( outputs.append(
run_model(llama_config, run_model(llama_config, use_inductor=use_inductor, use_compile=True)
use_inductor=use_inductor, )
use_compile=True))
run_model(tractable_config, use_inductor=use_inductor, use_compile=True) run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers + num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
1, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=1
num_piecewise_capturable_graphs_seen=1 + + llama_config.num_layers, # 1 + num_layers
llama_config.num_layers, # 1 + num_layers num_backend_compilations=1
num_backend_compilations=1 + + llama_config.num_layers, # num_piecewise_capturable_graphs_seen
llama_config.num_layers, # num_piecewise_capturable_graphs_seen num_cudagraph_captured=2
num_cudagraph_captured=2 * * (
(1 + llama_config.num_layers 1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
outputs.append( outputs.append(
run_model(llama_config, run_model(
use_inductor=use_inductor, llama_config,
use_compile=True, use_inductor=use_inductor,
split_attn=True)) use_compile=True,
run_model(tractable_config, split_attn=True,
use_inductor=use_inductor, )
use_compile=True, )
split_attn=True) run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)
for i in range(1, len(outputs)): for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i]) assert torch.allclose(outputs[0], outputs[i])
@ -388,17 +412,15 @@ def benchmark():
from triton.testing import do_bench from triton.testing import do_bench
# similar to llama 3.1-8B # similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096, llama_config = LlamaConfig(
mlp_size=14336, hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
vocab_size=128 * 1024, )
num_layers=32)
# a tiny model to measure the overhead # a tiny model to measure the overhead
# of piecewise cudagraph # of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40, llama_config = LlamaConfig(
mlp_size=80, hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
vocab_size=128, )
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
@ -424,12 +446,15 @@ def benchmark():
vllm_config = VllmConfig(compilation_config=compilation_config) vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = (
vllm_config=vllm_config, LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
prefix="").eval().cuda().to(torch.bfloat16) .eval()
.cuda()
.to(torch.bfloat16)
)
B = 256 # max batch size B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16) positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {} graphs = {}
@ -451,21 +476,25 @@ def benchmark():
# and use it later, because it will look up the name `b` in the # and use it later, because it will look up the name `b` in the
# enclosing scope, and the value of `b` will always be 256. # enclosing scope, and the value of `b` will always be 256.
# it is fine here, because we only use the lambda function once. # it is fine here, because we only use the lambda function once.
runtime = do_bench(lambda: graphs[b][0] # noqa runtime = do_bench(
(input_ids[:b], positions[:b])) # noqa lambda: graphs[b][0]( # noqa
input_ids[:b], positions[:b]
)
) # noqa
piecewise_cudagraph_time[b] = runtime piecewise_cudagraph_time[b] = runtime
else: else:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench( eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime eager_time[b] = eager_runtime
# print in tabular format # print in tabular format
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
for b in cudagraph_sizes: for b in cudagraph_sizes:
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" print(
f"\t{piecewise_cudagraph_time[b]:.3f}") f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -31,8 +31,9 @@ def reset_global_counter():
_global_counter = 0 _global_counter = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def silly_attention(
out: torch.Tensor) -> None: q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
""" """
Unified attention implementation that depends on Unified attention implementation that depends on
all inputs and affects the output. all inputs and affects the output.
@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out.copy_(q + k + v) out.copy_(q + k + v)
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def silly_attention_fake(
out: torch.Tensor) -> None: q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
"""Fake implementation for testing""" """Fake implementation for testing"""
return return
@ -60,5 +62,5 @@ direct_register_custom_op(
mutates_args=["out"], mutates_args=["out"],
fake_impl=silly_attention_fake, fake_impl=silly_attention_fake,
target_lib=silly_lib, target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ), tags=(torch._C.Tag.cudagraph_unsafe,),
) )

View File

@ -8,18 +8,30 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass from vllm.compilation.collective_fusion import AsyncTPPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, from vllm.config import (
PassConfig, VllmConfig) CompilationConfig,
from vllm.distributed import (tensor_model_parallel_all_gather, DeviceConfig,
tensor_model_parallel_reduce_scatter) ModelConfig,
from vllm.distributed.parallel_state import (init_distributed_environment, PassConfig,
initialize_model_parallel) VllmConfig,
)
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import (compare_two_settings, create_new_process_for_each_test, from ..utils import (
multi_gpu_test) compare_two_settings,
create_new_process_for_each_test,
multi_gpu_test,
)
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -33,14 +45,13 @@ prompts = [
class TestMMRSModel(torch.nn.Module): class TestMMRSModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16): def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty( self.gate_proj = torch.nn.Parameter(
(self.hidden_size * 2, hidden_size)), torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
requires_grad=False) )
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -66,14 +77,13 @@ class TestMMRSModel(torch.nn.Module):
class TestAGMMModel(torch.nn.Module): class TestAGMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16): def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty( self.weight = torch.nn.Parameter(
(hidden_size, hidden_size)), torch.empty((hidden_size, hidden_size)), requires_grad=False
requires_grad=False) )
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.weight, std=0.02) torch.nn.init.normal_(self.weight, std=0.02)
@ -96,20 +106,21 @@ class TestAGMMModel(torch.nn.Module):
class _BaseScaledMMModel(torch.nn.Module): class _BaseScaledMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16): def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\ self.weight = (
.contiguous().transpose(0, 1) torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
# Initialize scale_b for _scaled_mm. # Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32) self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
class TestScaledMMRSModel(_BaseScaledMMModel): class TestScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the scaled_mm + reduce scatter in the FX graph Forward pass implementing the scaled_mm + reduce scatter in the FX graph
@ -117,11 +128,13 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
""" """
fp8_input = input.to(FP8_DTYPE) fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input, scaled_mm = torch._scaled_mm(
self.weight, fp8_input,
scale_a=scale_a, self.weight,
scale_b=self.scale_b, scale_a=scale_a,
out_dtype=self.dtype) scale_b=self.scale_b,
out_dtype=self.dtype,
)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0) reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter return reduce_scatter
@ -133,7 +146,6 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
class TestAGScaledMMModel(_BaseScaledMMModel): class TestAGScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the all gather + scaled_mm in the FX graph Forward pass implementing the all gather + scaled_mm in the FX graph
@ -143,11 +155,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0) all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather, scaled_mm = torch._scaled_mm(
self.weight, all_gather,
scale_a=scale_a, self.weight,
scale_b=self.scale_b, scale_a=scale_a,
out_dtype=self.dtype) scale_b=self.scale_b,
out_dtype=self.dtype,
)
return scaled_mm return scaled_mm
def ops_in_model_before(self): def ops_in_model_before(self):
@ -158,7 +172,6 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
class TestCutlassScaledMMRSModel(_BaseScaledMMModel): class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the cutlass_scaled_mm + reduce scatter Forward pass implementing the cutlass_scaled_mm + reduce scatter
@ -167,11 +180,14 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
""" """
fp8_input = input.to(FP8_DTYPE) fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]), mm_out = torch.empty(
dtype=self.dtype, (fp8_input.shape[0], self.weight.shape[1]),
device=input.device) dtype=self.dtype,
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a, device=input.device,
self.scale_b, None) )
torch.ops._C.cutlass_scaled_mm(
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0) reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter return reduce_scatter
@ -183,7 +199,6 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
class TestAGCutlassScaledMMModel(_BaseScaledMMModel): class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the all gather + cutlass_scaled_mm Forward pass implementing the all gather + cutlass_scaled_mm
@ -195,11 +210,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]), mm_out = torch.empty(
dtype=self.dtype, (all_gather.shape[0], self.weight.shape[1]),
device=all_gather.device) dtype=self.dtype,
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight, device=all_gather.device,
scale_a, self.scale_b, None) )
torch.ops._C.cutlass_scaled_mm(
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
)
return mm_out return mm_out
def ops_in_model_before(self): def ops_in_model_before(self):
@ -210,23 +228,37 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [ @pytest.mark.parametrize(
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel, "test_model",
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel [
]) TestMMRSModel,
TestAGMMModel,
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
],
)
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
reason="Only test on CUDA") def test_async_tp_pass_replace(
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
hidden_size: int, dtype: torch.dtype): ):
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel, if (
TestCutlassScaledMMRSModel, test_model
TestAGCutlassScaledMMModel) and dtype == torch.float16: in (
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
)
and dtype == torch.float16
):
pytest.skip( pytest.skip(
"Only bf16 high precision output types are supported for " \ "Only bf16 high precision output types are supported for "
"per-token (row-wise) scaling" "per-token (row-wise) scaling"
) )
@ -235,19 +267,24 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with # need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda # torch.distributed and cuda
torch.multiprocessing.spawn(fn, torch.multiprocessing.spawn(
args=(num_processes, test_model, fn,
batch_size, seq_len, hidden_size, args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
dtype), nprocs=nprocs,
nprocs=nprocs) )
run_torch_spawn(async_tp_pass_on_test_model, num_processes) run_torch_spawn(async_tp_pass_on_test_model, num_processes)
def async_tp_pass_on_test_model(local_rank: int, world_size: int, def async_tp_pass_on_test_model(
test_model_cls: torch.nn.Module, local_rank: int,
batch_size: int, seq_len: int, world_size: int,
hidden_size: int, dtype: torch.dtype): test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
):
current_platform.seed_everything(0) current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
@ -255,13 +292,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
'RANK': str(local_rank), {
'LOCAL_RANK': str(local_rank), "RANK": str(local_rank),
'WORLD_SIZE': str(world_size), "LOCAL_RANK": str(local_rank),
'MASTER_ADDR': 'localhost', "WORLD_SIZE": str(world_size),
'MASTER_PORT': '12345', "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
@ -269,27 +308,28 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( vllm_config.compilation_config = CompilationConfig(
enable_async_tp=True, ), ) pass_config=PassConfig(
enable_async_tp=True,
),
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name, vllm_config.model_config = ModelConfig(
trust_remote_code=True, model=model_name, trust_remote_code=True, dtype=dtype, seed=42
dtype=dtype, )
seed=42)
async_tp_pass = AsyncTPPass(vllm_config) async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass) backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size, model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
dtype) # Pass dtype to model constructor
hidden_states = torch.randn((batch_size * seq_len, hidden_size), hidden_states = torch.randn(
dtype=dtype, (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
requires_grad=False) )
compiled_model = torch.compile(model, backend=backend) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states) compiled_model(hidden_states)
@ -306,10 +346,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", [ @pytest.mark.parametrize(
"meta-llama/Llama-3.2-1B-Instruct", "model_id",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
]) )
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True]) @pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"]) @pytest.mark.parametrize("distributed_backend", ["mp"])
@ -342,12 +382,10 @@ def test_async_tp_pass_correctness(
common_args.append("--enforce-eager") common_args.append("--enforce-eager")
compilation_config = { compilation_config = {
'level': 3, "level": 3,
'compile_sizes': [2, 4, 8], "compile_sizes": [2, 4, 8],
'splitting_ops': [], "splitting_ops": [],
'pass_config': { "pass_config": {"enable_async_tp": async_tp_enabled},
'enable_async_tp': async_tp_enabled
},
} }
async_tp_env = tp_env = { async_tp_env = tp_env = {
@ -372,9 +410,6 @@ def test_async_tp_pass_correctness(
"mp", "mp",
] ]
compare_two_settings(model_id, compare_two_settings(
async_tp_args, model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
tp_args, )
async_tp_env,
tp_env,
method="generate")

View File

@ -103,23 +103,28 @@ def test_compile_correctness(
attn_backend = test_setting.attn_backend attn_backend = test_setting.attn_backend
method = test_setting.method method = test_setting.method
if cuda_device_count_stateless() < pp_size * tp_size: if cuda_device_count_stateless() < pp_size * tp_size:
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got " pytest.skip(
f"{cuda_device_count_stateless()}") f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}"
)
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [ final_args = [
"--enforce-eager", *model_args, "-pp", "--enforce-eager",
str(pp_size), "-tp", *model_args,
str(tp_size) "-pp",
str(pp_size),
"-tp",
str(tp_size),
] ]
all_args: list[list[str]] = [] all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = [] all_envs: list[dict[str, str] | None] = []
for level in [ for level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE, CompilationLevel.PIECEWISE,
]: ]:
all_args.append(final_args + [f"-O{level}"]) all_args.append(final_args + [f"-O{level}"])
all_envs.append({}) all_envs.append({})
@ -130,14 +135,15 @@ def test_compile_correctness(
model, model,
all_args, all_args,
all_envs, all_envs,
method=method if method != "generate" else "generate_close") method=method if method != "generate" else "generate_close",
)
all_envs.clear() all_envs.clear()
all_args.clear() all_args.clear()
for level in [ for level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE, CompilationLevel.DYNAMO_ONCE,
]: ]:
all_args.append(final_args + [f"-O{level}"]) all_args.append(final_args + [f"-O{level}"])
all_envs.append({}) all_envs.append({})

View File

@ -9,11 +9,11 @@ from vllm.utils import _is_torch_equal_or_newer
def test_version(): def test_version():
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev') assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev') assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev') assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev') assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_use_cudagraphs_dynamic(monkeypatch): def test_use_cudagraphs_dynamic(monkeypatch):
@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch):
vllm_config = VllmConfig() vllm_config = VllmConfig()
assert vllm_config.compilation_config.use_cudagraph assert vllm_config.compilation_config.use_cudagraph
monkeypatch.setenv('VLLM_USE_V1', '0') monkeypatch.setenv("VLLM_USE_V1", "0")
vllm_config = VllmConfig() vllm_config = VllmConfig()
assert not vllm_config.compilation_config.use_cudagraph assert not vllm_config.compilation_config.use_cudagraph
@ -44,19 +44,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
assert vllm.envs.VLLM_USE_V1 assert vllm.envs.VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
compilation_config = { compilation_config = {
"use_cudagraph": False, # speed things up a bit "use_cudagraph": False, # speed things up a bit
} }
with ( with (
compilation_counter.expect(num_cache_entries_updated=0, compilation_counter.expect(
num_compiled_artifacts_saved=0), num_cache_entries_updated=0, num_compiled_artifacts_saved=0
# loading the model causes compilation (if enabled) to happen ),
vllm_runner('facebook/opt-125m', # loading the model causes compilation (if enabled) to happen
compilation_config=compilation_config, vllm_runner(
gpu_memory_utilization=0.4) as _): "facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass pass
@ -67,22 +71,25 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
assert vllm.envs.VLLM_USE_V1 assert vllm.envs.VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = { compilation_config = {
"cudagraph_capture_sizes": [100], "cudagraph_capture_sizes": [100],
"use_cudagraph": enabled, "use_cudagraph": enabled,
} }
with ( with (
compilation_counter.expect( compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0, num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0, num_cudagraph_captured=13 if enabled else 0,
), ),
# loading the model causes compilation (if enabled) to happen # loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m', vllm_runner(
compilation_config=compilation_config, "facebook/opt-125m",
gpu_memory_utilization=0.4) as _): compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass pass
@ -90,14 +97,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
@pytest.mark.forked @pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch): def test_dynamo_as_is(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with ( with (
compilation_counter.expect(dynamo_as_is_count=1), compilation_counter.expect(dynamo_as_is_count=1),
# loading the model causes compilation (if enabled) to happen # loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m', vllm_runner(
compilation_config={"level": 1}, "facebook/opt-125m",
gpu_memory_utilization=0.4) as _): compilation_config={"level": 1},
gpu_memory_utilization=0.4,
) as _,
):
pass pass
@ -105,14 +115,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
@pytest.mark.forked @pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch): def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with ( with (
compilation_counter.expect(num_graphs_seen=0, compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
dynamo_as_is_count=0), # loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen vllm_runner(
vllm_runner('facebook/opt-125m', "facebook/opt-125m",
compilation_config={"level": 0}, compilation_config={"level": 0},
gpu_memory_utilization=0.4) as _): gpu_memory_utilization=0.4,
) as _,
):
pass pass
@ -120,77 +132,73 @@ def test_no_compilation(vllm_runner, monkeypatch):
@pytest.mark.forked @pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch): def test_enforce_eager(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with ( with (
compilation_counter.expect(num_graphs_seen=0, compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
dynamo_as_is_count=0), # loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen vllm_runner(
vllm_runner('facebook/opt-125m', "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
enforce_eager=True, ) as _,
gpu_memory_utilization=0.4) as _): ):
pass pass
def test_splitting_ops_dynamic(): def test_splitting_ops_dynamic():
# Default config # Default config
config = VllmConfig() config = VllmConfig()
assert config.compilation_config.cudagraph_mode == \ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention() assert config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True # When use_inductor_graph_partition=True
if _is_torch_equal_or_newer('2.9.0.dev'): if _is_torch_equal_or_newer("2.9.0.dev"):
# inductor graph partition is only available in PyTorch 2.9+. # inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip. # this is a fast config check so we are not using pytest.skip.
config = VllmConfig(compilation_config=CompilationConfig( config = VllmConfig(
use_inductor_graph_partition=True, compilation_config=CompilationConfig(
splitting_ops=["silly_attention"])) use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
)
)
# should ignore splitting_ops # should ignore splitting_ops
assert config.compilation_config.splitting_ops == [] assert config.compilation_config.splitting_ops == []
# When attn_fusion pass enabled. # When attn_fusion pass enabled.
config = VllmConfig(compilation_config=CompilationConfig( config = VllmConfig(
pass_config={ compilation_config=CompilationConfig(
"enable_attn_fusion": True, pass_config={"enable_attn_fusion": True, "enable_noop": True},
"enable_noop": True custom_ops=["+quant_fp8"],
}, cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"], )
cudagraph_mode=CUDAGraphMode.PIECEWISE, )
))
assert config.compilation_config.splitting_ops == [] assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL # cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == \ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion # splitting_ops can not contain attention ops when attn_fusion
# pass enabled. # pass enabled.
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = VllmConfig(compilation_config=CompilationConfig( config = VllmConfig(
pass_config={ compilation_config=CompilationConfig(
"enable_attn_fusion": True, pass_config={"enable_attn_fusion": True, "enable_noop": True},
"enable_noop": True custom_ops=["+quant_fp8"],
}, cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"], # work around for accessing all attntion ops
cudagraph_mode=CUDAGraphMode.PIECEWISE, splitting_ops=CompilationConfig()._attention_ops,
# work around for accessing all attntion ops )
splitting_ops=CompilationConfig()._attention_ops, )
))
# When both use_inductor_graph_partition and attn_fusion pass enabled. # When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer('2.9.0.dev'): if _is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(compilation_config=CompilationConfig( config = VllmConfig(
use_inductor_graph_partition=True, compilation_config=CompilationConfig(
pass_config={ use_inductor_graph_partition=True,
"enable_attn_fusion": True, pass_config={"enable_attn_fusion": True, "enable_noop": True},
"enable_noop": True custom_ops=["+quant_fp8"],
}, cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"], )
cudagraph_mode=CUDAGraphMode.PIECEWISE, )
))
assert config.compilation_config.splitting_ops == [] assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under # enable_attn_fusion is directly support under
# use_inductor_graph_partition=True, and cudagraph_mode # use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged. # is unchanged.
assert config.compilation_config.cudagraph_mode == \ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
CUDAGraphMode.PIECEWISE

View File

@ -4,10 +4,15 @@ import torch
from torch import nn from torch import nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile, from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
support_torch_compile) from vllm.config import (
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CacheConfig,
CUDAGraphMode, VllmConfig, set_current_vllm_config) CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
@ -18,32 +23,42 @@ MLP_SIZE = 128
@torch.inference_mode @torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, def run_model(
cudagraph_runtime_mode: CUDAGraphMode): vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
):
with set_forward_context({}, vllm_config=vllm_config): with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE # warmup for the model with cudagraph_mode NONE
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
# simulate cudagraphs capturing # simulate cudagraphs capturing
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2, MLP_SIZE).cuda()) model(torch.randn(2, MLP_SIZE).cuda())
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1, MLP_SIZE).cuda()) model(torch.randn(1, MLP_SIZE).cuda())
# simulate cudagraphs replay # simulate cudagraphs replay
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(torch.randn(2, MLP_SIZE).cuda()) output = model(torch.randn(2, MLP_SIZE).cuda())
output = output.cpu() output = output.cpu()
@ -52,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module,
def test_ignore_torch_compile_decorator(): def test_ignore_torch_compile_decorator():
# piecewise # piecewise
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=True, level=CompilationLevel.PIECEWISE,
splitting_ops=["silly.attention"], use_cudagraph=True,
cudagraph_capture_sizes=[1, 2], splitting_ops=["silly.attention"],
)) cudagraph_capture_sizes=[1, 2],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@support_torch_compile @support_torch_compile
class A(nn.Module): class A(nn.Module):
def __init__(
def __init__(self, self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
*, ) -> None:
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -79,66 +93,60 @@ def test_ignore_torch_compile_decorator():
return x return x
@ignore_torch_compile @ignore_torch_compile
class B(A): class B(A): ...
...
@support_torch_compile @support_torch_compile
class C(B): class C(B): ...
...
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
# A has support_torch_compile # A has support_torch_compile
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_piecewise_graphs_seen=3, num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2, num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2, num_backend_compilations=2,
num_cudagraph_captured=4, num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
run_model(vllm_config, mod_A, cudagraph_runtime_mode) run_model(vllm_config, mod_A, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
# B's ignore_torch_compile should override A's support_torch_compile # B's ignore_torch_compile should override A's support_torch_compile
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=0, num_graphs_seen=0,
num_piecewise_graphs_seen=0, num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0, num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0, num_backend_compilations=0,
num_cudagraph_captured=0, num_cudagraph_captured=0,
): ):
run_model(vllm_config, mod_B, cudagraph_runtime_mode) run_model(vllm_config, mod_B, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
# C's support_torch_compile should override B's ignore_torch_compile # C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_piecewise_graphs_seen=3, num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2, num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2, num_backend_compilations=2,
num_cudagraph_captured=4, num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
run_model(vllm_config, mod_C, cudagraph_runtime_mode) run_model(vllm_config, mod_C, cudagraph_runtime_mode)
# Only enable torch.compile if # Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=True # vllm_config.cache_config.kv_sharing_fast_prefill=True
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. @support_torch_compile(
kv_sharing_fast_prefill) enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class B(nn.Module): class B(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -152,15 +160,11 @@ class B(nn.Module):
# Only enable torch.compile if # Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=False # vllm_config.cache_config.kv_sharing_fast_prefill=False
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. @support_torch_compile(
cache_config.kv_sharing_fast_prefill) enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class A(nn.Module): class A(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
@ -175,54 +179,60 @@ class A(nn.Module):
def test_conditional_compile_enable_if(): def test_conditional_compile_enable_if():
vllm_config = VllmConfig(cache_config=CacheConfig( vllm_config = VllmConfig(
kv_sharing_fast_prefill=True, ), cache_config=CacheConfig(
compilation_config=CompilationConfig( kv_sharing_fast_prefill=True,
level=CompilationLevel.PIECEWISE, ),
use_cudagraph=True, compilation_config=CompilationConfig(
splitting_ops=["silly.attention"], level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=[1, 2], use_cudagraph=True,
)) splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
),
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
# A has support_torch_compile but enable_if fn returns False # A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2 # enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled # to be compiled
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=2, num_graphs_seen=2,
num_piecewise_graphs_seen=6, num_piecewise_graphs_seen=6,
# 3 piecewise graphs per instance of B() # 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4, num_backend_compilations=4,
num_cudagraph_captured=8, num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
run_model(vllm_config, mod_A, cudagraph_runtime_mode) run_model(vllm_config, mod_A, cudagraph_runtime_mode)
# Set kv_sharing_fast_prefill=False # Set kv_sharing_fast_prefill=False
# which will cause A to be compiled and B to not be compiled # which will cause A to be compiled and B to not be compiled
vllm_config = VllmConfig(cache_config=CacheConfig( vllm_config = VllmConfig(
kv_sharing_fast_prefill=False, ), cache_config=CacheConfig(
compilation_config=CompilationConfig( kv_sharing_fast_prefill=False,
level=CompilationLevel.PIECEWISE, ),
use_cudagraph=True, compilation_config=CompilationConfig(
splitting_ops=["silly.attention"], level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=[1, 2], use_cudagraph=True,
)) splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
),
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_piecewise_graphs_seen=7, num_piecewise_graphs_seen=7,
# 3 attn ops and 4 non-attn ops # 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4, num_backend_compilations=4,
num_cudagraph_captured=8, num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
run_model(vllm_config, mod_A, cudagraph_runtime_mode) run_model(vllm_config, mod_A, cudagraph_runtime_mode)

View File

@ -14,8 +14,7 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
PassConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
@ -25,43 +24,54 @@ from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { (
"dtype": torch.float16, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
}), {
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { "dtype": torch.float16,
"dtype": torch.float16, },
}), ),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
{
"dtype": torch.float16,
},
),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if all: if all:
# TODO: figure out why this fails. # TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223 if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { TEST_MODELS.append(
"quantization": "gguf" ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
})) )
if is_quant_method_supported("gptq"): if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { TEST_MODELS.append(
"quantization": "gptq" ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
})) )
if is_quant_method_supported("gptq_marlin"): if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { TEST_MODELS.append(
"quantization": "gptq_marlin" (
})) "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
{"quantization": "gptq_marlin"},
)
)
if is_quant_method_supported("gptq_marlin_24"): if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { TEST_MODELS.append(
"quantization": "gptq_marlin_24" (
})) "alexm-nm/tinyllama-24-marlin24-4bit-g128",
{"quantization": "gptq_marlin_24"},
)
)
if not current_platform.is_rocm() and is_quant_method_supported("awq"): if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { TEST_MODELS.append(
"quantization": "AWQ" ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
})) )
if keywords is None: if keywords is None:
return TEST_MODELS return TEST_MODELS
@ -95,22 +105,34 @@ def test_full_graph(
"compilation_config, model_info", "compilation_config, model_info",
[ [
# additional compile sizes, only some of the models # additional compile sizes, only some of the models
(CompilationConfig(level=CompilationLevel.PIECEWISE, (
compile_sizes=[1, 2]), model) CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
model,
)
for model in models_list(all=False) for model in models_list(all=False)
] + [ ]
+ [
# RMSNorm + quant fusion, only 8-bit quant models # RMSNorm + quant fusion, only 8-bit quant models
(CompilationConfig(level=CompilationLevel.PIECEWISE, (
custom_ops=["+rms_norm"], CompilationConfig(
pass_config=PassConfig(enable_fusion=True, level=CompilationLevel.PIECEWISE,
enable_noop=True)), model) custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
model,
)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
] + [ ]
+ [
# Test depyf integration works # Test depyf integration works
(CompilationConfig(level=CompilationLevel.PIECEWISE, (
debug_dump_path=tempfile.gettempdir()), CompilationConfig(
("facebook/opt-125m", {})), level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
] + [ ),
("facebook/opt-125m", {}),
),
]
+ [
# graph inductor partition # graph inductor partition
( (
CompilationConfig( CompilationConfig(
@ -119,20 +141,24 @@ def test_full_graph(
# torch._C.Tag.cudagraph_unsafe to specify splitting ops # torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]), compile_sizes=[1, 2],
model) for model in models_list(all=False) ),
model,
)
for model in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev") if is_torch_equal_or_newer("2.9.0.dev")
]) ],
)
# only test some of the models # only test some of the models
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_custom_compile_config( def test_custom_compile_config(
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]], model_info: tuple[str, dict[str, Any]],
): ):
if (compilation_config.use_inductor_graph_partition if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
and not is_torch_equal_or_newer("2.9.0.dev")): "2.9.0.dev"
pytest.skip("inductor graph partition is only available " ):
"in PyTorch 2.9+") pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info model, model_kwargs = model_info
print(f"MODEL={model}") print(f"MODEL={model}")
@ -156,8 +182,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
def test_inductor_graph_partition_attn_fusion(caplog_vllm): def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"): if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available " pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
"in PyTorch 2.9+")
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
@ -171,14 +196,16 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
"kv_cache_dtype": "fp8", "kv_cache_dtype": "fp8",
"max_model_len": 1024, "max_model_len": 1024,
} }
with caplog_vllm.at_level( with (
logging.DEBUG), global_force_attn_backend_context_manager( caplog_vllm.at_level(logging.DEBUG),
_Backend.FLASHINFER): global_force_attn_backend_context_manager(_Backend.FLASHINFER),
):
run_model(compilation_config, model, model_kwargs) run_model(compilation_config, model, model_kwargs)
try: try:
assert ("Fused quantization onto 48 attention nodes" assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
in caplog_vllm.text), caplog_vllm.text caplog_vllm.text
)
except AssertionError: except AssertionError:
# Note: this message is only triggered when the compilation goes # Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on # through the custom pass. Due to multiple layers of cache on
@ -189,8 +216,11 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
assert "Fused quantization" not in caplog_vllm.text assert "Fused quantization" not in caplog_vllm.text
def run_model(compile_config: Union[int, CompilationConfig], model: str, def run_model(
model_kwargs: dict[str, Any]): compile_config: Union[int, CompilationConfig],
model: str,
model_kwargs: dict[str, Any],
):
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",

View File

@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module): class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128): def __init__(self, hidden_size: int = 128):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8: if TEST_FP8:
self.w = torch.rand(hidden_size, self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR, act_quant_group_shape=GroupShape.PER_TENSOR,
@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
if TEST_FP8: if TEST_FP8:
x2 = self.fp8_linear.apply(y, x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
self.w,
self.wscale,
input_scale=self.wscale)
return x2 return x2
else: else:
return y return y
def example_inputs(self, num_tokens=32, hidden_size=128): def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32 dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), ) return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
def ops_in_model(self, do_fusion): def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion: if TEST_FP8 and do_fusion:
@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module): class TestFusedAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
dtype = torch.float16 if TEST_FP8 else torch.float32 dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter( self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype)) torch.empty((intermediate_size, hidden_size), dtype=dtype)
)
self.norm = RMSNorm(intermediate_size, 1e-05) self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter( self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype)) torch.ones(intermediate_size, dtype=dtype)
)
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32) self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32 dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size), hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
return (hidden_states, residual) return (hidden_states, residual)
def ops_in_model(self, do_fusion): def ops_in_model(self, do_fusion):
@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module): class TestRotaryEmbedding(torch.nn.Module):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
def __init__(self,
head_dim=64,
rotary_dim=None,
max_position=2048,
base=10000):
super().__init__() super().__init__()
self.head_dim = head_dim self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim self.rotary_dim = rotary_dim or head_dim
@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
class TestRotaryEmbeddingSliceScatter(torch.nn.Module): class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
def __init__(self,
head_dim=64,
num_heads=4,
max_position=2048,
base=10000):
super().__init__() super().__init__()
self.head_dim = head_dim self.head_dim = head_dim
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = head_dim * num_heads self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(self.hidden_size, self.qkv_proj = torch.nn.Linear(
self.hidden_size * 3, self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
bias=False, )
dtype=torch.float16)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
@ -233,21 +213,24 @@ MODELS = [
@pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
torch.set_default_device("cuda") torch.set_default_device("cuda")
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig( vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] passes = (
if do_fusion else [noop_pass, cleanup_pass]) [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass) backend_func = TestBackend(*passes, func_pass)
@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
# check if the functionalization pass is applied # check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion): for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op) find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
is None) # noqa: E501
# make sure the ops were all de-functionalized # make sure the ops were all de-functionalized
found = dict() found = dict()

View File

@ -5,17 +5,26 @@ import pytest
import torch import torch
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import (
RMSNormQuantFusionPass) FUSED_OPS,
QUANT_OPS,
FusedRMSQuantKey,
RMSNormQuantFusionPass,
)
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc) GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported from ..utils import override_cutlass_fp8_supported
@ -25,9 +34,15 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(
def __init__(self, hidden_size: int, eps: float, static: bool, self,
cuda_force_torch: bool, *args, **kwargs): hidden_size: int,
eps: float,
static: bool,
cuda_force_torch: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
@ -54,17 +69,15 @@ class TestModel(torch.nn.Module):
resid = torch.sqrt(x) resid = torch.sqrt(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply(y, x2 = self.fp8_linear.apply(
self.w[0], y, self.w[0], self.wscale[0], input_scale=self.scale[0]
self.wscale[0], )
input_scale=self.scale[0])
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2, x3 = self.fp8_linear.apply(
self.w[1], y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
self.wscale[1], )
input_scale=self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
return y3 return y3
@ -74,7 +87,7 @@ class TestModel(torch.nn.Module):
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)], FUSED_OPS[FusedRMSQuantKey(self.key, False)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)] FUSED_OPS[FusedRMSQuantKey(self.key, True)],
] ]
@ -85,22 +98,27 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch", @pytest.mark.parametrize(
[True, False] if cutlass_fp8_supported() else [True]) "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
@pytest.mark.skipif(not current_platform.is_cuda_alike(), )
reason="Only test on CUDA and ROCm") @pytest.mark.skipif(
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
cuda_force_torch): )
def test_fusion_rmsnorm_quant(
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
custom_ops=["+rms_norm", "+quant_fp8"], level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_fusion=True, enable_noop=True), custom_ops=["+rms_norm", "+quant_fp8"],
)) pass_config=PassConfig(enable_fusion=True, enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)

View File

@ -10,14 +10,24 @@ from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, from vllm.config import (
ModelConfig, PassConfig, VllmConfig) CompilationConfig,
CompilationLevel,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.distributed.parallel_state import (
initialize_model_parallel) init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8) GroupShape,
QuantFP8,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
@ -26,7 +36,6 @@ from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -47,7 +56,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -68,25 +76,22 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True, self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
dtype=torch.float32)
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size) view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view) all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual) norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output, torch.ops._C.static_scaled_fp8_quant(
norm_output.contiguous(), self.output, norm_output.contiguous(), self.scale
self.scale) )
return self.output, residual_output return self.output, residual_output
def ops_in_model_after(self): def ops_in_model_after(self):
@ -95,35 +100,33 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def ops_in_model_before(self): def ops_in_model_before(self):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default torch.ops._C.static_scaled_fp8_quant.default,
] ]
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
dtype=torch.float32)
round_up = lambda x, y: (x + y - 1) // y * y round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128) rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16 scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4) rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4), self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
dtype=torch.int32)
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size) view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view) all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual) norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1]) norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(self.output, norm_output, torch.ops._C.scaled_fp4_quant(
self.output_scale, self.scale) self.output, norm_output, self.output_scale, self.scale
)
return self.output, residual_output, self.output_scale return self.output, residual_output, self.output_scale
def ops_in_model_after(self): def ops_in_model_after(self):
@ -132,7 +135,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def ops_in_model_before(self): def ops_in_model_before(self):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default torch.ops._C.scaled_fp4_quant.default,
] ]
@ -145,41 +148,55 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
TestAllReduceFusedAddRMSNormStaticQuantFP8Model, TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
# TODO: Enable with torch==2.8.0 # TODO: Enable with torch==2.8.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model, # TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
]) ],
)
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
reason="Only test on CUDA")
@pytest.mark.skipif( @pytest.mark.skipif(
not find_spec("flashinfer") not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="flashinfer is not found or flashinfer " reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion") "is not compiled with trtllm_allreduce_fusion",
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, )
batch_size: int, seq_len: int, def test_all_reduce_fusion_pass_replace(
hidden_size: int, dtype: torch.dtype): test_model: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
):
num_processes = 2 num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model if (
and not current_platform.has_device_capability(100)): test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
pytest.skip("Skip as nvfp4 is only supported on " and not current_platform.has_device_capability(100)
"devices with compute capability 10.0 (Blackwell)") ):
pytest.skip(
"Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)"
)
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn, torch.multiprocessing.spawn(
args=(num_processes, test_model, fn,
batch_size, seq_len, hidden_size, args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
dtype), nprocs=nprocs,
nprocs=nprocs) )
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, def all_reduce_fusion_pass_on_test_model(
test_model_cls: torch.nn.Module, local_rank: int,
batch_size: int, seq_len: int, world_size: int,
hidden_size: int, dtype: torch.dtype): test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
):
current_platform.seed_everything(0) current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
@ -187,39 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
'RANK': str(local_rank), {
'LOCAL_RANK': str(local_rank), "RANK": str(local_rank),
'WORLD_SIZE': str(world_size), "LOCAL_RANK": str(local_rank),
'MASTER_ADDR': 'localhost', "WORLD_SIZE": str(world_size),
'MASTER_PORT': '12345', "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
custom_ops=["+rms_norm", "+quant_fp8"])) level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
)
)
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True) enable_fi_allreduce_fusion=True, enable_noop=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name, vllm_config.model_config = ModelConfig(
trust_remote_code=True, model=model_name, trust_remote_code=True, dtype=dtype, seed=42
dtype=dtype, )
seed=42)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
cleanup_pass)
token_num = batch_size * seq_len token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num) model = test_model_cls(hidden_size, token_num)

View File

@ -19,14 +19,23 @@ from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, from vllm.config import (
ModelConfig, PassConfig, SchedulerConfig, VllmConfig, CacheConfig,
set_current_vllm_config) CompilationConfig,
CompilationLevel,
ModelConfig,
PassConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant) QuantKey,
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( kFp8StaticTensorSym,
Fp8LinearOp) kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -40,14 +49,16 @@ backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, quant_key", "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) )
@pytest.mark.parametrize("use_triton_fa", [True, False]) @pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_rocm(), @pytest.mark.skipif(
reason="V0 attn quant fusion only on ROCm") not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, )
quant_key: QuantKey, use_triton_fa: bool): def test_attention_fusion_v0(
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
):
# Clean Dynamo cache to avoid reusing other test cases # Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough) # (for some reason the reset at the end is not enough)
torch._dynamo.reset() torch._dynamo.reset()
@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend_unfused", backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
) )
vllm_config = VllmConfig(compilation_config=compile_config, vllm_config = VllmConfig(
model_config=ModelConfig( compilation_config=compile_config,
model=model, model_config=ModelConfig(
dtype=torch.bfloat16, model=model,
)) dtype=torch.bfloat16,
),
)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model, llm = LLM(
enforce_eager=True, model,
compilation_config=compile_config, enforce_eager=True,
gpu_memory_utilization=0.5, compilation_config=compile_config,
max_model_len=2048) gpu_memory_utilization=0.5,
max_model_len=2048,
)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
max_tokens=10,
top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params) unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released backend_unfused = None # Reset backend to make sure llm gets released
@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend", backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
) )
vllm_config = VllmConfig(compilation_config=compile_config, vllm_config = VllmConfig(
model_config=ModelConfig( compilation_config=compile_config,
model=model, model_config=ModelConfig(
dtype=torch.bfloat16, model=model,
)) dtype=torch.bfloat16,
),
)
# AttnFusionPass needs attention layers to be registered in config upon init # AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation. # so we initialize it during compilation.
attn_pass = LazyInitPass(AttnFusionPass, vllm_config) attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model, llm2 = LLM(
enforce_eager=True, model,
compilation_config=compile_config, enforce_eager=True,
gpu_memory_utilization=0.5, compilation_config=compile_config,
max_model_len=2048) gpu_memory_utilization=0.5,
max_model_len=2048,
)
# check support # check support
attn_fusion_supported = [ attn_fusion_supported = [
@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
for i in range(len(attn_nodes_pre)): for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \ assert fused == attn_fusion_supported[i], (
f"Node {i} {'' if fused else 'not '} expected " \ f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
f"to have fused output quant" )
# check outputs # check outputs
fused_output = llm2.generate(prompts, sampling_params) fused_output = llm2.generate(prompts, sampling_params)
@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
class AttentionQuantPatternModel(torch.nn.Module): class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion.""" """Base model for AttentionQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, def __init__(
kv_cache_dtype: torch.dtype, device: torch.device, self,
vllm_config: VllmConfig, **kwargs): num_qo_heads: int,
num_kv_heads: int,
head_size: int,
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
**kwargs,
):
super().__init__() super().__init__()
self.num_qo_heads = num_qo_heads self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
@ -197,33 +221,30 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device, device=self.device,
) )
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
-> AttentionMetadata:
"""Initialize attention metadata.""" """Initialize attention metadata."""
# Create common attn metadata # Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size, batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata( common_attn_metadata = create_common_attn_metadata(
batch_spec, batch_spec, self.block_size, self.device, arange_block_indices=True
self.block_size, )
self.device,
arange_block_indices=True)
max_blocks = (max(batch_spec.seq_lens) + self.block_size - max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
1) // self.block_size
num_blocks = batch_size * max_blocks num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM # Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, block_size, num_kv_heads, head_size] # - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, num_kv_heads, block_size, head_size] # - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(num_blocks, kv_cache = torch.zeros(
2, num_blocks,
self.num_kv_heads, 2,
self.block_size, self.num_kv_heads,
self.head_size, self.block_size,
dtype=self.kv_cache_dtype, self.head_size,
device=self.device) dtype=self.kv_cache_dtype,
device=self.device,
)
if current_platform.is_rocm(): if current_platform.is_rocm():
# k/v as 1st dimention # k/v as 1st dimention
if use_hnd: if use_hnd:
@ -239,7 +260,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
# Build attn metadata # Build attn metadata
self.attn_metadata = self.builder.build( self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata) common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
return self.attn_metadata return self.attn_metadata
@ -254,27 +276,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static, act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape) act_quant_group_shape=self.quant_key.scale.group_shape,
)
hidden_size = self.num_qo_heads * self.head_size hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get( self.w = kwargs.get(
"w", { "w",
"weight": {
torch.randn(hidden_size, hidden_size).to( "weight": torch.randn(hidden_size, hidden_size)
dtype=FP8_DTYPE, device=self.device).t(), .to(dtype=FP8_DTYPE, device=self.device)
"wscale": .t(),
torch.tensor([1.0], dtype=torch.float32, device=self.device), "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale": "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
torch.tensor([1.0], dtype=torch.float32, device=self.device), },
}) )
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused.""" """Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output, return self.fp8_linear.apply(
weight=self.w["weight"], input=attn_output,
weight_scale=self.w["wscale"], weight=self.w["weight"],
input_scale=self.w["scale"]) weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
@ -287,42 +312,54 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
hidden_size = self.num_qo_heads * self.head_size hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get( self.w = kwargs.get(
"w", { "w",
"weight": {
torch.randint(256, (hidden_size, hidden_size // 2), "weight": torch.randint(
dtype=FP4_DTYPE, 256,
device=self.device), (hidden_size, hidden_size // 2),
"wscale_swizzled": dtype=FP4_DTYPE,
torch.randn(hidden_size, hidden_size // 16).to( device=self.device,
dtype=FP8_DTYPE, device=self.device), ),
"wscale": "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
torch.tensor([500], dtype=torch.float32, device=self.device), dtype=FP8_DTYPE, device=self.device
"scale": ),
torch.tensor([0.002], dtype=torch.float32, device=self.device), "wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
}) "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
},
)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused.""" """Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
quant_output, output_block_scale = scaled_fp4_quant( quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"]) attn_output, 1 / self.w["scale"]
return cutlass_scaled_fp4_mm(a=quant_output, )
b=self.w["weight"], return cutlass_scaled_fp4_mm(
block_scale_a=output_block_scale, a=quant_output,
block_scale_b=self.w["wscale_swizzled"], b=self.w["weight"],
alpha=self.w["scale"] * self.w["wscale"], block_scale_a=output_block_scale,
out_dtype=attn_output.dtype) block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype,
)
if current_platform.is_cuda(): if current_platform.is_cuda():
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", MODELS = [
TestAttentionFp8StaticQuantPatternModel), (
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionNvfp4QuantPatternModel)] TestAttentionFp8StaticQuantPatternModel,
),
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel,
),
]
HEADS = [(64, 8), (40, 8)] HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm(): elif current_platform.is_rocm():
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV", MODELS = [
TestAttentionFp8StaticQuantPatternModel)] ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
HEADS = [(32, 8), (40, 8)] HEADS = [(32, 8), (40, 8)]
else: else:
MODELS = [] MODELS = []
@ -331,41 +368,53 @@ else:
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", @pytest.mark.parametrize(
[7, 256, 533] if current_platform.is_cuda() else [8]) "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend",
[_Backend.FLASHINFER] if current_platform.is_cuda()
else [_Backend.TRITON_ATTN])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"split_attention", "backend",
[False, True] if current_platform.is_rocm() else [False]) [_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
)
@pytest.mark.parametrize(
"split_attention", [False, True] if current_platform.is_rocm() else [False]
)
# TODO(boyuan): test inductor graph partition on rocm # TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_inductor_graph_partition", "use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True]) [False] if current_platform.is_rocm() else [False, True],
@pytest.mark.skipif(not current_platform.is_cuda_alike(), )
reason="Only test ROCm or CUDA") @pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(current_platform.is_cuda() @pytest.mark.skipif(
and not current_platform.is_device_capability((10, 0)), current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)") reason="On CUDA only test on SM100(Blackwell)",
@pytest.mark.skipif(not current_platform.is_cuda_alike(), )
reason="Only test ROCm or CUDA") @pytest.mark.skipif(
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
head_size: int, batch_size: int, )
dtype: torch.dtype, model_name: str, def test_attention_quant_pattern(
model_class: type[AttentionQuantPatternModel], num_qo_heads: int,
backend: _Backend, split_attention: bool, num_kv_heads: int,
use_inductor_graph_partition: bool, head_size: int,
monkeypatch, dist_init, caplog_vllm): batch_size: int,
dtype: torch.dtype,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
split_attention: bool,
use_inductor_graph_partition: bool,
monkeypatch,
dist_init,
caplog_vllm,
):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
if use_inductor_graph_partition and not is_torch_equal_or_newer( if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
"2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention: if split_attention:
@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition, use_inductor_graph_partition=use_inductor_graph_partition,
), ),
cache_config=CacheConfig(cache_dtype="fp8")) cache_config=CacheConfig(cache_dtype="fp8"),
)
# Create test inputs # Create test inputs
q = torch.randn(batch_size, q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
num_qo_heads * head_size, k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
dtype=dtype, v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
v = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
# Mark first dimension as dynamic for realistic testing # Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0) torch._dynamo.mark_dynamic(q, 0)
@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# Run model directly without compilation and fusion # Run model directly without compilation and fusion
vllm_config_unfused = copy.deepcopy(vllm_config) vllm_config_unfused = copy.deepcopy(vllm_config)
with set_current_vllm_config(vllm_config_unfused), set_forward_context( with (
attn_metadata=None, vllm_config=vllm_config_unfused set_current_vllm_config(vllm_config_unfused),
), global_force_attn_backend_context_manager(backend): set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
model_unfused = model_class(num_qo_heads=num_qo_heads, global_force_attn_backend_context_manager(backend),
num_kv_heads=num_kv_heads, ):
head_size=head_size, model_unfused = model_class(
kv_cache_dtype=FP8_DTYPE, num_qo_heads=num_qo_heads,
device=device, num_kv_heads=num_kv_heads,
vllm_config=vllm_config_unfused) head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused,
)
model_unfused = model_unfused.to(device) model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata( forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size, use_hnd=split_attention) batch_size, use_hnd=split_attention
)
# Run model directly without compilation and fusion # Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v) result_unfused = model_unfused(q, k, v)
# Run model with attn fusion enabled # Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True) enable_attn_fusion=True, enable_noop=True
with set_current_vllm_config(vllm_config), set_forward_context( )
attn_metadata=None, vllm_config=vllm_config with (
), global_force_attn_backend_context_manager(backend): set_current_vllm_config(vllm_config),
model_fused = model_class(num_qo_heads=num_qo_heads, set_forward_context(attn_metadata=None, vllm_config=vllm_config),
num_kv_heads=num_kv_heads, global_force_attn_backend_context_manager(backend),
head_size=head_size, ):
kv_cache_dtype=FP8_DTYPE, model_fused = model_class(
device=device, num_qo_heads=num_qo_heads,
vllm_config=vllm_config, num_kv_heads=num_kv_heads,
w=model_unfused.w) head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
)
model_fused = model_fused.to(device) model_fused = model_fused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata( forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention) batch_size, use_hnd=split_attention
)
# Create test backend with fusion passes enabled # Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# Compile model with fusion enabled # Compile model with fusion enabled
model_compiled = torch.compile(model_fused, model_compiled = torch.compile(
backend=test_backend, model_fused, backend=test_backend, fullgraph=True
fullgraph=True) )
assert model_compiled.attn._o_scale_float is None assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v) result_fused_1 = model_compiled(q, k, v)
@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert model_compiled.attn._o_scale_float is not None assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused, torch.testing.assert_close(
result_fused_2, result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
atol=1e-2, )
rtol=1e-2)
# Check attn fusion support # Check attn fusion support
quant_key = model_class.quant_key quant_key = model_class.quant_key
attn_fusion_supported = [ attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key) for key, layer in layer.impl.fused_output_quant_supported(quant_key)
vllm_config.compilation_config.static_forward_context.items() for key, layer in vllm_config.compilation_config.static_forward_context.items()
] ]
if any(attn_fusion_supported): if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion # Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]], test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
fully_replaced=True)
# access the underlying `AttnFusionPass` on the `LazyInitPass` # access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion # Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), \ assert len(attn_nodes_pre) == len(attn_nodes_post), (
"Should have same number of attention nodes before and after fusion" "Should have same number of attention nodes before and after fusion"
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ )
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
"Attention should not have output_scale before fusion" "Attention should not have output_scale before fusion"
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ )
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
"Attention should have output_scale after fusion" "Attention should have output_scale after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale before fusion" "Attention should not have output_block_scale before fusion"
)
if quant_key.dtype == FP8_DTYPE: if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale after FP8 fusion" "Attention should not have output_block_scale after FP8 fusion"
)
elif quant_key.dtype == FP4_DTYPE: elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
"Attention should have output_block_scale after FP4 fusion" # noqa: E501 "Attention should have output_block_scale after FP4 fusion"
) # noqa: E501
# Check that results are close # Check that results are close
torch.testing.assert_close(result_unfused, torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)
result_fused_1,
atol=1e-2,
rtol=1e-2)

View File

@ -6,14 +6,12 @@ import torch
import vllm import vllm
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
VllmConfig)
from .backend import TestBackend from .backend import TestBackend
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
[torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", [256, 1024]) @pytest.mark.parametrize("num_tokens", [256, 1024])
@pytest.mark.parametrize("hidden_size", [64, 4096]) @pytest.mark.parametrize("hidden_size", [64, 4096])
def test_noop_elimination(dtype, num_tokens, hidden_size): def test_noop_elimination(dtype, num_tokens, hidden_size):
@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
torch.manual_seed(1) torch.manual_seed(1)
class Model(torch.nn.Module): class Model(torch.nn.Module):
def forward(self, x): def forward(self, x):
# Chain of reshapes # Chain of reshapes
y = x.reshape(-1, 128, 32) y = x.reshape(-1, 128, 32)
@ -32,7 +29,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
# Final reshape that should remain # Final reshape that should remain
b = a.reshape(-1, 128, 32) b = a.reshape(-1, 128, 32)
# No-op slice # No-op slice
c = b[0:b.shape[0]] c = b[0 : b.shape[0]]
# The pass should replace the result of this op with `c`. # The pass should replace the result of this op with `c`.
d = torch.slice_scatter( d = torch.slice_scatter(
torch.ones_like(c), # Dummy tensor to be scattered into torch.ones_like(c), # Dummy tensor to be scattered into
@ -43,10 +40,12 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
) )
return d return d
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
pass_config=PassConfig(enable_noop=True), level=CompilationLevel.PIECEWISE,
)) pass_config=PassConfig(enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
@ -82,17 +81,18 @@ def test_non_noop_slice_preserved():
x = torch.randn(16, 16) x = torch.randn(16, 16)
class SliceModel(torch.nn.Module): class SliceModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
base = x.clone() base = x.clone()
src = torch.ones(15, 16) src = torch.ones(15, 16)
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1) y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
return x[0:-1, :], y return x[0:-1, :], y
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
pass_config=PassConfig(enable_noop=True), level=CompilationLevel.PIECEWISE,
)) pass_config=PassConfig(enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
backend = TestBackend(noop_pass) backend = TestBackend(noop_pass)

View File

@ -28,7 +28,6 @@ def test_bad_callable():
# Pass that inherits from InductorPass # Pass that inherits from InductorPass
class ProperPass(InductorPass): class ProperPass(InductorPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None: def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass pass
@ -39,8 +38,7 @@ class ProperPass(InductorPass):
ProperPass(), ProperPass(),
# Can also wrap callables in CallableInductorPass for compliance # Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass(simple_callable), CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable, CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
InductorPass.hash_source(__file__))
], ],
) )
def test_pass_manager_uuid(callable): def test_pass_manager_uuid(callable):
@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable):
# UUID should be different due to config change # UUID should be different due to config change
config2 = copy.deepcopy(config) config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.enable_fusion = not \ config2.compilation_config.pass_config.enable_fusion = (
config2.compilation_config.pass_config.enable_fusion not config2.compilation_config.pass_config.enable_fusion
)
pass_manager3 = PostGradPassManager() pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2) pass_manager3.configure(config2)
pass_manager3.add(callable) pass_manager3.add(callable)

View File

@ -12,14 +12,20 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, from vllm.config import (
PassConfig, VllmConfig) CompilationConfig,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.distributed.parallel_state import (
initialize_model_parallel) init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
Fp8LinearOp)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
@ -36,16 +42,15 @@ prompts = [
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(
def __init__(self, self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
hidden_size=16, ):
intermediate_size=32,
vllm_config: VllmConfig = None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter( self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size))) torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05) self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -64,7 +69,7 @@ class TestModel(torch.nn.Module):
# Reshape input # Reshape input
view = hidden_states.reshape(-1, self.hidden_size) view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication # matrix multiplication
permute = self.gate_proj.permute(1, 0) permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute) mm = torch.mm(view, permute)
@ -82,7 +87,7 @@ class TestModel(torch.nn.Module):
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
torch.ops.vllm.reduce_scatter.default, torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default torch.ops.vllm.all_gather.default,
] ]
def ops_in_model(self): def ops_in_model(self):
@ -90,18 +95,16 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module): class TestQuantModel(torch.nn.Module):
def __init__(
def __init__(self, self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
hidden_size=16, ):
intermediate_size=32,
vllm_config: VllmConfig = None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.gate_proj = torch.nn.Parameter(torch.empty( self.gate_proj = torch.nn.Parameter(
(intermediate_size, hidden_size)), torch.empty((intermediate_size, hidden_size)), requires_grad=False
requires_grad=False) )
self.norm = RMSNorm(intermediate_size, 1e-05) self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -111,8 +114,7 @@ class TestQuantModel(torch.nn.Module):
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm, # Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout. # which expects a column-major layout.
self.w = torch.rand(hidden_size, self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32) self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
@ -129,7 +131,7 @@ class TestQuantModel(torch.nn.Module):
# Reshape input # Reshape input
view = hidden_states.reshape(-1, self.hidden_size) view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication # matrix multiplication
permute = self.gate_proj.permute(1, 0) permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute) mm = torch.mm(view, permute)
@ -140,45 +142,51 @@ class TestQuantModel(torch.nn.Module):
norm_output, residual_output = self.norm(all_reduce, residual) norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization # scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(norm_output, fp8_linear_result = self.fp8_linear.apply(
self.w, norm_output,
self.wscale, self.w,
input_scale=self.scale.to( self.wscale,
norm_output.device)) input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output return fp8_linear_result, residual_output
def ops_in_model_before(self): def ops_in_model_before(self):
ops_to_remove = [torch.ops.vllm.all_reduce.default ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
] # Always removed by SP
# The following are only removed if fusion happens # The following are only removed if fusion happens
if self.vllm_config and self.vllm_config.compilation_config \ if (
.pass_config.enable_fusion: self.vllm_config
ops_to_remove.extend([ and self.vllm_config.compilation_config.pass_config.enable_fusion
torch.ops._C.fused_add_rms_norm.default, ):
torch.ops._C.static_scaled_fp8_quant.default, ops_to_remove.extend(
]) [
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.static_scaled_fp8_quant.default,
]
)
return ops_to_remove return ops_to_remove
def ops_in_model_after(self): def ops_in_model_after(self):
ops_to_add = [ ops_to_add = [
torch.ops.vllm.reduce_scatter.default, torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default torch.ops.vllm.all_gather.default,
] ]
# The following is only added if fusion happens # The following is only added if fusion happens
if self.vllm_config and self.vllm_config.compilation_config \ if (
.pass_config.enable_fusion: self.vllm_config
ops_to_add.append( and self.vllm_config.compilation_config.pass_config.enable_fusion
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) ):
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
return ops_to_add return ops_to_add
def ops_in_model(self): def ops_in_model(self):
if self.vllm_config and self.vllm_config.compilation_config \ if (
.pass_config.enable_fusion: self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
# If fusion happens, the fused op is the one # If fusion happens, the fused op is the one
# we check for (de)functionalization # we check for (de)functionalization
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] # noqa: E501
] # noqa: E501
else: else:
# If no fusion, the original ops are checked # If no fusion, the original ops are checked
return [ return [
@ -195,30 +203,47 @@ class TestQuantModel(torch.nn.Module):
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False]) @pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
reason="Only test on CUDA") def test_sequence_parallelism_pass(
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module], test_model_cls: type[torch.nn.Module],
batch_size: int, seq_len: int, batch_size: int,
hidden_size: int, dtype: torch.dtype, seq_len: int,
enable_fusion: bool): hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
):
num_processes = 2 num_processes = 2
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with # need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda # torch.distributed and cuda
torch.multiprocessing.spawn(fn, torch.multiprocessing.spawn(
args=(num_processes, test_model_cls, fn,
batch_size, seq_len, hidden_size, args=(
dtype, enable_fusion), num_processes,
nprocs=nprocs) test_model_cls,
batch_size,
seq_len,
hidden_size,
dtype,
enable_fusion,
),
nprocs=nprocs,
)
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def sequence_parallelism_pass_on_test_model( def sequence_parallelism_pass_on_test_model(
local_rank: int, world_size: int, local_rank: int,
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, world_size: int,
hidden_size: int, dtype: torch.dtype, enable_fusion: bool): test_model_cls: type[torch.nn.Module],
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
):
current_platform.seed_everything(0) current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model(
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
'RANK': str(local_rank), {
'LOCAL_RANK': str(local_rank), "RANK": str(local_rank),
'WORLD_SIZE': str(world_size), "LOCAL_RANK": str(local_rank),
'MASTER_ADDR': 'localhost', "WORLD_SIZE": str(world_size),
'MASTER_PORT': '12345', "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
@ -240,27 +267,28 @@ def sequence_parallelism_pass_on_test_model(
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( vllm_config.compilation_config = CompilationConfig(
enable_sequence_parallelism=True, pass_config=PassConfig(
enable_fusion=enable_fusion, enable_sequence_parallelism=True,
enable_noop=True)) # NoOp needed for fusion enable_fusion=enable_fusion,
enable_noop=True,
)
) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name, vllm_config.model_config = ModelConfig(
trust_remote_code=True, model=model_name, trust_remote_code=True, dtype=dtype, seed=42
dtype=dtype, )
seed=42)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend: list[VllmInductorPass] = \ passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
[noop_pass, sequence_parallelism_pass]
if enable_fusion: if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config)
@ -271,12 +299,9 @@ def sequence_parallelism_pass_on_test_model(
backend_no_func = TestBackend(*passes_for_backend) backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass) backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
hidden_size * 2,
vllm_config=vllm_config)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func) compiled_model_no_func = torch.compile(model, backend=backend_no_func)
@ -297,8 +322,7 @@ def sequence_parallelism_pass_on_test_model(
# check if the functionalization pass is applied # check if the functionalization pass is applied
for op in model.ops_in_model(): for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op) find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
op) is None # noqa: E501
# make sure the ops were all de-functionalized # make sure the ops were all de-functionalized
found = dict() found = dict()

View File

@ -8,10 +8,15 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.compilation.activation_quant_fusion import ( from vllm.compilation.activation_quant_fusion import (
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) FUSED_OPS,
SILU_MUL_OP,
ActivationQuantFusionPass,
)
# yapf: enable # yapf: enable
from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
@ -19,9 +24,14 @@ from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, kFp8StaticTensorSym, kNvfp4Quant) GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_fp8_supported) Fp8LinearOp,
cutlass_fp8_supported,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported from ..utils import override_cutlass_fp8_supported
@ -36,7 +46,6 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module): class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
@ -53,10 +62,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
self.w,
self.wscale,
input_scale=self.wscale)
return x2 return x2
def ops_in_model_before(self): def ops_in_model_before(self):
@ -67,11 +73,12 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
class TestSiluMulNvfp4QuantModel(torch.nn.Module): class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
super().__init__() super().__init__()
from vllm.compilation.activation_quant_fusion import ( from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported) silu_and_mul_nvfp4_quant_supported,
)
assert silu_and_mul_nvfp4_quant_supported assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
@ -88,12 +95,14 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
out = cutlass_scaled_fp4_mm(a=y_quant, out = cutlass_scaled_fp4_mm(
b=self.w, a=y_quant,
block_scale_a=y_block_scale, b=self.w,
block_scale_b=self.w_block_scale, block_scale_a=y_block_scale,
alpha=self.alpha, block_scale_b=self.w_block_scale,
out_dtype=y.dtype) alpha=self.alpha,
out_dtype=y.dtype,
)
return out return out
def ops_in_model_before(self): def ops_in_model_before(self):
@ -108,16 +117,24 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_class", "model_class",
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] cast(
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])) list[type],
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported()
else [TestSiluMulFp8QuantModel],
),
)
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch", @pytest.mark.parametrize(
[True, False] if cutlass_fp8_supported() else [True]) "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], )
reason="Only test on CUDA and ROCm") @pytest.mark.skipif(
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
cuda_force_torch): )
def test_fusion_silu_and_mul_quant(
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4") pytest.skip("Duplicate tests for NVFP4")
@ -129,17 +146,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
config = VllmConfig() config = VllmConfig()
config.compilation_config = CompilationConfig( config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_noop=True)) pass_config=PassConfig(enable_fusion=True, enable_noop=True)
)
fusion_pass = ActivationQuantFusionPass(config) fusion_pass = ActivationQuantFusionPass(config)
passes = [ passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
backend = TestBackend(*passes) backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
cuda_force_torch=cuda_force_torch,
x=x)
# First dimension dynamic # First dimension dynamic
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
@ -155,10 +168,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
elif model_class == TestSiluMulNvfp4QuantModel: elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1 atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(result[0].to(dtype=dtype), torch.testing.assert_close(
result2[0].to(dtype=dtype), result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
atol=atol, )
rtol=rtol)
assert fusion_pass.matched_count == 1 assert fusion_pass.matched_count == 1

View File

@ -10,7 +10,6 @@ from vllm.config import CompilationLevel
class MyMod(torch.nn.Module): class MyMod(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
if cache is not None: if cache is not None:
return x + cache return x + cache
@ -18,12 +17,12 @@ class MyMod(torch.nn.Module):
class MyWrapper(TorchCompileWrapperWithCustomDispatcher): class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
compiled_callable = torch.compile(self.forward, backend="eager") compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable, super().__init__(
compilation_level=CompilationLevel.DYNAMO_ONCE) compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled # this is the function to be compiled
@ -54,10 +53,8 @@ def test_torch_compile_wrapper():
# for new input, dispatch to the compiled code directly # for new input, dispatch to the compiled code directly
new_x = torch.tensor([3]) new_x = torch.tensor([3])
assert wrapper(new_x, assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
None).item() == 6 # dispatch to the first compiled code assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
assert wrapper(
new_x, cache).item() == 5 # dispatch to the second compiled code
for wrapper in wrappers: for wrapper in wrappers:
# make sure they have independent compiled codes # make sure they have independent compiled codes

View File

@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
""" """
def create_config(): def create_config():
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", engine_args = EngineArgs(
trust_remote_code=True) model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
)
return engine_args.create_engine_config() return engine_args.create_engine_config()
# Create config with CUDA_VISIBLE_DEVICES set normally # Create config with CUDA_VISIBLE_DEVICES set normally
@ -34,16 +35,18 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
empty_config_dict.pop("instance_id", None) empty_config_dict.pop("instance_id", None)
assert deep_compare(normal_config_dict, empty_config_dict), ( assert deep_compare(normal_config_dict, empty_config_dict), (
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\"" 'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""'
" should be equivalent") " should be equivalent"
)
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch): def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
# In testing, this method needs to be nested inside as ray does not # In testing, this method needs to be nested inside as ray does not
# see the test module. # see the test module.
def create_config(): def create_config():
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", engine_args = EngineArgs(
trust_remote_code=True) model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
)
return engine_args.create_engine_config() return engine_args.create_engine_config()
config = create_config() config = create_config()
@ -51,6 +54,7 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
assert parallel_config.ray_runtime_env is None assert parallel_config.ray_runtime_env is None
import ray import ray
ray.init() ray.init()
runtime_env = { runtime_env = {
@ -59,13 +63,13 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
}, },
} }
config_ref = ray.remote(create_config).options( config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote()
runtime_env=runtime_env).remote()
config = ray.get(config_ref) config = ray.get(config_ref)
parallel_config = config.parallel_config parallel_config = config.parallel_config
assert parallel_config.ray_runtime_env is not None assert parallel_config.ray_runtime_env is not None
assert parallel_config.ray_runtime_env.env_vars().get( assert (
"TEST_ENV_VAR") == "test_value" parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value"
)
ray.shutdown() ray.shutdown()

View File

@ -16,13 +16,13 @@ def test_mp_reducer(monkeypatch):
""" """
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value # Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
monkeypatch.setenv('VLLM_USE_V1', '1') monkeypatch.setenv("VLLM_USE_V1", "1")
# Ensure transformers_modules is not in sys.modules # Ensure transformers_modules is not in sys.modules
if 'transformers_modules' in sys.modules: if "transformers_modules" in sys.modules:
del sys.modules['transformers_modules'] del sys.modules["transformers_modules"]
with patch('multiprocessing.reducer.register') as mock_register: with patch("multiprocessing.reducer.register") as mock_register:
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model="facebook/opt-125m", model="facebook/opt-125m",
max_model_len=32, max_model_len=32,
@ -36,7 +36,8 @@ def test_mp_reducer(monkeypatch):
) )
assert mock_register.called, ( assert mock_register.called, (
"multiprocessing.reducer.register should have been called") "multiprocessing.reducer.register should have been called"
)
vllm_config_registered = False vllm_config_registered = False
for call_args in mock_register.call_args_list: for call_args in mock_register.call_args_list:
@ -45,8 +46,7 @@ def test_mp_reducer(monkeypatch):
vllm_config_registered = True vllm_config_registered = True
reducer_func = call_args[0][1] reducer_func = call_args[0][1]
assert callable( assert callable(reducer_func), "Reducer function should be callable"
reducer_func), "Reducer function should be callable"
break break
assert vllm_config_registered, ( assert vllm_config_registered, (

View File

@ -30,22 +30,27 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, from transformers import (
BatchEncoding, BatchFeature) AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
BatchFeature,
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import (TokensTextLogprobs, from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
TokensTextLogprobsPromptLogprobs)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config.model import (ConvertOption, RunnerOption, from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
_get_and_verify_dtype)
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory, from vllm.distributed import (
init_distributed_environment, cleanup_dist_env_and_memory,
initialize_model_parallel) init_distributed_environment,
initialize_model_parallel,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
@ -82,12 +87,13 @@ class ImageAssetPrompts(TypedDict):
class ImageTestAssets(list[ImageAsset]): class ImageTestAssets(list[ImageAsset]):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__([ super().__init__(
ImageAsset("stop_sign"), [
ImageAsset("cherry_blossom"), ImageAsset("stop_sign"),
]) ImageAsset("cherry_blossom"),
]
)
def prompts(self, prompts: ImageAssetPrompts) -> list[str]: def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
""" """
@ -104,11 +110,12 @@ class VideoAssetPrompts(TypedDict):
class VideoTestAssets(list[VideoAsset]): class VideoTestAssets(list[VideoAsset]):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__([ super().__init__(
VideoAsset("baby_reading"), [
]) VideoAsset("baby_reading"),
]
)
def prompts(self, prompts: VideoAssetPrompts) -> list[str]: def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
return [prompts["baby_reading"]] return [prompts["baby_reading"]]
@ -120,12 +127,13 @@ class AudioAssetPrompts(TypedDict):
class AudioTestAssets(list[AudioAsset]): class AudioTestAssets(list[AudioAsset]):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__([ super().__init__(
AudioAsset("mary_had_lamb"), [
AudioAsset("winning_call"), AudioAsset("mary_had_lamb"),
]) AudioAsset("winning_call"),
]
)
def prompts(self, prompts: AudioAssetPrompts) -> list[str]: def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
return [prompts["mary_had_lamb"], prompts["winning_call"]] return [prompts["mary_had_lamb"], prompts["winning_call"]]
@ -220,6 +228,7 @@ def example_system_message() -> str:
class DecoderPromptType(Enum): class DecoderPromptType(Enum):
"""For encoder/decoder models only.""" """For encoder/decoder models only."""
CUSTOM = 1 CUSTOM = 1
NONE = 2 NONE = 2
EMPTY_STR = 3 EMPTY_STR = 3
@ -253,15 +262,13 @@ _R = TypeVar("_R")
class HfRunner: class HfRunner:
def get_default_device(self): def get_default_device(self):
from vllm.platforms import current_platform from vllm.platforms import current_platform
return ("cpu" return "cpu" if current_platform.is_cpu() else current_platform.device_type
if current_platform.is_cpu() else current_platform.device_type)
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )): if x is None or isinstance(x, (bool,)):
return x return x
if device is None: if device is None:
@ -289,8 +296,11 @@ class HfRunner:
# Set this to avoid hanging issue # Set this to avoid hanging issue
default_torch_num_threads: Optional[int] = None, default_torch_num_threads: Optional[int] = None,
) -> None: ) -> None:
init_ctx = (nullcontext() if default_torch_num_threads is None else init_ctx = (
set_default_torch_num_threads(default_torch_num_threads)) nullcontext()
if default_torch_num_threads is None
else set_default_torch_num_threads(default_torch_num_threads)
)
with init_ctx: with init_ctx:
self._init( self._init(
@ -362,14 +372,15 @@ class HfRunner:
) )
# in case some unquantized custom models are not in same dtype # in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None if getattr(model, "quantization_method", None) is None and any(
and any(p.dtype != self.dtype p.dtype != self.dtype for p in model.parameters()
for p in model.parameters())): ):
model = model.to(dtype=self.dtype) model = model.to(dtype=self.dtype)
if (getattr(model, "quantization_method", None) != "bitsandbytes" if (
and len({p.device getattr(model, "quantization_method", None) != "bitsandbytes"
for p in model.parameters()}) < 2): and len({p.device for p in model.parameters()}) < 2
):
model = model.to(device=self.device) model = model.to(device=self.device)
self.model = model self.model = model
@ -384,6 +395,7 @@ class HfRunner:
# don't put this import at the top level # don't put this import at the top level
# it will call torch.cuda.device_count() # it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401 from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -471,10 +483,9 @@ class HfRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
all_inputs = self.get_inputs(prompts, all_inputs = self.get_inputs(
images=images, prompts, images=images, videos=videos, audios=audios
videos=videos, )
audios=audios)
outputs: list[tuple[list[list[int]], list[str]]] = [] outputs: list[tuple[list[list[int]], list[str]]] = []
for inputs in all_inputs: for inputs in all_inputs:
@ -501,16 +512,17 @@ class HfRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[list[int], str]]: ) -> list[tuple[list[int], str]]:
outputs = self.generate(prompts, outputs = self.generate(
do_sample=False, prompts,
max_new_tokens=max_tokens, do_sample=False,
images=images, max_new_tokens=max_tokens,
videos=videos, images=images,
audios=audios, videos=videos,
**kwargs) audios=audios,
**kwargs,
)
return [(output_ids[0], output_str[0]) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
for output_ids, output_str in outputs]
def generate_beam_search( def generate_beam_search(
self, self,
@ -521,21 +533,22 @@ class HfRunner:
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
outputs = self.generate(prompts, outputs = self.generate(
do_sample=False, prompts,
max_new_tokens=max_tokens, do_sample=False,
num_beams=beam_width, max_new_tokens=max_tokens,
num_return_sequences=beam_width, num_beams=beam_width,
images=images, num_return_sequences=beam_width,
videos=videos, images=images,
audios=audios) videos=videos,
audios=audios,
)
for i in range(len(outputs)): for i in range(len(outputs)):
output_ids, output_str = outputs[i] output_ids, output_str = outputs[i]
for j in range(len(output_ids)): for j in range(len(output_ids)):
output_ids[j] = [ output_ids[j] = [
x for x in output_ids[j] x for x in output_ids[j] if x != self.tokenizer.pad_token_id
if x != self.tokenizer.pad_token_id
] ]
outputs[i] = (output_ids, output_str) outputs[i] = (output_ids, output_str)
return outputs return outputs
@ -549,10 +562,9 @@ class HfRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> list[list[torch.Tensor]]: ) -> list[list[torch.Tensor]]:
all_inputs = self.get_inputs(prompts, all_inputs = self.get_inputs(
images=images, prompts, images=images, videos=videos, audios=audios
videos=videos, )
audios=audios)
all_logprobs: list[list[torch.Tensor]] = [] all_logprobs: list[list[torch.Tensor]] = []
for inputs in all_inputs: for inputs in all_inputs:
@ -565,8 +577,7 @@ class HfRunner:
return_dict_in_generate=True, return_dict_in_generate=True,
**kwargs, **kwargs,
) )
seq_logprobs = self._hidden_states_to_seq_logprobs( seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states)
output.hidden_states)
all_logprobs.append(seq_logprobs) all_logprobs.append(seq_logprobs)
return all_logprobs return all_logprobs
@ -630,10 +641,9 @@ class HfRunner:
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
**kwargs: Any, **kwargs: Any,
) -> list[TokensTextLogprobs]: ) -> list[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts, all_inputs = self.get_inputs(
images=images, prompts, images=images, videos=videos, audios=audios
videos=videos, )
audios=audios)
all_logprobs: list[list[dict[int, float]]] = [] all_logprobs: list[list[dict[int, float]]] = []
all_output_ids: list[list[int]] = [] all_output_ids: list[list[int]] = []
@ -653,8 +663,7 @@ class HfRunner:
( (
seq_logprobs_lst, seq_logprobs_lst,
output_len, output_len,
) = self._hidden_states_to_logprobs(output.hidden_states, ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
num_logprobs)
all_logprobs.append(seq_logprobs_lst) all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0] seq_ids = output.sequences[0]
@ -664,19 +673,16 @@ class HfRunner:
all_output_strs.append(self.tokenizer.decode(output_ids)) all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs) outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs) return [
for output_ids, output_str, output_logprobs in outputs] (output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs
]
def encode(self, prompts: list[str], *args, def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs) return self.model.encode(prompts, *args, **kwargs)
def predict(self, prompts: list[list[str]], *args, def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
**kwargs) -> torch.Tensor: return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
return self.model.predict(prompts,
*args,
convert_to_tensor=True,
**kwargs)
def __enter__(self): def __enter__(self):
return self return self
@ -727,8 +733,11 @@ class VllmRunner:
default_torch_num_threads: Optional[int] = None, default_torch_num_threads: Optional[int] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
init_ctx = (nullcontext() if default_torch_num_threads is None else init_ctx = (
set_default_torch_num_threads(default_torch_num_threads)) nullcontext()
if default_torch_num_threads is None
else set_default_torch_num_threads(default_torch_num_threads)
)
if not kwargs.get("compilation_config", None): if not kwargs.get("compilation_config", None):
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]} kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
@ -760,11 +769,12 @@ class VllmRunner:
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
if any(x is not None and len(x) != len(prompts) if any(
for x in [images, videos, audios]): x is not None and len(x) != len(prompts) for x in [images, videos, audios]
):
raise ValueError( raise ValueError(
"All non-None multimodal inputs must have the same length as " "All non-None multimodal inputs must have the same length as prompts"
"prompts") )
inputs = list[dict[str, Any]]() inputs = list[dict[str, Any]]()
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
@ -800,14 +810,11 @@ class VllmRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
inputs = self.get_inputs(prompts, inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
images=images,
videos=videos,
audios=audios)
req_outputs = self.llm.generate(inputs, req_outputs = self.llm.generate(
sampling_params=sampling_params, inputs, sampling_params=sampling_params, **kwargs
**kwargs) )
outputs: list[tuple[list[list[int]], list[str]]] = [] outputs: list[tuple[list[list[int]], list[str]]] = []
for req_output in req_outputs: for req_output in req_outputs:
@ -834,8 +841,9 @@ class VllmRunner:
output_str = sample.text output_str = sample.text
output_ids = list(sample.token_ids) output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs, outputs.append(
req_output.prompt_logprobs)) (output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
)
return outputs return outputs
def generate_w_logprobs( def generate_w_logprobs(
@ -846,23 +854,22 @@ class VllmRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[list[TokensTextLogprobs], ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
list[TokensTextLogprobsPromptLogprobs]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.llm.generate(inputs, req_outputs = self.llm.generate(
sampling_params=sampling_params, inputs, sampling_params=sampling_params, **kwargs
**kwargs) )
toks_str_logsprobs_prompt_logprobs = ( toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
self._final_steps_generate_w_logprobs(req_outputs)) req_outputs
)
# Omit prompt logprobs if not required by sampling params # Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] return (
if sampling_params.prompt_logprobs is None else [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
toks_str_logsprobs_prompt_logprobs) if sampling_params.prompt_logprobs is None
else toks_str_logsprobs_prompt_logprobs
)
def generate_greedy( def generate_greedy(
self, self,
@ -874,14 +881,15 @@ class VllmRunner:
**kwargs: Any, **kwargs: Any,
) -> list[tuple[list[int], str]]: ) -> list[tuple[list[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, outputs = self.generate(
greedy_params, prompts,
images=images, greedy_params,
videos=videos, images=images,
audios=audios, videos=videos,
**kwargs) audios=audios,
return [(output_ids[0], output_str[0]) **kwargs,
for output_ids, output_str in outputs] )
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
def generate_greedy_logprobs( def generate_greedy_logprobs(
self, self,
@ -895,22 +903,24 @@ class VllmRunner:
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[list[TokensTextLogprobs], ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
list[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams( greedy_logprobs_params = SamplingParams(
temperature=0.0, temperature=0.0,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs, logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs, prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
stop=stop) stop=stop,
)
return self.generate_w_logprobs(prompts, return self.generate_w_logprobs(
greedy_logprobs_params, prompts,
images=images, greedy_logprobs_params,
audios=audios, images=images,
videos=videos, audios=audios,
**kwargs) videos=videos,
**kwargs,
)
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
""" """
@ -919,10 +929,9 @@ class VllmRunner:
:param prompts: list of prompts to score :param prompts: list of prompts to score
:return: perplexity score of each prompt :return: perplexity score of each prompt
""" """
outputs = self.generate_greedy_logprobs(prompts, outputs = self.generate_greedy_logprobs(
max_tokens=1, prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
num_logprobs=None, )
num_prompt_logprobs=0)
perplexities = [] perplexities = []
for output in outputs: for output in outputs:
@ -951,15 +960,13 @@ class VllmRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
concurrency_limit: Optional[int] = None, concurrency_limit: Optional[int] = None,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
inputs = self.get_inputs(prompts, inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
images=images,
videos=videos,
audios=audios)
outputs = self.llm.beam_search(inputs, outputs = self.llm.beam_search(
BeamSearchParams(beam_width=beam_width, inputs,
max_tokens=max_tokens), BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
concurrency_limit=concurrency_limit) concurrency_limit=concurrency_limit,
)
returned_outputs = [] returned_outputs = []
for output in outputs: for output in outputs:
token_ids = [x.tokens for x in output.sequences] token_ids = [x.tokens for x in output.sequences]
@ -971,17 +978,16 @@ class VllmRunner:
req_outputs = self.llm.classify(prompts) req_outputs = self.llm.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs] return [req_output.outputs.probs for req_output in req_outputs]
def embed(self, def embed(
prompts: list[str], self,
images: Optional[PromptImageInput] = None, prompts: list[str],
videos: Optional[PromptVideoInput] = None, images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None,
*args, audios: Optional[PromptAudioInput] = None,
**kwargs) -> list[list[float]]: *args,
inputs = self.get_inputs(prompts, **kwargs,
images=images, ) -> list[list[float]]:
videos=videos, inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
audios=audios)
req_outputs = self.llm.embed(inputs, *args, **kwargs) req_outputs = self.llm.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs] return [req_output.outputs.embedding for req_output in req_outputs]
@ -1026,6 +1032,7 @@ def vllm_runner():
@pytest.fixture() @pytest.fixture()
def temporary_enable_log_propagate(): def temporary_enable_log_propagate():
import logging import logging
logger = logging.getLogger("vllm") logger = logging.getLogger("vllm")
logger.propagate = True logger.propagate = True
yield yield
@ -1045,6 +1052,7 @@ def num_gpus_available():
in current process.""" in current process."""
from vllm.platforms import current_platform from vllm.platforms import current_platform
return current_platform.device_count() return current_platform.device_count()
@ -1058,12 +1066,11 @@ _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
def dummy_opt_path(): def dummy_opt_path():
json_path = os.path.join(_dummy_opt_path, "config.json") json_path = os.path.join(_dummy_opt_path, "config.json")
if not os.path.exists(_dummy_opt_path): if not os.path.exists(_dummy_opt_path):
snapshot_download(repo_id="facebook/opt-125m", snapshot_download(
local_dir=_dummy_opt_path, repo_id="facebook/opt-125m",
ignore_patterns=[ local_dir=_dummy_opt_path,
"*.bin", "*.bin.index.json", "*.pt", "*.h5", ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
"*.msgpack" )
])
assert os.path.exists(json_path) assert os.path.exists(json_path)
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
@ -1077,12 +1084,18 @@ def dummy_opt_path():
def dummy_llava_path(): def dummy_llava_path():
json_path = os.path.join(_dummy_llava_path, "config.json") json_path = os.path.join(_dummy_llava_path, "config.json")
if not os.path.exists(_dummy_llava_path): if not os.path.exists(_dummy_llava_path):
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf", snapshot_download(
local_dir=_dummy_llava_path, repo_id="llava-hf/llava-1.5-7b-hf",
ignore_patterns=[ local_dir=_dummy_llava_path,
"*.bin", "*.bin.index.json", "*.pt", "*.h5", ignore_patterns=[
"*.msgpack", "*.safetensors" "*.bin",
]) "*.bin.index.json",
"*.pt",
"*.h5",
"*.msgpack",
"*.safetensors",
],
)
assert os.path.exists(json_path) assert os.path.exists(json_path)
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
@ -1096,12 +1109,18 @@ def dummy_llava_path():
def dummy_gemma2_embedding_path(): def dummy_gemma2_embedding_path():
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
if not os.path.exists(_dummy_gemma2_embedding_path): if not os.path.exists(_dummy_gemma2_embedding_path):
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2", snapshot_download(
local_dir=_dummy_gemma2_embedding_path, repo_id="BAAI/bge-multilingual-gemma2",
ignore_patterns=[ local_dir=_dummy_gemma2_embedding_path,
"*.bin", "*.bin.index.json", "*.pt", "*.h5", ignore_patterns=[
"*.msgpack", "*.safetensors" "*.bin",
]) "*.bin.index.json",
"*.pt",
"*.h5",
"*.msgpack",
"*.safetensors",
],
)
assert os.path.exists(json_path) assert os.path.exists(json_path)
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
@ -1114,10 +1133,9 @@ def dummy_gemma2_embedding_path():
# Add the flag `--optional` to allow run tests # Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional # that are marked with @pytest.mark.optional
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption("--optional", parser.addoption(
action="store_true", "--optional", action="store_true", default=False, help="run optional test"
default=False, )
help="run optional test")
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
@ -1185,7 +1203,6 @@ def _find_free_port() -> int:
class LocalAssetServer: class LocalAssetServer:
address: str address: str
port: int port: int
server: Optional[http.server.ThreadingHTTPServer] server: Optional[http.server.ThreadingHTTPServer]
@ -1200,9 +1217,9 @@ class LocalAssetServer:
def __enter__(self): def __enter__(self):
self.port = _find_free_port() self.port = _find_free_port()
self.server = http.server.ThreadingHTTPServer( self.server = http.server.ThreadingHTTPServer(
(self.address, self.port), AssetHandler) (self.address, self.port), AssetHandler
self.thread = threading.Thread(target=self.server.serve_forever, )
daemon=True) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start() self.thread.start()
return self return self

View File

@ -13,7 +13,7 @@ from vllm.platforms import current_platform
def check_cuda_context(): def check_cuda_context():
"""Check CUDA driver context status""" """Check CUDA driver context status"""
try: try:
cuda = ctypes.CDLL('libcuda.so') cuda = ctypes.CDLL("libcuda.so")
device = ctypes.c_int() device = ctypes.c_int()
result = cuda.cuCtxGetDevice(ctypes.byref(device)) result = cuda.cuCtxGetDevice(ctypes.byref(device))
return (True, device.value) if result == 0 else (False, None) return (True, device.value) if result == 0 else (False, None)
@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
# New thread should have no CUDA context initially # New thread should have no CUDA context initially
valid_before, device_before = check_cuda_context() valid_before, device_before = check_cuda_context()
if valid_before: if valid_before:
return False, \ return (
"CUDA context should not exist in new thread, " \ False,
f"got device {device_before}" "CUDA context should not exist in new thread, "
f"got device {device_before}",
)
# Test setting CUDA context # Test setting CUDA context
current_platform.set_device(device_input) current_platform.set_device(device_input)
@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
if not valid_after: if not valid_after:
return False, "CUDA context should be valid after set_cuda_context" return False, "CUDA context should be valid after set_cuda_context"
if device_id != expected_device_id: if device_id != expected_device_id:
return False, \ return False, f"Expected device {expected_device_id}, got {device_id}"
f"Expected device {expected_device_id}, got {device_id}"
return True, "Success" return True, "Success"
except Exception as e: except Exception as e:
@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
class TestSetCudaContext: class TestSetCudaContext:
"""Test suite for the set_cuda_context function.""" """Test suite for the set_cuda_context function."""
@pytest.mark.skipif(not current_platform.is_cuda(), @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
reason="CUDA not available") @pytest.mark.parametrize(
@pytest.mark.parametrize(argnames="device_input,expected_device_id", argnames="device_input,expected_device_id",
argvalues=[ argvalues=[
(0, 0), (0, 0),
(torch.device('cuda:0'), 0), (torch.device("cuda:0"), 0),
('cuda:0', 0), ("cuda:0", 0),
], ],
ids=["int", "torch_device", "string"]) ids=["int", "torch_device", "string"],
def test_set_cuda_context_parametrized(self, device_input, )
expected_device_id): def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
"""Test setting CUDA context in isolated threads.""" """Test setting CUDA context in isolated threads."""
with ThreadPoolExecutor(max_workers=1) as executor: with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_cuda_test_in_thread, device_input, future = executor.submit(
expected_device_id) run_cuda_test_in_thread, device_input, expected_device_id
)
success, message = future.result(timeout=30) success, message = future.result(timeout=30)
assert success, message assert success, message
@pytest.mark.skipif(not current_platform.is_cuda(), @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
reason="CUDA not available")
def test_set_cuda_context_invalid_device_type(self): def test_set_cuda_context_invalid_device_type(self):
"""Test error handling for invalid device type.""" """Test error handling for invalid device type."""
with pytest.raises(ValueError, match="Expected a cuda device"): with pytest.raises(ValueError, match="Expected a cuda device"):
current_platform.set_device(torch.device('cpu')) current_platform.set_device(torch.device("cpu"))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str):
prompt = ( prompt = (
"You are a helpful assistant. How do I build a car from cardboard and " "You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available " "paper clips? Is there an easy to follow video tutorial available "
"online for free?") "online for free?"
)
llm = LLM(model=model) llm = LLM(model=model)
sampling_params = SamplingParams(max_tokens=10, sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False)
temperature=0.0,
detokenize=False)
outputs_no_detokenization = llm.generate(prompt, outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
sampling_params)[0].outputs[0]
sampling_params.detokenize = True sampling_params.detokenize = True
outputs_with_detokenization = llm.generate(prompt, outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
sampling_params)[0].outputs[0]
assert outputs_no_detokenization.text == '' assert outputs_no_detokenization.text == ""
assert outputs_with_detokenization.text != '' assert outputs_with_detokenization.text != ""
assert outputs_no_detokenization.token_ids == \ assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids
outputs_with_detokenization.token_ids

View File

@ -8,15 +8,17 @@ from vllm import SamplingParams
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
PROMPT = "Hello, my name is Lee, and I'm a student in the " + \ PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering"
"college of engineering"
@pytest.mark.parametrize("min_tokens,stop,truth", [ @pytest.mark.parametrize(
(0, None, " is Lee, and I'm a student in the college of engineering"), "min_tokens,stop,truth",
(0, "e", " is L"), [
(5, "e", " is Lee, and I'm a stud"), (0, None, " is Lee, and I'm a student in the college of engineering"),
]) (0, "e", " is L"),
(5, "e", " is Lee, and I'm a stud"),
],
)
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
"""Test for a specific min_tokens and stop. """Test for a specific min_tokens and stop.
@ -31,16 +33,18 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
stop=stop, stop=stop,
min_tokens=min_tokens, min_tokens=min_tokens,
) )
request = EngineCoreRequest(request_id="", request = EngineCoreRequest(
prompt_token_ids=prompt_token_ids, request_id="",
mm_features=None, prompt_token_ids=prompt_token_ids,
sampling_params=params, mm_features=None,
pooling_params=None, sampling_params=params,
eos_token_id=None, pooling_params=None,
arrival_time=0.0, eos_token_id=None,
lora_request=None, arrival_time=0.0,
cache_salt=None, lora_request=None,
data_parallel_rank=None) cache_salt=None,
data_parallel_rank=None,
)
detokenizer = FastIncrementalDetokenizer(tokenizer, request) detokenizer = FastIncrementalDetokenizer(tokenizer, request)

View File

@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts):
llm = vllm_model.llm llm = vllm_model.llm
# test stop token # test stop token
outputs = llm.generate(example_prompts, outputs = llm.generate(
sampling_params=SamplingParams( example_prompts,
ignore_eos=True, sampling_params=SamplingParams(
seed=SEED, ignore_eos=True,
max_tokens=MAX_TOKENS, seed=SEED,
stop_token_ids=[stop_token_id])) max_tokens=MAX_TOKENS,
stop_token_ids=[stop_token_id],
),
)
for output in outputs: for output in outputs:
output = output.outputs[0] output = output.outputs[0]
assert output.finish_reason == "stop" assert output.finish_reason == "stop"
assert output.stop_reason == stop_token_id assert output.stop_reason == stop_token_id
# test stop string # test stop string
outputs = llm.generate(example_prompts, outputs = llm.generate(
sampling_params=SamplingParams( example_prompts,
ignore_eos=True, sampling_params=SamplingParams(
seed=SEED, ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="."
max_tokens=MAX_TOKENS, ),
stop=".")) )
for output in outputs: for output in outputs:
output = output.outputs[0] output = output.outputs[0]
assert output.finish_reason == "stop" assert output.finish_reason == "stop"
assert output.stop_reason == STOP_STR assert output.stop_reason == STOP_STR
# test EOS token # test EOS token
outputs = llm.generate(example_prompts, outputs = llm.generate(
sampling_params=SamplingParams( example_prompts,
seed=SEED, max_tokens=MAX_TOKENS)) sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS),
)
for output in outputs: for output in outputs:
output = output.outputs[0] output = output.outputs[0]
assert output.finish_reason == "length" or ( assert output.finish_reason == "length" or (
output.finish_reason == "stop" and output.stop_reason is None) output.finish_reason == "stop" and output.stop_reason is None
)

View File

@ -14,7 +14,6 @@ def include_stop_str_in_output(request):
class _DummyDetokenizer(BaseIncrementalDetokenizer): class _DummyDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, request: EngineCoreRequest): def __init__(self, request: EngineCoreRequest):
super().__init__(request) super().__init__(request)
@ -27,7 +26,8 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
params = SamplingParams( params = SamplingParams(
stop=stop, stop=stop,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
min_tokens=min_tokens) min_tokens=min_tokens,
)
# Keep other fields minimal for unit test purposes. # Keep other fields minimal for unit test purposes.
req = EngineCoreRequest( req = EngineCoreRequest(
request_id="test", request_id="test",
@ -44,8 +44,7 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
return req return req
def test_stop_string_while_stop_token_terminates( def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool):
include_stop_str_in_output: bool):
""" """
This test verifies that the detokenizer correctly handles the case where This test verifies that the detokenizer correctly handles the case where
the generated token sequence contains both: the generated token sequence contains both:
@ -78,8 +77,9 @@ def test_stop_string_while_stop_token_terminates(
token_ids = [ord(c) for c in generated_text] token_ids = [ord(c) for c in generated_text]
# Create a request with the stop string and initialize the detokenizer. # Create a request with the stop string and initialize the detokenizer.
req = _make_request(stop=[stop_string], req = _make_request(
include_stop_str_in_output=include_stop_str_in_output) stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output
)
detok = _DummyDetokenizer(req) detok = _DummyDetokenizer(req)
# Simulate that the last token ('Z') is a stop token (stop_terminated=True). # Simulate that the last token ('Z') is a stop token (stop_terminated=True).
@ -99,5 +99,4 @@ def test_stop_string_while_stop_token_terminates(
# get_next_output_text should return the full text when finished=True. # get_next_output_text should return the full text when finished=True.
# (Buffering only applies during streaming when finished=False.) # (Buffering only applies during streaming when finished=False.)
assert detok.get_next_output_text(finished=True, assert detok.get_next_output_text(finished=True, delta=False) == expected_text
delta=False) == expected_text

View File

@ -11,12 +11,14 @@ MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200 MAX_TOKENS = 200
def _test_stopping(llm: LLM, def _test_stopping(
expected_output: str, llm: LLM,
expected_reason: Any, expected_output: str,
stop: Optional[list[str]] = None, expected_reason: Any,
stop_token_ids: Optional[list[int]] = None, stop: Optional[list[str]] = None,
include_in_output: bool = False) -> None: stop_token_ids: Optional[list[int]] = None,
include_in_output: bool = False,
) -> None:
output = llm.generate( output = llm.generate(
"A story about vLLM:\n", "A story about vLLM:\n",
SamplingParams( SamplingParams(
@ -25,7 +27,8 @@ def _test_stopping(llm: LLM,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output, include_stop_str_in_output=include_in_output,
))[0].outputs[0] ),
)[0].outputs[0]
assert output is not None assert output is not None
assert output.text == expected_output assert output.text == expected_output
@ -33,17 +36,21 @@ def _test_stopping(llm: LLM,
def _stop_basic(llm): def _stop_basic(llm):
_test_stopping(llm, _test_stopping(
stop=["."], llm,
include_in_output=False, stop=["."],
expected_output="VLLM is a 100% volunteer organization", include_in_output=False,
expected_reason=".") expected_output="VLLM is a 100% volunteer organization",
expected_reason=".",
)
_test_stopping(llm, _test_stopping(
stop=["."], llm,
include_in_output=True, stop=["."],
expected_output="VLLM is a 100% volunteer organization.", include_in_output=True,
expected_reason=".") expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".",
)
def _stop_multi_tokens(llm): def _stop_multi_tokens(llm):
@ -52,45 +59,54 @@ def _stop_multi_tokens(llm):
stop=["group of peo", "short"], stop=["group of peo", "short"],
include_in_output=False, include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ", expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo") expected_reason="group of peo",
)
_test_stopping( _test_stopping(
llm, llm,
stop=["group of peo", "short"], stop=["group of peo", "short"],
include_in_output=True, include_in_output=True,
expected_output= expected_output="VLLM is a 100% volunteer organization. We are a group of peo",
"VLLM is a 100% volunteer organization. We are a group of peo", expected_reason="group of peo",
expected_reason="group of peo") )
def _stop_partial_token(llm): def _stop_partial_token(llm):
_test_stopping(llm, _test_stopping(
stop=["gani"], llm,
include_in_output=False, stop=["gani"],
expected_output="VLLM is a 100% volunteer or", include_in_output=False,
expected_reason="gani") expected_output="VLLM is a 100% volunteer or",
expected_reason="gani",
)
_test_stopping(llm, _test_stopping(
stop=["gani"], llm,
include_in_output=True, stop=["gani"],
expected_output="VLLM is a 100% volunteer organi", include_in_output=True,
expected_reason="gani") expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani",
)
def _stop_token_id(llm): def _stop_token_id(llm):
# token id 13013 => " organization" # token id 13013 => " organization"
_test_stopping(llm, _test_stopping(
stop_token_ids=[13013], llm,
include_in_output=False, stop_token_ids=[13013],
expected_output="VLLM is a 100% volunteer", include_in_output=False,
expected_reason=13013) expected_output="VLLM is a 100% volunteer",
expected_reason=13013,
)
_test_stopping(llm, _test_stopping(
stop_token_ids=[13013], llm,
include_in_output=True, stop_token_ids=[13013],
expected_output="VLLM is a 100% volunteer organization", include_in_output=True,
expected_reason=13013) expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013,
)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup

View File

@ -111,8 +111,7 @@ class MockSubscriber:
self.last_seq = -1 self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type) self.decoder = msgspec.msgpack.Decoder(type=decode_type)
def receive_one(self, def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]:
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
"""Receive a single message with timeout""" """Receive a single message with timeout"""
if not self.sub.poll(timeout): if not self.sub.poll(timeout):
return None return None
@ -135,8 +134,7 @@ class MockSubscriber:
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big")) self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
def receive_replay(self, def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages from a specific replay socket""" """Receive replayed messages from a specific replay socket"""
if not self.replay_sockets: if not self.replay_sockets:
raise ValueError("Replay sockets not initialized") raise ValueError("Replay sockets not initialized")

View File

@ -12,7 +12,8 @@ import torch.distributed as dist
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
CustomAllreduce) CustomAllreduce,
)
# create a cpu process group for communicating metadata (ipc handle) # create a cpu process group for communicating metadata (ipc handle)
dist.init_process_group(backend="gloo") dist.init_process_group(backend="gloo")
@ -52,7 +53,8 @@ for p in pointers:
assert ord(host_data[i]) == byte_value, ( assert ord(host_data[i]) == byte_value, (
f"Rank {rank} failed" f"Rank {rank} failed"
f" to verify buffer {p}. Expected {byte_value}, " f" to verify buffer {p}. Expected {byte_value}, "
f"got {ord(host_data[i])}") f"got {ord(host_data[i])}"
)
print(f"Rank {rank} verified all buffers") print(f"Rank {rank} verified all buffers")

View File

@ -13,13 +13,19 @@ import pytest
import ray import ray
import torch import torch
from vllm.distributed import (broadcast_tensor_dict, get_pp_group, from vllm.distributed import (
tensor_model_parallel_all_gather, broadcast_tensor_dict,
tensor_model_parallel_all_reduce, get_pp_group,
tensor_model_parallel_reduce_scatter) tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from ..utils import (init_test_distributed_environment, multi_gpu_test, from ..utils import (
multi_process_parallel) init_test_distributed_environment,
multi_gpu_test,
multi_process_parallel,
)
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
@ -37,12 +43,11 @@ def all_reduce_test_worker(
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
num_elements = 8 num_elements = 8
all_tensors = [ all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") * torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
(r + 1) for r in range(tp_size) for r in range(tp_size)
] ]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank % tp_size] t = all_tensors[rank % tp_size]
@ -51,28 +56,31 @@ def all_reduce_test_worker(
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, def reduce_scatter_test_worker(
pp_size: int, rank: int, monkeypatch: pytest.MonkeyPatch,
distributed_init_port: str): tp_size: int,
pp_size: int,
rank: int,
distributed_init_port: str,
):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs # so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU # they will be able to set the device to the correct GPU
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
num_elements = 8 num_elements = 8
all_tensors = [ all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") * torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
(r + 1) for r in range(tp_size) for r in range(tp_size)
] ]
index = rank % tp_size index = rank % tp_size
partition_size = num_elements // tp_size partition_size = num_elements // tp_size
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
expected = all_reduce[index * partition_size:(index + 1) * partition_size] expected = all_reduce[index * partition_size : (index + 1) * partition_size]
t = all_tensors[index] t = all_tensors[index]
t = tensor_model_parallel_reduce_scatter(t, 0) t = tensor_model_parallel_reduce_scatter(t, 0)
torch.testing.assert_close(t, expected) torch.testing.assert_close(t, expected)
@ -92,8 +100,7 @@ def all_gather_test_worker(
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
num_dimensions = 3 num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2)) tensor_size = list(range(2, num_dimensions + 2))
total_size = 1 total_size = 1
@ -101,8 +108,10 @@ def all_gather_test_worker(
total_size *= s total_size *= s
for all_gather_dimension in range(num_dimensions): for all_gather_dimension in range(num_dimensions):
all_tensors = [ all_tensors = [
torch.arange(total_size, dtype=torch.float32, torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
device="cuda").reshape(tensor_size) * (r + 1) tensor_size
)
* (r + 1)
for r in range(tp_size) for r in range(tp_size)
] ]
expected = torch.cat(all_tensors, dim=all_gather_dimension) expected = torch.cat(all_tensors, dim=all_gather_dimension)
@ -125,8 +134,7 @@ def broadcast_tensor_dict_test_worker(
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
test_dict = { test_dict = {
# device tensor # device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"), "a": torch.arange(8, dtype=torch.float32, device="cuda"),
@ -134,10 +142,7 @@ def broadcast_tensor_dict_test_worker(
"b": torch.arange(16, dtype=torch.int8, device="cpu"), "b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test", "c": "test",
"d": [1, 2, 3], "d": [1, 2, 3],
"e": { "e": {"a": 1, "b": 2},
"a": 1,
"b": 2
},
# empty tensor # empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"), "f": torch.tensor([], dtype=torch.float32, device="cuda"),
} }
@ -166,8 +171,7 @@ def send_recv_tensor_dict_test_worker(
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
test_dict = { test_dict = {
# device tensor # device tensor
@ -176,10 +180,7 @@ def send_recv_tensor_dict_test_worker(
"b": torch.arange(16, dtype=torch.int8, device="cpu"), "b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test", "c": "test",
"d": [1, 2, 3], "d": [1, 2, 3],
"e": { "e": {"a": 1, "b": 2},
"a": 1,
"b": 2
},
# empty tensor # empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"), "f": torch.tensor([], dtype=torch.float32, device="cuda"),
} }
@ -211,8 +212,7 @@ def send_recv_test_worker(
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
size = 64 size = 64
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
@ -229,10 +229,10 @@ def send_recv_test_worker(
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [ @pytest.mark.parametrize(
all_reduce_test_worker, all_gather_test_worker, "test_target",
broadcast_tensor_dict_test_worker [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
]) )
def test_multi_process_tensor_parallel( def test_multi_process_tensor_parallel(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
tp_size: int, tp_size: int,
@ -244,7 +244,8 @@ def test_multi_process_tensor_parallel(
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
)
def test_multi_process_pipeline_parallel( def test_multi_process_pipeline_parallel(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
pp_size: int, pp_size: int,
@ -256,11 +257,16 @@ def test_multi_process_pipeline_parallel(
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [ @pytest.mark.parametrize(
send_recv_test_worker, send_recv_tensor_dict_test_worker, "test_target",
all_reduce_test_worker, all_gather_test_worker, [
broadcast_tensor_dict_test_worker send_recv_test_worker,
]) send_recv_tensor_dict_test_worker,
all_reduce_test_worker,
all_gather_test_worker,
broadcast_tensor_dict_test_worker,
],
)
def test_multi_process_tensor_parallel_pipeline_parallel( def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,

View File

@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
all workers in a node other than the head node, which can cause the test all workers in a node other than the head node, which can cause the test
to fail. to fail.
""" """
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -56,7 +57,8 @@ class CPTestSettings:
raise ValueError( raise ValueError(
f"Length mismatch: distributed_backends " f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != " f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})") f"vllm_major_versions ({len(self.vllm_major_versions)})"
)
@staticmethod @staticmethod
def detailed( def detailed(
@ -74,29 +76,39 @@ class CPTestSettings:
for dcp_multiplier in [0.5, 1]: for dcp_multiplier in [0.5, 1]:
for chunked_prefill_val in [True]: for chunked_prefill_val in [True]:
parallel_setups.append( parallel_setups.append(
ParallelSetup(tp_size=tp_base, ParallelSetup(
pp_size=pp_multiplier * pp_base, tp_size=tp_base,
dcp_size=int(dcp_multiplier * pp_size=pp_multiplier * pp_base,
tp_base), dcp_size=int(dcp_multiplier * tp_base),
eager_mode=eager_mode_val, eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val)) chunked_prefill=chunked_prefill_val,
)
)
return CPTestSettings( return CPTestSettings(
parallel_setups=parallel_setups, parallel_setups=parallel_setups,
distributed_backends=["mp"], distributed_backends=["mp"],
vllm_major_versions=["1"], vllm_major_versions=["1"],
runner=runner, runner=runner,
test_options=CPTestOptions(multi_node_only=multi_node_only, test_options=CPTestOptions(
load_format=load_format), multi_node_only=multi_node_only, load_format=load_format
),
) )
def iter_params(self, model_id: str): def iter_params(self, model_id: str):
opts = self.test_options opts = self.test_options
for parallel_setup in self.parallel_setups: for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends, for backend, vllm_major_version in zip(
self.vllm_major_versions): self.distributed_backends, self.vllm_major_versions
yield (model_id, parallel_setup, backend, vllm_major_version, ):
self.runner, opts) yield (
model_id,
parallel_setup,
backend,
vllm_major_version,
self.runner,
opts,
)
def _compare_cp_with_tp( def _compare_cp_with_tp(
@ -148,8 +160,10 @@ def _compare_cp_with_tp(
if num_gpus_available < tp_size * pp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp": if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip(
"multiprocessing distributed backend") "Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend"
)
if multi_node_only and not VLLM_MULTI_NODE: if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting") pytest.skip("Not in multi-node setting")
@ -178,8 +192,7 @@ def _compare_cp_with_tp(
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
cp_env = tp_env = { cp_env = tp_env = {
"VLLM_USE_V1": "VLLM_USE_V1": vllm_major_version, # Note(hc): DCP only support V1 engine only
vllm_major_version, # Note(hc): DCP only support V1 engine only
} }
cp_args = [ cp_args = [
@ -205,13 +218,15 @@ def _compare_cp_with_tp(
] ]
try: try:
compare_two_settings(model_id, compare_two_settings(
cp_args, model_id,
tp_args, cp_args,
cp_env, tp_args,
tp_env, cp_env,
method=method, tp_env,
max_wait_seconds=720) method=method,
max_wait_seconds=720,
)
except Exception: except Exception:
testing_ray_compiled_graph = cp_env is not None testing_ray_compiled_graph = cp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0": if testing_ray_compiled_graph and vllm_major_version == "0":
@ -224,9 +239,10 @@ def _compare_cp_with_tp(
CP_TEXT_GENERATION_MODELS = { CP_TEXT_GENERATION_MODELS = {
# [MLA attention only] # [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": "deepseek-ai/DeepSeek-V2-Lite-Chat": [
[CPTestSettings.detailed(), CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2)], CPTestSettings.detailed(tp_base=2),
],
} }
CP_TEST_MODELS = [ CP_TEST_MODELS = [
@ -237,11 +253,19 @@ CP_TEST_MODELS = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", (
"runner", "test_options"), "model_id",
"parallel_setup",
"distributed_backend",
"vllm_major_version",
"runner",
"test_options",
),
[ [
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items() params
for setting in settings for params in setting.iter_params(model_id) for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
for setting in settings
for params in setting.iter_params(model_id)
if model_id in CP_TEST_MODELS if model_id in CP_TEST_MODELS
], ],
) )
@ -255,12 +279,14 @@ def test_cp_generation(
test_options: CPTestOptions, test_options: CPTestOptions,
num_gpus_available, num_gpus_available,
): ):
_compare_cp_with_tp(model_id, _compare_cp_with_tp(
parallel_setup, model_id,
distributed_backend, parallel_setup,
vllm_major_version, distributed_backend,
runner, vllm_major_version,
test_options, runner,
num_gpus_available, test_options,
method="generate", num_gpus_available,
is_multimodal=False) method="generate",
is_multimodal=False,
)

View File

@ -8,12 +8,14 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.distributed.parallel_state import get_tp_group, graph_capture
from ..utils import (ensure_model_parallel_initialized, from ..utils import (
init_test_distributed_environment, multi_process_parallel) ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_parallel,
)
random.seed(42) random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
@ -33,8 +35,7 @@ def graph_allreduce(
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tp_group().device_group group = get_tp_group().device_group
@ -60,18 +61,15 @@ def graph_allreduce(
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture(device=device) as graph_capture_context: with graph_capture(device=device) as graph_capture_context:
# use integers so result matches NCCL exactly # use integers so result matches NCCL exactly
inp1 = torch.randint(1, inp1 = torch.randint(
16, (sz, ), 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
dtype=dtype, )
device=torch.cuda.current_device()) inp2 = torch.randint(
inp2 = torch.randint(1, 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
16, (sz, ), )
dtype=dtype,
device=torch.cuda.current_device())
torch.cuda.synchronize() torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, with torch.cuda.graph(graph, stream=graph_capture_context.stream):
stream=graph_capture_context.stream):
for i in range(num_communication): for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1) out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test # the input buffer is immediately modified to test
@ -96,8 +94,7 @@ def eager_allreduce(
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
# we use the first group to communicate once # we use the first group to communicate once
# and the second group to communicate twice # and the second group to communicate twice
@ -132,5 +129,4 @@ def test_custom_allreduce(
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
test_target)

View File

@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from ..entrypoints.openai.test_oot_registration import ( from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
run_and_test_dummy_opt_api_server)
def test_distributed_oot(dummy_opt_path: str): def test_distributed_oot(dummy_opt_path: str):

View File

@ -10,10 +10,12 @@ from vllm.distributed.eplb.rebalance_algo import rebalance_experts
def test_basic_rebalance(): def test_basic_rebalance():
"""Test basic rebalancing functionality""" """Test basic rebalancing functionality"""
# Example from https://github.com/deepseek-ai/eplb # Example from https://github.com/deepseek-ai/eplb
weight = torch.tensor([ weight = torch.tensor(
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], [
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
]) [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
]
)
num_layers = weight.shape[0] num_layers = weight.shape[0]
num_replicas = 16 num_replicas = 16
@ -21,45 +23,49 @@ def test_basic_rebalance():
num_nodes = 2 num_nodes = 2
num_gpus = 8 num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Verify output shapes # Verify output shapes
assert phy2log.shape == ( assert phy2log.shape == (
2, 2,
16, 16,
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
assert (log2phy.shape[0] == 2 assert log2phy.shape[0] == 2, (
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
assert ( )
log2phy.shape[1] == 12 assert log2phy.shape[1] == 12, (
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
)
assert logcnt.shape == ( assert logcnt.shape == (
2, 2,
12, 12,
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
# Verify physical to logical expert mapping range is correct # Verify physical to logical expert mapping range is correct
assert torch.all(phy2log >= 0) and torch.all( assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), (
phy2log < 12), "Physical to logical mapping should be in range [0, 12)" "Physical to logical mapping should be in range [0, 12)"
)
# Verify expert count reasonableness # Verify expert count reasonableness
assert torch.all( assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica"
logcnt >= 1), "Each logical expert should have at least 1 replica" assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, (
assert ( f"Total replicas should be {num_replicas * num_layers}"
torch.sum(logcnt, dim=1).sum() == num_replicas * )
num_layers), f"Total replicas should be {num_replicas * num_layers}"
# Verify expected output # Verify expected output
expected_phy2log = torch.tensor([ expected_phy2log = torch.tensor(
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], [
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
]) [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
]
)
assert torch.all(phy2log == expected_phy2log) assert torch.all(phy2log == expected_phy2log)
expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], expected_logcnt = torch.tensor(
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) [[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]
)
assert torch.all(logcnt == expected_logcnt) assert torch.all(logcnt == expected_logcnt)
@ -71,9 +77,9 @@ def test_single_gpu_case():
num_nodes = 1 num_nodes = 1
num_gpus = 1 num_gpus = 1
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Verify shapes # Verify shapes
assert phy2log.shape == (1, 4) assert phy2log.shape == (1, 4)
@ -93,19 +99,19 @@ def test_equal_weights():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Verify shapes # Verify shapes
assert phy2log.shape == (1, 8) assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 8) assert logcnt.shape == (1, 8)
# With equal weights, each expert should have exactly one replica # With equal weights, each expert should have exactly one replica
assert torch.all( assert torch.all(logcnt == 1), (
logcnt == 1 "With equal weights and no replication, "
), "With equal weights and no replication, " \ "each expert should have exactly 1 replica"
"each expert should have exactly 1 replica" )
def test_extreme_weight_imbalance(): def test_extreme_weight_imbalance():
@ -116,35 +122,37 @@ def test_extreme_weight_imbalance():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Verify shapes # Verify shapes
assert phy2log.shape == (1, 12) assert phy2log.shape == (1, 12)
assert logcnt.shape == (1, 8) assert logcnt.shape == (1, 8)
# Expert with highest weight (index 0) should have more replicas # Expert with highest weight (index 0) should have more replicas
assert ( assert logcnt[0, 0] > logcnt[0, 1], (
logcnt[0, 0] "Expert with highest weight should have more replicas"
> logcnt[0, 1]), "Expert with highest weight should have more replicas" )
def test_multiple_layers(): def test_multiple_layers():
"""Test multiple layers case""" """Test multiple layers case"""
weight = torch.tensor([ weight = torch.tensor(
[10, 20, 30, 40, 50, 60], # First layer [
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) [10, 20, 30, 40, 50, 60], # First layer
[25, 25, 25, 25, 25, 25], # Third layer (equal weights) [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
]) [25, 25, 25, 25, 25, 25], # Third layer (equal weights)
]
)
num_replicas = 8 num_replicas = 8
num_groups = 2 num_groups = 2
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Verify shapes # Verify shapes
assert phy2log.shape == (3, 8) assert phy2log.shape == (3, 8)
@ -152,12 +160,12 @@ def test_multiple_layers():
# Verify expert allocation is reasonable for each layer # Verify expert allocation is reasonable for each layer
for layer in range(3): for layer in range(3):
assert torch.all(phy2log[layer] >= 0) and torch.all( assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), (
phy2log[layer] < 6 f"Layer {layer} physical to logical mappingshould be in range [0, 6)"
), f"Layer {layer} physical to logical mapping" \ )
"should be in range [0, 6)" assert torch.sum(logcnt[layer]) == num_replicas, (
assert (torch.sum(logcnt[layer]) == num_replicas f"Layer {layer} total replicas should be {num_replicas}"
), f"Layer {layer} total replicas should be {num_replicas}" )
def test_parameter_validation(): def test_parameter_validation():
@ -179,17 +187,19 @@ def test_parameter_validation():
def test_small_scale_hierarchical(): def test_small_scale_hierarchical():
"""Test small-scale hierarchical load balancing""" """Test small-scale hierarchical load balancing"""
weight = torch.tensor([ weight = torch.tensor(
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts [
]) [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
]
)
num_replicas = 12 num_replicas = 12
num_groups = 4 # 4 groups, 2 experts each num_groups = 4 # 4 groups, 2 experts each
num_nodes = 2 # 2 nodes num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs num_gpus = 4 # 4 GPUs
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Verify basic constraints # Verify basic constraints
assert phy2log.shape == (1, 12) assert phy2log.shape == (1, 12)
@ -199,8 +209,9 @@ def test_small_scale_hierarchical():
# Expert with highest weight should have more replicas # Expert with highest weight should have more replicas
max_weight_expert = torch.argmax(weight[0]) max_weight_expert = torch.argmax(weight[0])
assert (logcnt[0, max_weight_expert] assert logcnt[0, max_weight_expert] >= 2, (
>= 2), "Highest weight expert should have multiple replicas" "Highest weight expert should have multiple replicas"
)
def test_global_load_balance_fallback(): def test_global_load_balance_fallback():
@ -213,9 +224,9 @@ def test_global_load_balance_fallback():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Should work normally, just using global load balancing strategy # Should work normally, just using global load balancing strategy
assert phy2log.shape == (1, 8) assert phy2log.shape == (1, 8)
@ -235,9 +246,9 @@ def test_device_compatibility(device):
num_nodes = 1 num_nodes = 1
num_gpus = 2 num_gpus = 2
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
# Function will convert to CPU internally, but should handle different # Function will convert to CPU internally, but should handle different
# device inputs normally # device inputs normally
@ -250,7 +261,8 @@ def test_additional_cases():
# Test case 1: Large-scale distributed setup # Test case 1: Large-scale distributed setup
weight1 = torch.tensor( weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
)
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
assert phy2log1.shape == (1, 24) assert phy2log1.shape == (1, 24)
@ -258,10 +270,12 @@ def test_additional_cases():
assert torch.sum(logcnt1) == 24 assert torch.sum(logcnt1) == 24
# Test case 2: Different weight distributions # Test case 2: Different weight distributions
weight2 = torch.tensor([ weight2 = torch.tensor(
[200, 150, 100, 50, 25, 12], # Decreasing weights [
[12, 25, 50, 100, 150, 200], # Increasing weights [200, 150, 100, 50, 25, 12], # Decreasing weights
]) [12, 25, 50, 100, 150, 200], # Increasing weights
]
)
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
assert phy2log2.shape == (2, 10) assert phy2log2.shape == (2, 10)
@ -274,19 +288,21 @@ def test_additional_cases():
if __name__ == "__main__": if __name__ == "__main__":
weight = torch.tensor([ weight = torch.tensor(
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], [
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
]) [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
]
)
num_replicas = 16 num_replicas = 16
num_groups = 4 num_groups = 4
num_nodes = 2 num_nodes = 2
num_gpus = 8 num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, phy2log, log2phy, logcnt = rebalance_experts(
num_groups, num_nodes, weight, num_replicas, num_groups, num_nodes, num_gpus
num_gpus) )
print(phy2log) print(phy2log)
test_basic_rebalance() test_basic_rebalance()

View File

@ -9,11 +9,12 @@ import pytest
import torch import torch
import torch.distributed import torch.distributed
from vllm.distributed.eplb.rebalance_execute import ( from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
rearrange_expert_weights_inplace) from vllm.distributed.parallel_state import (
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, ensure_model_parallel_initialized,
get_tp_group, get_tp_group,
init_distributed_environment) init_distributed_environment,
)
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
@ -22,13 +23,13 @@ def distributed_run(fn, world_size):
processes: list[multiprocessing.Process] = [] processes: list[multiprocessing.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
env: dict[str, str] = {} env: dict[str, str] = {}
env['RANK'] = str(i) env["RANK"] = str(i)
env['LOCAL_RANK'] = str(i) env["LOCAL_RANK"] = str(i)
env['WORLD_SIZE'] = str(number_of_processes) env["WORLD_SIZE"] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost' env["MASTER_ADDR"] = "localhost"
env['MASTER_PORT'] = '12345' env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env, )) p = multiprocessing.Process(target=fn, args=(env,))
processes.append(p) processes.append(p)
p.start() p.start()
@ -45,7 +46,7 @@ def worker_fn_wrapper(fn):
# and update the environment variables in the function # and update the environment variables in the function
def wrapped_fn(env): def wrapped_fn(env):
update_environment_variables(env) update_environment_variables(env)
local_rank = os.environ['LOCAL_RANK'] local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_distributed_environment() init_distributed_environment()
@ -60,10 +61,10 @@ def worker_fn_wrapper(fn):
def create_expert_indices_with_redundancy( def create_expert_indices_with_redundancy(
num_layers: int, num_layers: int,
num_logical_experts: int, num_logical_experts: int,
total_physical_experts: int, total_physical_experts: int,
redundancy_config: list[int], # redundancy for each logical expert redundancy_config: list[int], # redundancy for each logical expert
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Create expert indices with redundancy. Create expert indices with redundancy.
@ -120,27 +121,27 @@ def create_expert_weights(
for layer in range(num_layers): for layer in range(num_layers):
layer_weights = [] layer_weights = []
for weight_idx, hidden_size in enumerate(hidden_sizes): for weight_idx, hidden_size in enumerate(hidden_sizes):
weight_tensor = torch.zeros(num_local_experts, weight_tensor = torch.zeros(
hidden_size, num_local_experts, hidden_size, device=device, dtype=torch.float32
device=device, )
dtype=torch.float32)
for local_expert in range(num_local_experts): for local_expert in range(num_local_experts):
# Get the logical expert ID for this physical expert # Get the logical expert ID for this physical expert
global_pos = rank * num_local_experts + local_expert global_pos = rank * num_local_experts + local_expert
logical_expert_id = physical_to_logical_mapping[ logical_expert_id = physical_to_logical_mapping[
layer, global_pos].item() layer, global_pos
].item()
# Generate weights based on logical expert ID # Generate weights based on logical expert ID
# (so that all replicas of the same logical expert have the # (so that all replicas of the same logical expert have the
# same weights) # same weights)
base_value = (logical_expert_id * 1000 + layer * 100 + base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10
weight_idx * 10) weight_tensor[local_expert] = torch.arange(
weight_tensor[local_expert] = torch.arange(base_value, base_value,
base_value + base_value + hidden_size,
hidden_size, device=device,
device=device, dtype=torch.float32,
dtype=torch.float32) )
layer_weights.append(weight_tensor) layer_weights.append(weight_tensor)
expert_weights.append(layer_weights) expert_weights.append(layer_weights)
@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle(
# Check if the weights are correct # Check if the weights are correct
actual_weights = weight_tensor[local_expert] actual_weights = weight_tensor[local_expert]
expected_base = (expected_logical_expert * 1000 + layer * 100 + expected_base = (
weight_idx * 10) expected_logical_expert * 1000 + layer * 100 + weight_idx * 10
expected_weights = torch.arange(expected_base, )
expected_base + hidden_size, expected_weights = torch.arange(
device=actual_weights.device, expected_base,
dtype=actual_weights.dtype) expected_base + hidden_size,
device=actual_weights.device,
dtype=actual_weights.dtype,
)
torch.testing.assert_close( torch.testing.assert_close(
actual_weights, actual_weights,
@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle(
msg=f"Layer {layer}, weight {weight_idx}," msg=f"Layer {layer}, weight {weight_idx},"
f"local expert {local_expert}: " f"local expert {local_expert}: "
f"weights do not match. " f"weights do not match. "
f"Expected logical expert {expected_logical_expert}") f"Expected logical expert {expected_logical_expert}",
)
def verify_redundant_experts_have_same_weights( def verify_redundant_experts_have_same_weights(
@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights(
total_physical_experts, total_physical_experts,
hidden_size, hidden_size,
device=expert_weights[layer][weight_idx].device, device=expert_weights[layer][weight_idx].device,
dtype=expert_weights[layer][weight_idx].dtype) dtype=expert_weights[layer][weight_idx].dtype,
)
# Use all_gather to collect expert weights from current node # Use all_gather to collect expert weights from current node
# expert_weights[layer][weight_idx] shape: # expert_weights[layer][weight_idx] shape:
# [num_local_experts, hidden_size] # [num_local_experts, hidden_size]
local_weights = expert_weights[layer][ local_weights = expert_weights[layer][
weight_idx] # [num_local_experts, hidden_size] weight_idx
] # [num_local_experts, hidden_size]
# Split tensor along dim 0 into a list for all_gather # Split tensor along dim 0 into a list for all_gather
gathered_weights_list = torch.chunk(gathered_weights, gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0)
world_size,
dim=0)
torch.distributed.all_gather( torch.distributed.all_gather(
# Output list: each element corresponds to one rank's weights # Output list: each element corresponds to one rank's weights
list(gathered_weights_list), list(gathered_weights_list),
local_weights # Input: current rank's local weights local_weights, # Input: current rank's local weights
) )
all_weights.append(gathered_weights) all_weights.append(gathered_weights)
@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights(
msg=f"Layer {layer}, weight {weight_idx}," msg=f"Layer {layer}, weight {weight_idx},"
f"logical expert {logical_expert_id}: " f"logical expert {logical_expert_id}: "
f"Physical expert {physical_pos} has different weights" f"Physical expert {physical_pos} has different weights"
f"than expected") f"than expected",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights(
# 4 GPU, 8 experts per GPU # 4 GPU, 8 experts per GPU
# 16 logical experts, 32 physical experts, 16 redundant experts # 16 logical experts, 32 physical experts, 16 redundant experts
(4, 8, 8, 16), (4, 8, 8, 16),
]) ],
def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, )
num_local_experts, def test_rearrange_expert_weights_with_redundancy(
num_logical_experts): world_size, num_layers, num_local_experts, num_logical_experts
):
"""Test the functionality of rearranging expert weights with redundancy.""" """Test the functionality of rearranging expert weights with redundancy."""
if torch.cuda.device_count() < world_size: if torch.cuda.device_count() < world_size:
@ -304,8 +311,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
# Initialize model parallel (using tensor parallel as an entrypoint # Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel) # to expert parallel)
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
pipeline_model_parallel_size=1) )
ep_group = get_tp_group().cpu_group ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
@ -316,8 +323,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
hidden_sizes = [32, 64] # Two different weight matrices hidden_sizes = [32, 64] # Two different weight matrices
# Create old expert indices (with redundancy) # Create old expert indices (with redundancy)
redundancy_config = create_redundancy_config(num_logical_experts, redundancy_config = create_redundancy_config(
total_physical_experts) num_logical_experts, total_physical_experts
)
old_indices = create_expert_indices_with_redundancy( old_indices = create_expert_indices_with_redundancy(
num_layers, num_layers,
@ -328,7 +336,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
# Create new expert indices (with redundancy) # Create new expert indices (with redundancy)
new_redundancy_config = create_redundancy_config( new_redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts) num_logical_experts, total_physical_experts
)
new_indices = create_expert_indices_with_redundancy( new_indices = create_expert_indices_with_redundancy(
num_layers, num_layers,
num_logical_experts, num_logical_experts,
@ -337,9 +346,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
) )
# Create expert weights # Create expert weights
expert_weights = create_expert_weights(num_layers, num_local_experts, expert_weights = create_expert_weights(
hidden_sizes, ep_rank, device, num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
old_indices) )
# Execute weight rearrangement # Execute weight rearrangement
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size):
@worker_fn_wrapper @worker_fn_wrapper
def worker_fn(): def worker_fn():
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
pipeline_model_parallel_size=1) )
ep_group = get_tp_group().cpu_group ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
@ -401,12 +410,12 @@ def test_rearrange_expert_weights_no_change(world_size):
# Same indices - no change # Same indices - no change
indices = create_expert_indices_with_redundancy( indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, num_layers, num_logical_experts, total_physical_experts, redundancy_config
redundancy_config) )
expert_weights = create_expert_weights(num_layers, num_local_experts, expert_weights = create_expert_weights(
hidden_sizes, ep_rank, device, num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
indices) )
# Save original weights # Save original weights
original_weights = [] original_weights = []
@ -422,7 +431,8 @@ def test_rearrange_expert_weights_no_change(world_size):
indices, # Same indices indices, # Same indices
expert_weights, expert_weights,
ep_group, ep_group,
is_profile=False) is_profile=False,
)
# Verify that the weights have not changed # Verify that the weights have not changed
for layer in range(num_layers): for layer in range(num_layers):
@ -430,8 +440,8 @@ def test_rearrange_expert_weights_no_change(world_size):
torch.testing.assert_close( torch.testing.assert_close(
expert_weights[layer][weight_idx], expert_weights[layer][weight_idx],
original_weights[layer][weight_idx], original_weights[layer][weight_idx],
msg=f"Layer {layer}, weight {weight_idx} should remain " msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
f"unchanged") )
distributed_run(worker_fn, world_size) distributed_run(worker_fn, world_size)
@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size):
@worker_fn_wrapper @worker_fn_wrapper
def worker_fn(): def worker_fn():
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
pipeline_model_parallel_size=1) )
ep_group = get_tp_group().cpu_group ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
@ -460,21 +470,23 @@ def test_rearrange_expert_weights_profile_mode(world_size):
hidden_sizes = [32] hidden_sizes = [32]
# Create different index distributions # Create different index distributions
old_redundancy = create_redundancy_config(num_logical_experts, old_redundancy = create_redundancy_config(
total_physical_experts) num_logical_experts, total_physical_experts
new_redundancy = create_redundancy_config(num_logical_experts, )
total_physical_experts) new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
old_indices = create_expert_indices_with_redundancy( old_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, num_layers, num_logical_experts, total_physical_experts, old_redundancy
old_redundancy) )
new_indices = create_expert_indices_with_redundancy( new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, num_layers, num_logical_experts, total_physical_experts, new_redundancy
new_redundancy) )
expert_weights = create_expert_weights(num_layers, num_local_experts, expert_weights = create_expert_weights(
hidden_sizes, ep_rank, device, num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
old_indices) )
# Save original weights # Save original weights
original_weights = [] original_weights = []
@ -490,7 +502,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
new_indices, new_indices,
expert_weights, expert_weights,
ep_group, ep_group,
is_profile=True # Profile mode is_profile=True, # Profile mode
) )
# In profile mode, the weights should remain unchanged # In profile mode, the weights should remain unchanged
@ -499,6 +511,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
torch.testing.assert_close( torch.testing.assert_close(
expert_weights[layer][weight_idx], expert_weights[layer][weight_idx],
original_weights[layer][weight_idx], original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged") msg="In profile mode, the weights should remain unchanged",
)
distributed_run(worker_fn, world_size) distributed_run(worker_fn, world_size)

View File

@ -6,24 +6,29 @@ import time
import msgspec import msgspec
import pytest import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, from vllm.distributed.kv_events import (
NullEventPublisher) EventBatch,
EventPublisherFactory,
NullEventPublisher,
)
DP_RANK = 0 DP_RANK = 0
class EventSample( class EventSample(
msgspec.Struct, msgspec.Struct,
tag=True, # type: ignore tag=True, # type: ignore
array_like=True # type: ignore array_like=True, # type: ignore
): ):
"""Test event for publisher testing""" """Test event for publisher testing"""
id: int id: int
value: str value: str
class SampleBatch(EventBatch): class SampleBatch(EventBatch):
"""Test event batch for publisher testing""" """Test event batch for publisher testing"""
events: list[EventSample] events: list[EventSample]
@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber):
seq, received = result seq, received = result
assert seq == 0, "Sequence number mismatch" assert seq == 0, "Sequence number mismatch"
assert received.ts == pytest.approx(test_batch.ts, assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
abs=0.1), ("Timestamp mismatch") assert len(received.events) == len(test_batch.events), "Number of events mismatch"
assert len(received.events) == len(
test_batch.events), ("Number of events mismatch")
for i, event in enumerate(received.events): for i, event in enumerate(received.events):
assert event.id == i, "Event id mismatch" assert event.id == i, "Event id mismatch"
@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber):
assert len(replayed) > 0, "No replayed messages received" assert len(replayed) > 0, "No replayed messages received"
seqs = [seq for seq, _ in replayed] seqs = [seq for seq, _ in replayed]
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order" assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
assert seqs == list(range(min(seqs), assert seqs == list(range(min(seqs), max(seqs) + 1)), (
max(seqs) + "Replayed messages not consecutive"
1)), ("Replayed messages not consecutive") )
def test_buffer_limit(publisher, subscriber, publisher_config): def test_buffer_limit(publisher, subscriber, publisher_config):
@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config):
pub = EventPublisherFactory.create(publisher_config, DP_RANK) pub = EventPublisherFactory.create(publisher_config, DP_RANK)
from .conftest import MockSubscriber from .conftest import MockSubscriber
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar") sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config):
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)] foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
assert all(msg is not None for msg in foo_received), ( assert all(msg is not None for msg in foo_received), (
"Subscriber with matching topic should receive messages") "Subscriber with matching topic should receive messages"
)
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)] bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
assert all(msg is None for msg in bar_received), ( assert all(msg is None for msg in bar_received), (
"Subscriber with non-matching topic should receive no messages") "Subscriber with non-matching topic should receive no messages"
)
finally: finally:
pub.shutdown() pub.shutdown()
sub_foo.close() sub_foo.close()
@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber):
publisher_thread.join() publisher_thread.join()
assert len(received) >= num_batches * 0.9, ( assert len(received) >= num_batches * 0.9, "We should have received most messages"
"We should have received most messages")
seqs = [seq for seq, _ in received] seqs = [seq for seq, _ in received]
assert sorted(seqs) == seqs, "Sequence numbers should be in order" assert sorted(seqs) == seqs, "Sequence numbers should be in order"
@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config):
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558 # For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
expected_endpoint_1 = base_endpoint.replace( expected_endpoint_1 = base_endpoint.replace(
":5557", ":5558") # rank 1 gets port + 1 ":5557", ":5558"
) # rank 1 gets port + 1
else: else:
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1 # For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
expected_endpoint_0 = base_endpoint # rank 0 gets base expected_endpoint_0 = base_endpoint # rank 0 gets base
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1 expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
from .conftest import MockSubscriber from .conftest import MockSubscriber
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic) sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic) sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config):
# Verify DP rank tagging # Verify DP rank tagging
assert received_0.data_parallel_rank == 0, ( assert received_0.data_parallel_rank == 0, (
f"Expected DP rank 0, got {received_0.data_parallel_rank}") f"Expected DP rank 0, got {received_0.data_parallel_rank}"
)
assert received_1.data_parallel_rank == 1, ( assert received_1.data_parallel_rank == 1, (
f"Expected DP rank 1, got {received_1.data_parallel_rank}") f"Expected DP rank 1, got {received_1.data_parallel_rank}"
)
# Verify event content is correct # Verify event content is correct
assert len( assert len(received_0.events) == 2, "Wrong number of events from rank 0"
received_0.events) == 2, "Wrong number of events from rank 0" assert len(received_1.events) == 3, "Wrong number of events from rank 1"
assert len(
received_1.events) == 3, "Wrong number of events from rank 1"
finally: finally:
pub_0.shutdown() pub_0.shutdown()

View File

@ -46,28 +46,24 @@ class EPTestSettings:
): ):
return EPTestSettings( return EPTestSettings(
parallel_setups=[ parallel_setups=[
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False),
eager_mode=False, ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True),
chunked_prefill=False), ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
ParallelSetup(tp_size=tp_base, ParallelSetup(
eager_mode=False, tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True
chunked_prefill=True), ),
ParallelSetup(tp_size=tp_base, ParallelSetup(
eager_mode=True, tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False
chunked_prefill=False), ),
ParallelSetup(tp_size=2 * tp_base,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=2 * tp_base,
eager_mode=True,
chunked_prefill=False),
], ],
distributed_backends=["mp", "ray"], distributed_backends=["mp", "ray"],
runner=runner, runner=runner,
test_options=EPTestOptions(trust_remote_code=trust_remote_code, test_options=EPTestOptions(
tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code,
load_format=load_format, tokenizer_mode=tokenizer_mode,
hf_overrides=hf_overrides), load_format=load_format,
hf_overrides=hf_overrides,
),
) )
@staticmethod @staticmethod
@ -82,16 +78,16 @@ class EPTestSettings:
): ):
return EPTestSettings( return EPTestSettings(
parallel_setups=[ parallel_setups=[
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
eager_mode=True,
chunked_prefill=False),
], ],
distributed_backends=["mp"], distributed_backends=["mp"],
runner=runner, runner=runner,
test_options=EPTestOptions(trust_remote_code=trust_remote_code, test_options=EPTestOptions(
tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code,
load_format=load_format, tokenizer_mode=tokenizer_mode,
hf_overrides=hf_overrides), load_format=load_format,
hf_overrides=hf_overrides,
),
) )
def iter_params(self, model_name: str): def iter_params(self, model_name: str):
@ -99,8 +95,13 @@ class EPTestSettings:
for parallel_setup in self.parallel_setups: for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends: for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend, yield (
self.runner, opts) model_name,
parallel_setup,
distributed_backend,
self.runner,
opts,
)
# NOTE: You can adjust tp_base locally to fit the model in GPU # NOTE: You can adjust tp_base locally to fit the model in GPU

View File

@ -6,8 +6,7 @@ import pytest
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
def verify_round_robin_pattern(expert_map, ep_rank, ep_size, def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts):
global_num_experts):
"""Verify that the expert map follows the round_robin pattern.""" """Verify that the expert map follows the round_robin pattern."""
# Calculate expected local experts (supporting non-divisible cases) # Calculate expected local experts (supporting non-divisible cases)
base_experts = global_num_experts // ep_size base_experts = global_num_experts // ep_size
@ -30,24 +29,21 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
if global_expert_id in expected_expert_ids: if global_expert_id in expected_expert_ids:
local_expert_id = expert_map[global_expert_id] local_expert_id = expert_map[global_expert_id]
expected_local_id = expected_expert_ids.index(global_expert_id) expected_local_id = expected_expert_ids.index(global_expert_id)
assert ( assert local_expert_id == expected_local_id, (
local_expert_id == expected_local_id f"Global expert {global_expert_id} should map to local expert "
), f"Global expert {global_expert_id} should map to local expert " \
f"{expected_local_id}, got {local_expert_id}" f"{expected_local_id}, got {local_expert_id}"
)
else: else:
assert ( assert expert_map[global_expert_id] == -1, (
expert_map[global_expert_id] == -1 f"Global expert {global_expert_id} should not be mapped to this rank"
), f"Global expert {global_expert_id} should not be mapped to " \ )
f"this rank"
# Verify that all local expert IDs are consecutive starting from 0 # Verify that all local expert IDs are consecutive starting from 0
local_expert_ids = [ local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids]
expert_map[global_id] for global_id in expected_expert_ids
]
expected_local_ids = list(range(local_num_experts)) expected_local_ids = list(range(local_num_experts))
assert ( assert local_expert_ids == expected_local_ids, (
local_expert_ids == expected_local_ids f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}" )
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) @pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
@ -78,8 +74,9 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
for test_global_experts, test_ep_size in test_cases: for test_global_experts, test_ep_size in test_cases:
# Ensure ep_size matches world_size # Ensure ep_size matches world_size
assert (test_ep_size == world_size assert test_ep_size == world_size, (
), f"ep_size {test_ep_size} must equal world_size {world_size}" f"ep_size {test_ep_size} must equal world_size {world_size}"
)
# Test each rank # Test each rank
for ep_rank in range(world_size): for ep_rank in range(world_size):
@ -98,21 +95,22 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
expert_placement_strategy=expert_placement_strategy, expert_placement_strategy=expert_placement_strategy,
) )
assert ( assert test_local_experts == expected_test_local, (
test_local_experts == expected_test_local f"For {test_global_experts} experts on {test_ep_size} ranks, "
), f"For {test_global_experts} experts on {test_ep_size} ranks, " \ f"rank {ep_rank}: expected {expected_test_local} local"
f"rank {ep_rank}: expected {expected_test_local} local" \
f"experts, got {test_local_experts}" f"experts, got {test_local_experts}"
)
if test_expert_map is not None: if test_expert_map is not None:
assert test_expert_map.shape == ( assert test_expert_map.shape == (test_global_experts,), (
test_global_experts, f"Expected expert map shape ({test_global_experts},), "
), f"Expected expert map shape ({test_global_experts},), " \
f"got {test_expert_map.shape}" f"got {test_expert_map.shape}"
)
# Verify round_robin pattern for this test case # Verify round_robin pattern for this test case
verify_round_robin_pattern(test_expert_map, ep_rank, verify_round_robin_pattern(
test_ep_size, test_global_experts) test_expert_map, ep_rank, test_ep_size, test_global_experts
)
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) @pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
@ -147,28 +145,81 @@ def test_determine_expert_map_comprehensive():
# expert_placement_strategy, expected_local, expected_map_pattern) # expert_placement_strategy, expected_local, expected_map_pattern)
test_cases = [ test_cases = [
# Round robin placement tests # Round robin placement tests
(2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3, (
-1]), # rank 0 gets even experts 2,
(2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 0,
3]), # rank 1 gets odd experts 8,
(2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4 "round_robin",
]), # rank 0 gets 5 experts (even + last) 4,
(2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3, [0, -1, 1, -1, 2, -1, 3, -1],
-1]), # rank 1 gets 4 experts (odd) ), # rank 0 gets even experts
(
2,
1,
8,
"round_robin",
4,
[-1, 0, -1, 1, -1, 2, -1, 3],
), # rank 1 gets odd experts
(
2,
0,
9,
"round_robin",
5,
[0, -1, 1, -1, 2, -1, 3, -1, 4],
), # rank 0 gets 5 experts (even + last)
(
2,
1,
9,
"round_robin",
4,
[-1, 0, -1, 1, -1, 2, -1, 3, -1],
), # rank 1 gets 4 experts (odd)
# 4-rank tests # 4-rank tests
(4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1, (
-1]), # rank 0 gets experts 0, 4 4,
(4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1, 0,
-1]), # rank 1 gets experts 1, 5 8,
(4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1, "round_robin",
-1]), # rank 2 gets experts 2, 6 2,
(4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1, [0, -1, -1, -1, 1, -1, -1, -1],
1]), # rank 3 gets experts 3, 7 ), # rank 0 gets experts 0, 4
(
4,
1,
8,
"round_robin",
2,
[-1, 0, -1, -1, -1, 1, -1, -1],
), # rank 1 gets experts 1, 5
(
4,
2,
8,
"round_robin",
2,
[-1, -1, 0, -1, -1, -1, 1, -1],
), # rank 2 gets experts 2, 6
(
4,
3,
8,
"round_robin",
2,
[-1, -1, -1, 0, -1, -1, -1, 1],
), # rank 3 gets experts 3, 7
] ]
for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \ for (
expected_local, expected_map_pattern in test_cases: ep_size,
ep_rank,
global_num_experts,
expert_placement_strategy,
expected_local,
expected_map_pattern,
) in test_cases:
local_num_experts, expert_map = determine_expert_map( local_num_experts, expert_map = determine_expert_map(
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
@ -176,19 +227,21 @@ def test_determine_expert_map_comprehensive():
expert_placement_strategy=expert_placement_strategy, expert_placement_strategy=expert_placement_strategy,
) )
assert local_num_experts == expected_local, \ assert local_num_experts == expected_local, (
f"ep_size={ep_size}, ep_rank={ep_rank}, " \ f"ep_size={ep_size}, ep_rank={ep_rank}, "
f"global_num_experts={global_num_experts}, " \ f"global_num_experts={global_num_experts}, "
f"expert_placement_strategy={expert_placement_strategy}: " \ f"expert_placement_strategy={expert_placement_strategy}: "
f"expected {expected_local} local experts, got {local_num_experts}" f"expected {expected_local} local experts, got {local_num_experts}"
)
if expected_map_pattern is None: if expected_map_pattern is None:
assert expert_map is None, "Expected expert_map to be None" assert expert_map is None, "Expected expert_map to be None"
else: else:
assert expert_map is not None, "Expected expert_map to not be None" assert expert_map is not None, "Expected expert_map to not be None"
actual_map = expert_map.tolist() actual_map = expert_map.tolist()
assert actual_map == expected_map_pattern, \ assert actual_map == expected_map_pattern, (
f"ep_size={ep_size}, ep_rank={ep_rank}, " \ f"ep_size={ep_size}, ep_rank={ep_rank}, "
f"global_num_experts={global_num_experts}, " \ f"global_num_experts={global_num_experts}, "
f"expert_placement_strategy={expert_placement_strategy}: " \ f"expert_placement_strategy={expert_placement_strategy}: "
f"expected map {expected_map_pattern}, got {actual_map}" f"expected map {expected_map_pattern}, got {actual_map}"
)

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig, from vllm.config import (
VllmConfig, set_current_vllm_config) DeviceConfig,
KVTransferConfig,
ModelConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout) get_kv_connector_cache_layout,
)
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger("test_expert_parallel") logger = init_logger("test_expert_parallel")
@ -23,8 +29,9 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector():
kv_connector="LMCacheConnectorV1", kv_connector="LMCacheConnectorV1",
kv_role="kv_both", kv_role="kv_both",
) )
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), vllm_config = VllmConfig(
kv_transfer_config=kv_transfer_config) device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Test with default settings # Test with default settings
layout = get_kv_connector_cache_layout() layout = get_kv_connector_cache_layout()
@ -37,9 +44,11 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
kv_role="kv_both", kv_role="kv_both",
) )
model_config = ModelConfig() model_config = ModelConfig()
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), vllm_config = VllmConfig(
model_config=model_config, device_config=DeviceConfig("cpu"),
kv_transfer_config=kv_transfer_config) model_config=model_config,
kv_transfer_config=kv_transfer_config,
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Test with default settings # Test with default settings
layout = get_kv_connector_cache_layout() layout = get_kv_connector_cache_layout()
@ -47,25 +56,22 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
def test_get_kv_connector_cache_layout_with_multi_connector(): def test_get_kv_connector_cache_layout_with_multi_connector():
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector", kv_transfer_config = KVTransferConfig(
kv_role="kv_both", kv_connector="MultiConnector",
kv_connector_extra_config={ kv_role="kv_both",
"connectors": [{ kv_connector_extra_config={
"kv_connector": "connectors": [
"SharedStorageConnector", {"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"},
"kv_role": {"kv_connector": "NixlConnector", "kv_role": "kv_both"},
"kv_both" ]
}, { },
"kv_connector": )
"NixlConnector",
"kv_role":
"kv_both"
}]
})
model_config = ModelConfig() model_config = ModelConfig()
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), vllm_config = VllmConfig(
model_config=model_config, device_config=DeviceConfig("cpu"),
kv_transfer_config=kv_transfer_config) model_config=model_config,
kv_transfer_config=kv_transfer_config,
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Test with default settings # Test with default settings
layout = get_kv_connector_cache_layout() layout = get_kv_connector_cache_layout()

View File

@ -24,14 +24,13 @@ from vllm.utils import get_ip
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.skipif(not VLLM_MULTI_NODE, @pytest.mark.skipif(
reason="Need at least 2 nodes to run the test.") not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test."
)
def test_multi_node_assignment() -> None: def test_multi_node_assignment() -> None:
# NOTE: important to keep this class definition here # NOTE: important to keep this class definition here
# to let ray use cloudpickle to serialize it. # to let ray use cloudpickle to serialize it.
class Actor: class Actor:
def get_ip(self): def get_ip(self):
return get_ip() return get_ip()
@ -41,8 +40,7 @@ def test_multi_node_assignment() -> None:
current_ip = get_ip() current_ip = get_ip()
workers = [] workers = []
for bundle_id, bundle in enumerate( for bundle_id, bundle in enumerate(config.placement_group.bundle_specs):
config.placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
continue continue
scheduling_strategy = PlacementGroupSchedulingStrategy( scheduling_strategy = PlacementGroupSchedulingStrategy(

View File

@ -11,15 +11,17 @@ import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.device_communicators.cuda_communicator import ( from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
CudaCommunicator) from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
from vllm.distributed.device_communicators.pynccl_allocator import ( from vllm.distributed.device_communicators.pynccl_allocator import (
get_nccl_mem_pool, is_symmetric_memory_enabled) get_nccl_mem_pool,
from vllm.distributed.parallel_state import (get_tp_group, is_symmetric_memory_enabled,
init_distributed_environment, )
initialize_model_parallel) from vllm.distributed.parallel_state import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
@ -38,31 +40,32 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
"RANK": str(local_rank), {
"LOCAL_RANK": str(local_rank), "RANK": str(local_rank),
"WORLD_SIZE": str(world_size), "LOCAL_RANK": str(local_rank),
"MASTER_ADDR": "localhost", "WORLD_SIZE": str(world_size),
"MASTER_PORT": "12345", "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator, cuda_communicator = typing.cast(
get_tp_group().device_communicator) CudaCommunicator, get_tp_group().device_communicator
)
pynccl_comm = cuda_communicator.pynccl_comm pynccl_comm = cuda_communicator.pynccl_comm
if get_nccl_mem_pool() is None: if get_nccl_mem_pool() is None:
pytest.skip("NCCL allocator compilation failed " pytest.skip(
"(probably missing NCCL headers).") "NCCL allocator compilation failed (probably missing NCCL headers)."
)
if not is_symmetric_memory_enabled(): if not is_symmetric_memory_enabled():
pytest.skip("NCCL symmetric memory allreduce is disabled.") pytest.skip("NCCL symmetric memory allreduce is disabled.")
register_nccl_symmetric_ops(pynccl_comm) register_nccl_symmetric_ops(pynccl_comm)
input = torch.randint(1, input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device)
23, (test_size_elements, ),
dtype=dtype,
device=device)
input_clone = input.clone() input_clone = input.clone()
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input) output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
assert output is not None assert output is not None
@ -77,8 +80,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.", reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
) )
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
reason="Only test on CUDA")
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
@ -88,7 +90,5 @@ def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1") monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1") monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
mp.spawn(nccl_symm_mem_allreduce_worker, mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size)
args=(world_size, ),
nprocs=world_size)
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()

View File

@ -32,12 +32,15 @@ if __name__ == "__main__":
# Expected node count based on environment variable) # Expected node count based on environment variable)
expected = int(os.environ.get("NUM_NODES", "1")) expected = int(os.environ.get("NUM_NODES", "1"))
assert test_result == expected, \ assert test_result == expected, f"Expected {expected} nodes, got {test_result}"
f"Expected {expected} nodes, got {test_result}"
if pg == dist.group.WORLD: if pg == dist.group.WORLD:
print(f"Node count test passed! Got {test_result} nodes " print(
f"when using torch distributed!") f"Node count test passed! Got {test_result} nodes "
f"when using torch distributed!"
)
else: else:
print(f"Node count test passed! Got {test_result} nodes " print(
f"when using StatelessProcessGroup!") f"Node count test passed! Got {test_result} nodes "
f"when using StatelessProcessGroup!"
)

View File

@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
all workers in a node other than the head node, which can cause the test all workers in a node other than the head node, which can cause the test
to fail. to fail.
""" """
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -55,26 +56,17 @@ class PPTestSettings:
): ):
return PPTestSettings( return PPTestSettings(
parallel_setups=[ parallel_setups=[
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=False),
pp_size=pp_base, ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=False),
eager_mode=False), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=True),
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=False),
pp_size=2 * pp_base, ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=True),
eager_mode=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
eager_mode=True),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
eager_mode=False),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
eager_mode=True),
], ],
distributed_backends=["mp", "ray"], distributed_backends=["mp", "ray"],
runner=runner, runner=runner,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(
load_format=load_format), multi_node_only=multi_node_only, load_format=load_format
),
) )
@staticmethod @staticmethod
@ -86,17 +78,15 @@ class PPTestSettings:
multi_node_only: bool = False, multi_node_only: bool = False,
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
return PPTestSettings( return PPTestSettings(
parallel_setups=[ parallel_setups=[
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=True),
pp_size=pp_base,
eager_mode=True),
], ],
distributed_backends=["mp"], distributed_backends=["mp"],
runner=runner, runner=runner,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(
load_format=load_format), multi_node_only=multi_node_only, load_format=load_format
),
) )
def iter_params(self, model_id: str): def iter_params(self, model_id: str):
@ -281,8 +271,10 @@ def _compare_tp(
if num_gpus_available < tp_size * pp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp": if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip(
"multiprocessing distributed backend") "Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend"
)
if multi_node_only and not VLLM_MULTI_NODE: if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting") pytest.skip("Not in multi-node setting")
@ -357,20 +349,16 @@ def _compare_tp(
"mp", "mp",
] ]
compare_two_settings(model_id, compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method)
pp_args,
tp_args,
pp_env,
tp_env,
method=method)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "runner", ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
"test_options"),
[ [
params for model_id, settings in TEXT_GENERATION_MODELS.items() params
for params in settings.iter_params(model_id) if model_id in TEST_MODELS for model_id, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in TEST_MODELS
], ],
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
@ -382,22 +370,25 @@ def test_tp_language_generation(
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available, num_gpus_available,
): ):
_compare_tp(model_id, _compare_tp(
parallel_setup, model_id,
distributed_backend, parallel_setup,
runner, distributed_backend,
test_options, runner,
num_gpus_available, test_options,
method="generate", num_gpus_available,
is_multimodal=False) method="generate",
is_multimodal=False,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "runner", ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
"test_options"),
[ [
params for model_id, settings in EMBEDDING_MODELS.items() params
for params in settings.iter_params(model_id) if model_id in TEST_MODELS for model_id, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in TEST_MODELS
], ],
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
@ -409,22 +400,25 @@ def test_tp_language_embedding(
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available, num_gpus_available,
): ):
_compare_tp(model_id, _compare_tp(
parallel_setup, model_id,
distributed_backend, parallel_setup,
runner, distributed_backend,
test_options, runner,
num_gpus_available, test_options,
method="encode", num_gpus_available,
is_multimodal=False) method="encode",
is_multimodal=False,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "runner", ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
"test_options"),
[ [
params for model_id, settings in MULTIMODAL_MODELS.items() params
for params in settings.iter_params(model_id) if model_id in TEST_MODELS for model_id, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in TEST_MODELS
], ],
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
@ -436,11 +430,13 @@ def test_tp_multimodal_generation(
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available, num_gpus_available,
): ):
_compare_tp(model_id, _compare_tp(
parallel_setup, model_id,
distributed_backend, parallel_setup,
runner, distributed_backend,
test_options, runner,
num_gpus_available, test_options,
method="generate", num_gpus_available,
is_multimodal=True) method="generate",
is_multimodal=True,
)

View File

@ -9,7 +9,6 @@ from vllm.distributed.utils import get_pp_indices
def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch): def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
def _verify(partition_str, num_layers, pp_size, goldens): def _verify(partition_str, num_layers, pp_size, goldens):
@ -57,7 +56,8 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
(5, 3, 0, (0, 2)), (5, 3, 0, (0, 2)),
(5, 3, 1, (2, 4)), (5, 3, 1, (2, 4)),
(5, 3, 2, (4, 5)), (5, 3, 2, (4, 5)),
]) ],
)
def test_uneven_auto_partition( def test_uneven_auto_partition(
num_hidden_layers: int, num_hidden_layers: int,
pp_size: int, pp_size: int,

View File

@ -12,12 +12,18 @@ if TYPE_CHECKING:
from typing_extensions import LiteralString from typing_extensions import LiteralString
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ @pytest.mark.parametrize(
(2, "JackFram/llama-160m"), "PP_SIZE, MODEL_NAME",
]) [
@pytest.mark.parametrize("ATTN_BACKEND", [ (2, "JackFram/llama-160m"),
"FLASH_ATTN", ],
]) )
@pytest.mark.parametrize(
"ATTN_BACKEND",
[
"FLASH_ATTN",
],
)
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,

View File

@ -9,13 +9,15 @@ import pytest
import torch import torch
import torch.distributed import torch.distributed
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, from vllm.distributed.parallel_state import (
get_world_group, graph_capture, ensure_model_parallel_initialized,
init_distributed_environment) get_world_group,
graph_capture,
init_distributed_environment,
)
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
@ -24,13 +26,13 @@ def distributed_run(fn, world_size):
processes: list[multiprocessing.Process] = [] processes: list[multiprocessing.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
env: dict[str, str] = {} env: dict[str, str] = {}
env['RANK'] = str(i) env["RANK"] = str(i)
env['LOCAL_RANK'] = str(i) env["LOCAL_RANK"] = str(i)
env['WORLD_SIZE'] = str(number_of_processes) env["WORLD_SIZE"] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost' env["MASTER_ADDR"] = "localhost"
env['MASTER_PORT'] = '12345' env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env, )) p = multiprocessing.Process(target=fn, args=(env,))
processes.append(p) processes.append(p)
p.start() p.start()
@ -47,7 +49,7 @@ def worker_fn_wrapper(fn):
# and update the environment variables in the function # and update the environment variables in the function
def wrapped_fn(env): def wrapped_fn(env):
update_environment_variables(env) update_environment_variables(env)
local_rank = os.environ['LOCAL_RANK'] local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_distributed_environment() init_distributed_environment()
@ -58,17 +60,18 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper @worker_fn_wrapper
def worker_fn(): def worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
tensor = torch.ones(16, 1024, 1024, )
dtype=torch.float32).cuda(pynccl_comm.rank) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item() assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl(): def test_pynccl():
distributed_run(worker_fn, 2) distributed_run(worker_fn, 2)
@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [ groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"), torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo") torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
] ]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device) pynccl_comm = PyNcclCommunicator(group=group, device=device)
@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn():
assert torch.all(tensor == 2).cpu().item() assert torch.all(tensor == 2).cpu().item()
@pytest.mark.skipif(torch.cuda.device_count() < 4, @pytest.mark.skipif(
reason="Need at least 4 GPUs to run the test.") torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_multiple_allreduce(): def test_pynccl_multiple_allreduce():
# this tests pynccl for multiple tp groups, in a standalone way # this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `pynccl_comm.all_reduce` directly # i.e. call `pynccl_comm.all_reduce` directly
@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn():
assert torch.all(tensor == 2).cpu().item() assert torch.all(tensor == 2).cpu().item()
@pytest.mark.skipif(torch.cuda.device_count() < 4, @pytest.mark.skipif(
reason="Need at least 4 GPUs to run the test.") torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_multiple_allreduce_with_vllm(): def test_pynccl_multiple_allreduce_with_vllm():
# this tests pynccl for multiple tp groups, together with vllm # this tests pynccl for multiple tp groups, together with vllm
# i.e. call `tensor_model_parallel_all_reduce` # i.e. call `tensor_model_parallel_all_reduce`
@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm():
def worker_fn_with_cudagraph(): def worker_fn_with_cudagraph():
with torch.no_grad(): with torch.no_grad():
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
# run something in the default stream to initialize torch engine # run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
a_out = pynccl_comm.all_reduce(a) a_out = pynccl_comm.all_reduce(a)
@ -148,84 +154,90 @@ def worker_fn_with_cudagraph():
@worker_fn_wrapper @worker_fn_wrapper
def all_gather_worker_fn(): def all_gather_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
rank = pynccl_comm.rank rank = pynccl_comm.rank
world_size = pynccl_comm.world_size world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}' device = f"cuda:{pynccl_comm.rank}"
num_elems = 1000 num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32, tensor = (
device=device) + rank * num_elems torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
result = torch.zeros(num_elems * world_size, )
dtype=torch.float32, result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device)
device=device)
expected = torch.cat([ expected = torch.cat(
torch.arange(num_elems, dtype=torch.float32) + r * num_elems [
for r in range(world_size) torch.arange(num_elems, dtype=torch.float32) + r * num_elems
]).to(device) for r in range(world_size)
]
).to(device)
pynccl_comm.all_gather(result, tensor) pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_all_gather(): def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2) distributed_run(all_gather_worker_fn, 2)
@worker_fn_wrapper @worker_fn_wrapper
def all_gatherv_worker_fn(): def all_gatherv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
rank = pynccl_comm.rank rank = pynccl_comm.rank
world_size = pynccl_comm.world_size world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}' device = f"cuda:{pynccl_comm.rank}"
assert world_size <= 8 assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sizes[rank] num_elems = sizes[rank]
tensor = torch.arange(num_elems, dtype=torch.float32, tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
device=device) + rank * 100
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
expected = torch.cat([ expected = torch.cat(
torch.arange(sizes[r], dtype=torch.float32) + r * 100 [
for r in range(world_size) torch.arange(sizes[r], dtype=torch.float32) + r * 100
]).to(device) for r in range(world_size)
]
).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes) pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_all_gatherv(): def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2) distributed_run(all_gatherv_worker_fn, 2)
@worker_fn_wrapper @worker_fn_wrapper
def reduce_scatter_worker_fn(): def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
rank = pynccl_comm.rank rank = pynccl_comm.rank
world_size = pynccl_comm.world_size world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}' device = f"cuda:{pynccl_comm.rank}"
num_elems = 1000 num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32, tensor = (
device=device) + rank * num_elems torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
assert (num_elems % world_size == 0) )
result = torch.zeros(num_elems // world_size, assert num_elems % world_size == 0
dtype=torch.float32, result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device)
device=device)
# Calculate expected result for this rank's chunk # Calculate expected result for this rank's chunk
scattered_size = num_elems // world_size scattered_size = num_elems // world_size
@ -233,34 +245,37 @@ def reduce_scatter_worker_fn():
torch.arange(num_elems, dtype=torch.float32) + r * num_elems torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size) for r in range(world_size)
] ]
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] expected = sum(
for tensor in all_tensors).to(device) tensor[rank * scattered_size : (rank + 1) * scattered_size]
for tensor in all_tensors
).to(device)
pynccl_comm.reduce_scatter(result, tensor) pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_reduce_scatter(): def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2) distributed_run(reduce_scatter_worker_fn, 2)
@worker_fn_wrapper @worker_fn_wrapper
def reduce_scatterv_worker_fn(): def reduce_scatterv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
rank = pynccl_comm.rank rank = pynccl_comm.rank
world_size = pynccl_comm.world_size world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}' device = f"cuda:{pynccl_comm.rank}"
assert world_size <= 8 assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sum(sizes) num_elems = sum(sizes)
tensor = torch.arange(num_elems, dtype=torch.float32, tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
device=device) + rank * 100
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
# Calculate expected result for this rank's chunk # Calculate expected result for this rank's chunk
@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn():
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_reduce_scatterv(): def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2) distributed_run(reduce_scatterv_worker_fn, 2)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_with_cudagraph(): def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2) distributed_run(worker_fn_with_cudagraph, 2)
@worker_fn_wrapper @worker_fn_wrapper
def send_recv_worker_fn(): def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
if pynccl_comm.rank == 0: if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024, tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
dtype=torch.float32).cuda(pynccl_comm.rank)
else: else:
tensor = torch.empty(16, 1024, 1024, tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
dtype=torch.float32).cuda(pynccl_comm.rank)
if pynccl_comm.rank == 0: if pynccl_comm.rank == 0:
pynccl_comm.send(tensor, pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else: else:
pynccl_comm.recv(tensor, pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize() torch.cuda.synchronize()
assert torch.all(tensor == 1).cpu().item() assert torch.all(tensor == 1).cpu().item()
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(
reason="Need at least 2 GPUs to run the test.") torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_send_recv(): def test_pynccl_send_recv():
distributed_run(send_recv_worker_fn, 2) distributed_run(send_recv_worker_fn, 2)
@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [ groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"), torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo") torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
] ]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device) pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
elif torch.distributed.get_rank() == 1: elif torch.distributed.get_rank() == 1:
tensor = 2 * torch.ones( tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
16, 1024, 1024, dtype=torch.float32, device=device)
else: else:
tensor = torch.empty(16, tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device)
1024,
1024,
dtype=torch.float32,
device=device)
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor, pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else: else:
pynccl_comm.recv(tensor, pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize() torch.cuda.synchronize()
if torch.distributed.get_rank() in [0, 2]: if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item() assert torch.all(tensor == 1).cpu().item()
@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn():
assert torch.all(tensor == 2).cpu().item() assert torch.all(tensor == 2).cpu().item()
@pytest.mark.skipif(torch.cuda.device_count() < 4, @pytest.mark.skipif(
reason="Need at least 4 GPUs to run the test.") torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_multiple_send_recv(): def test_pynccl_multiple_send_recv():
distributed_run(multiple_send_recv_worker_fn, 4) distributed_run(multiple_send_recv_worker_fn, 4)
@pytest.mark.skipif(torch.cuda.device_count() < 4, @pytest.mark.skipif(
reason="Need at least 4 GPUs to run the test.") torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_broadcast(): def test_pynccl_broadcast():
distributed_run(broadcast_worker_fn, 4) distributed_run(broadcast_worker_fn, 4)
@ -366,19 +376,17 @@ def test_pynccl_broadcast():
def broadcast_worker_fn(): def broadcast_worker_fn():
# Test broadcast for every root rank. # Test broadcast for every root rank.
# Essentially this is an all-gather operation. # Essentially this is an all-gather operation.
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(
device=get_world_group().device) get_world_group().cpu_group, device=get_world_group().device
)
recv_tensors = [ recv_tensors = [
torch.empty(16, torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
1024,
1024,
dtype=torch.float32,
device=pynccl_comm.device)
for i in range(pynccl_comm.world_size) for i in range(pynccl_comm.world_size)
] ]
recv_tensors[pynccl_comm.rank] = torch.ones( recv_tensors[pynccl_comm.rank] = (
16, 1024, 1024, dtype=torch.float32, torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
device=pynccl_comm.device) * pynccl_comm.rank * pynccl_comm.rank
)
for i in range(pynccl_comm.world_size): for i in range(pynccl_comm.world_size):
pynccl_comm.broadcast(recv_tensors[i], src=i) pynccl_comm.broadcast(recv_tensors[i], src=i)

View File

@ -8,20 +8,20 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.distributed.parallel_state import get_tp_group, graph_capture
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import (ensure_model_parallel_initialized, from ..utils import (
init_test_distributed_environment, multi_process_parallel) ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_parallel,
)
torch.manual_seed(42) torch.manual_seed(42)
random.seed(44) random.seed(44)
# Size over 8MB is sufficient for custom quick allreduce. # Size over 8MB is sufficient for custom quick allreduce.
test_sizes = [ test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
]
for i, v in enumerate(test_sizes): for i, v in enumerate(test_sizes):
test_sizes[i] -= v % 8 test_sizes[i] -= v % 8
@ -38,8 +38,7 @@ def graph_quickreduce(
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tp_group().device_group group = get_tp_group().device_group
@ -64,18 +63,15 @@ def graph_quickreduce(
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float16, torch.bfloat16]: for dtype in [torch.float16, torch.bfloat16]:
with graph_capture(device=device) as graph_capture_context: with graph_capture(device=device) as graph_capture_context:
inp1 = torch.randint(1, inp1 = torch.randint(
23, (sz, ), 1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device()
dtype=dtype, )
device=torch.cuda.current_device()) inp2 = torch.randint(
inp2 = torch.randint(-23, -23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
1, (sz, ), )
dtype=dtype,
device=torch.cuda.current_device())
torch.cuda.synchronize() torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, with torch.cuda.graph(graph, stream=graph_capture_context.stream):
stream=graph_capture_context.stream):
for _ in range(num_communication): for _ in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1) out1 = tensor_model_parallel_all_reduce(inp1)
dist.all_reduce(inp1, group=group) dist.all_reduce(inp1, group=group)
@ -99,39 +95,42 @@ def eager_quickreduce(
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
distributed_init_port)
# Size over 8MB is sufficient for custom quick allreduce. # Size over 8MB is sufficient for custom quick allreduce.
sz = 16 * 1024 * 1024 sz = 16 * 1024 * 1024
fa = get_tp_group().device_communicator.qr_comm fa = get_tp_group().device_communicator.qr_comm
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], inp = torch.tensor(
dtype=torch.float16, [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
device=device) )
out = fa.quick_all_reduce(inp) out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], inp = torch.tensor(
dtype=torch.bfloat16, [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
device=device) )
out = fa.quick_all_reduce(inp) out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
@pytest.mark.skipif(not current_platform.is_rocm(), @pytest.mark.skipif(
reason="only test quick allreduce for rocm") not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"]) @pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce]) @pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, def test_custom_quick_allreduce(
pipeline_parallel_size, test_target, monkeypatch: pytest.MonkeyPatch,
quant_mode): tp_size,
pipeline_parallel_size,
test_target,
quant_mode,
):
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
test_target)

View File

@ -22,15 +22,13 @@ if __name__ == "__main__":
dist.broadcast_object_list(recv, src=0) dist.broadcast_object_list(recv, src=0)
ip, port = recv ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank, stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]: for pg in [dist.group.WORLD, stateless_pg]:
test_result = all(in_the_same_node_as(pg, source_rank=0)) test_result = all(in_the_same_node_as(pg, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, \ assert test_result == expected, f"Expected {expected}, got {test_result}"
f"Expected {expected}, got {test_result}"
if pg == dist.group.WORLD: if pg == dist.group.WORLD:
print("Same node test passed! when using torch distributed!") print("Same node test passed! when using torch distributed!")
else: else:

View File

@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
all workers in a node other than the head node, which can cause the test all workers in a node other than the head node, which can cause the test
to fail. to fail.
""" """
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -56,7 +57,8 @@ class SPTestSettings:
raise ValueError( raise ValueError(
f"Length mismatch: distributed_backends " f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != " f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})") f"vllm_major_versions ({len(self.vllm_major_versions)})"
)
@staticmethod @staticmethod
def detailed( def detailed(
@ -72,18 +74,22 @@ class SPTestSettings:
for pp_multiplier in [1, 2]: for pp_multiplier in [1, 2]:
for chunked_prefill_val in [False, True]: for chunked_prefill_val in [False, True]:
parallel_setups.append( parallel_setups.append(
ParallelSetup(tp_size=tp_base, ParallelSetup(
pp_size=pp_multiplier * pp_base, tp_size=tp_base,
enable_fusion=False, pp_size=pp_multiplier * pp_base,
eager_mode=eager_mode_val, enable_fusion=False,
chunked_prefill=chunked_prefill_val)) eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
)
return SPTestSettings( return SPTestSettings(
parallel_setups=parallel_setups, parallel_setups=parallel_setups,
distributed_backends=["mp", "ray"], distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"], vllm_major_versions=["1", "1"],
runner=runner, runner=runner,
test_options=SPTestOptions(multi_node_only=multi_node_only, test_options=SPTestOptions(
load_format=load_format), multi_node_only=multi_node_only, load_format=load_format
),
) )
@staticmethod @staticmethod
@ -100,18 +106,22 @@ class SPTestSettings:
for pp_multiplier in [1, 2]: for pp_multiplier in [1, 2]:
for chunked_prefill_val in [False, True]: for chunked_prefill_val in [False, True]:
parallel_setups.append( parallel_setups.append(
ParallelSetup(tp_size=tp_base, ParallelSetup(
pp_size=pp_multiplier * pp_base, tp_size=tp_base,
enable_fusion=False, pp_size=pp_multiplier * pp_base,
eager_mode=eager_mode_val, enable_fusion=False,
chunked_prefill=chunked_prefill_val)) eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
)
return SPTestSettings( return SPTestSettings(
parallel_setups=parallel_setups, parallel_setups=parallel_setups,
distributed_backends=["mp", "ray"], distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"], vllm_major_versions=["1", "1"],
runner=runner, runner=runner,
test_options=SPTestOptions(multi_node_only=multi_node_only, test_options=SPTestOptions(
load_format=load_format), multi_node_only=multi_node_only, load_format=load_format
),
) )
@staticmethod @staticmethod
@ -126,28 +136,39 @@ class SPTestSettings:
parallel_setups = [] parallel_setups = []
for fusion_val in [False, True]: for fusion_val in [False, True]:
parallel_setups.append( parallel_setups.append(
ParallelSetup(tp_size=tp_base, ParallelSetup(
pp_size=pp_base, tp_size=tp_base,
enable_fusion=fusion_val, pp_size=pp_base,
eager_mode=True, enable_fusion=fusion_val,
chunked_prefill=False)) eager_mode=True,
chunked_prefill=False,
)
)
return SPTestSettings( return SPTestSettings(
parallel_setups=parallel_setups, parallel_setups=parallel_setups,
distributed_backends=["mp", "ray"], distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"], vllm_major_versions=["1", "1"],
runner=runner, runner=runner,
test_options=SPTestOptions(multi_node_only=multi_node_only, test_options=SPTestOptions(
load_format=load_format), multi_node_only=multi_node_only, load_format=load_format
),
) )
def iter_params(self, model_id: str): def iter_params(self, model_id: str):
opts = self.test_options opts = self.test_options
for parallel_setup in self.parallel_setups: for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends, for backend, vllm_major_version in zip(
self.vllm_major_versions): self.distributed_backends, self.vllm_major_versions
yield (model_id, parallel_setup, backend, vllm_major_version, ):
self.runner, opts) yield (
model_id,
parallel_setup,
backend,
vllm_major_version,
self.runner,
opts,
)
def _compare_sp( def _compare_sp(
@ -200,8 +221,10 @@ def _compare_sp(
if num_gpus_available < tp_size * pp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp": if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip(
"multiprocessing distributed backend") "Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend"
)
if multi_node_only and not VLLM_MULTI_NODE: if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting") pytest.skip("Not in multi-node setting")
@ -232,13 +255,13 @@ def _compare_sp(
common_args.append("--skip-tokenizer-init") common_args.append("--skip-tokenizer-init")
compilation_config = { compilation_config = {
'level': 3, "level": 3,
'custom_ops': ["+rms_norm"], "custom_ops": ["+rms_norm"],
'compile_sizes': [4, 8], "compile_sizes": [4, 8],
'pass_config': { "pass_config": {
'enable_sequence_parallelism': True, "enable_sequence_parallelism": True,
'enable_fusion': enable_fusion, "enable_fusion": enable_fusion,
'enable_noop': True, "enable_noop": True,
}, },
} }
@ -270,12 +293,9 @@ def _compare_sp(
] ]
try: try:
compare_two_settings(model_id, compare_two_settings(
tp_sp_args, model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method
tp_args, )
tp_sp_env,
tp_env,
method=method)
except Exception: except Exception:
testing_ray_compiled_graph = tp_sp_env is not None testing_ray_compiled_graph = tp_sp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0": if testing_ray_compiled_graph and vllm_major_version == "0":
@ -301,10 +321,17 @@ SP_TEST_MODELS = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", (
"runner", "test_options"), "model_id",
"parallel_setup",
"distributed_backend",
"vllm_major_version",
"runner",
"test_options",
),
[ [
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() params
for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id) for params in settings.iter_params(model_id)
if model_id in SP_TEST_MODELS if model_id in SP_TEST_MODELS
], ],
@ -319,12 +346,14 @@ def test_tp_sp_generation(
test_options: SPTestOptions, test_options: SPTestOptions,
num_gpus_available, num_gpus_available,
): ):
_compare_sp(model_id, _compare_sp(
parallel_setup, model_id,
distributed_backend, parallel_setup,
vllm_major_version, distributed_backend,
runner, vllm_major_version,
test_options, runner,
num_gpus_available, test_options,
method="generate", num_gpus_available,
is_multimodal=False) method="generate",
is_multimodal=False,
)

View File

@ -26,13 +26,13 @@ def distributed_run(fn, world_size):
processes = [] processes = []
for i in range(number_of_processes): for i in range(number_of_processes):
env = {} env = {}
env['RANK'] = str(i) env["RANK"] = str(i)
env['LOCAL_RANK'] = str(i) env["LOCAL_RANK"] = str(i)
env['WORLD_SIZE'] = str(number_of_processes) env["WORLD_SIZE"] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost' env["MASTER_ADDR"] = "localhost"
env['MASTER_PORT'] = '12345' env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env, )) p = multiprocessing.Process(target=fn, args=(env,))
processes.append(p) processes.append(p)
p.start() p.start()
@ -57,25 +57,23 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper @worker_fn_wrapper
def worker_fn(): def worker_fn():
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: if rank == 0:
port = get_open_port() port = get_open_port()
ip = '127.0.0.1' ip = "127.0.0.1"
dist.broadcast_object_list([ip, port], src=0) dist.broadcast_object_list([ip, port], src=0)
else: else:
recv = [None, None] recv = [None, None]
dist.broadcast_object_list(recv, src=0) dist.broadcast_object_list(recv, src=0)
ip, port = recv # type: ignore ip, port = recv # type: ignore
stateless_pg = StatelessProcessGroup.create(ip, port, rank, stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]: for pg in [dist.group.WORLD, stateless_pg]:
writer_rank = 2 writer_rank = 2
broadcaster = MessageQueue.create_from_process_group( broadcaster = MessageQueue.create_from_process_group(
pg, 40 * 1024, 2, writer_rank) pg, 40 * 1024, 2, writer_rank
)
if rank == writer_rank: if rank == writer_rank:
seed = random.randint(0, 1000) seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank) dist.broadcast_object_list([seed], writer_rank)

View File

@ -5,7 +5,8 @@ import traceback
import unittest import unittest
from vllm.distributed.device_communicators.shm_object_storage import ( from vllm.distributed.device_communicators.shm_object_storage import (
SingleWriterShmRingBuffer) SingleWriterShmRingBuffer,
)
class TestSingleWriterShmRingBuffer(unittest.TestCase): class TestSingleWriterShmRingBuffer(unittest.TestCase):
@ -25,18 +26,21 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
"""Test opening an existing buffer""" """Test opening an existing buffer"""
# First create a buffer # First create a buffer
self.ring_buffer = SingleWriterShmRingBuffer( self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=self.buffer_size, create=True) data_buffer_size=self.buffer_size, create=True
)
# Then open it with another instance # Then open it with another instance
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle()) reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
self.assertFalse(reader_buffer.is_writer) self.assertFalse(reader_buffer.is_writer)
self.assertEqual(reader_buffer.shared_memory.name, self.assertEqual(
self.ring_buffer.shared_memory.name) reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name
)
def test_buffer_access(self): def test_buffer_access(self):
"""Test accessing allocated buffers""" """Test accessing allocated buffers"""
self.ring_buffer = SingleWriterShmRingBuffer( self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=self.buffer_size, create=True) data_buffer_size=self.buffer_size, create=True
)
size = 100 size = 100
address, monotonic_id = self.ring_buffer.allocate_buf(size) address, monotonic_id = self.ring_buffer.allocate_buf(size)
@ -44,11 +48,11 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
# Write some test data # Write some test data
test_data = b"Hello, World!" * 7 # 91 bytes test_data = b"Hello, World!" * 7 # 91 bytes
with self.ring_buffer.access_buf(address) as (data_buf, metadata): with self.ring_buffer.access_buf(address) as (data_buf, metadata):
data_buf[0:len(test_data)] = test_data data_buf[0 : len(test_data)] = test_data
# Read it back # Read it back
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2): with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
read_data = bytes(data_buf2[0:len(test_data)]) read_data = bytes(data_buf2[0 : len(test_data)])
read_id = metadata2[0] read_id = metadata2[0]
self.assertEqual(read_data, test_data) self.assertEqual(read_data, test_data)
@ -58,7 +62,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
"""Test that MemoryError is raised when buffer is full""" """Test that MemoryError is raised when buffer is full"""
small_buffer_size = 200 small_buffer_size = 200
self.ring_buffer = SingleWriterShmRingBuffer( self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=small_buffer_size, create=True) data_buffer_size=small_buffer_size, create=True
)
# Fill up the buffer # Fill up the buffer
self.ring_buffer.allocate_buf(100) self.ring_buffer.allocate_buf(100)
@ -72,7 +77,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
"""Test allocation and freeing of buffers""" """Test allocation and freeing of buffers"""
small_buffer_size = 200 small_buffer_size = 200
self.ring_buffer = SingleWriterShmRingBuffer( self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=small_buffer_size, create=True) data_buffer_size=small_buffer_size, create=True
)
size = 80 size = 80
# Write some data # Write some data
@ -81,7 +87,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
address, monotonic_id = self.ring_buffer.allocate_buf(size) address, monotonic_id = self.ring_buffer.allocate_buf(size)
with self.ring_buffer.access_buf(address) as (data_buf, metadata): with self.ring_buffer.access_buf(address) as (data_buf, metadata):
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
data_buf[4:len(test_data) + 4] = test_data data_buf[4 : len(test_data) + 4] = test_data
print(self.ring_buffer.metadata) print(self.ring_buffer.metadata)
freed_ids = self.ring_buffer.free_buf(lambda *args: True) freed_ids = self.ring_buffer.free_buf(lambda *args: True)
print(f" Freed IDs: {freed_ids}") print(f" Freed IDs: {freed_ids}")
@ -90,7 +96,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
def test_clear_buffer(self): def test_clear_buffer(self):
"""Test clearing the buffer""" """Test clearing the buffer"""
self.ring_buffer = SingleWriterShmRingBuffer( self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=self.buffer_size, create=True) data_buffer_size=self.buffer_size, create=True
)
# Allocate some buffers # Allocate some buffers
for _ in range(3): for _ in range(3):
@ -121,8 +128,7 @@ def main():
# Manual demonstration # Manual demonstration
try: try:
print("Creating ring buffer...") print("Creating ring buffer...")
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True)
create=True)
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle()) reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
print(f"Buffer created with name: {writer_buffer.shared_memory.name}") print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
@ -140,7 +146,7 @@ def main():
# Write some test data # Write some test data
with writer_buffer.access_buf(address) as (data_buf, metadata): with writer_buffer.access_buf(address) as (data_buf, metadata):
test_message = f"Test message {i}".encode() test_message = f"Test message {i}".encode()
data_buf[0:len(test_message)] = test_message data_buf[0 : len(test_message)] = test_message
except MemoryError as e: except MemoryError as e:
print(f" Failed to allocate {size} bytes: {e}") print(f" Failed to allocate {size} bytes: {e}")

View File

@ -12,28 +12,33 @@ import torch
# Assuming these are imported from your module # Assuming these are imported from your module
from vllm.distributed.device_communicators.shm_object_storage import ( from vllm.distributed.device_communicators.shm_object_storage import (
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) MsgpackSerde,
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, SingleWriterShmObjectStorage,
MultiModalSharedField) SingleWriterShmRingBuffer,
)
from vllm.multimodal.inputs import (
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalSharedField,
)
def _dummy_elem(modality: str, key: str, size: int): def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem( return MultiModalFieldElem(
modality=modality, modality=modality,
key=key, key=key,
data=torch.empty((size, ), dtype=torch.int8), data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(1), field=MultiModalSharedField(1),
) )
def _dummy_item(modality: str, size_by_key: dict[str, int]): def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([ return MultiModalKwargsItem.from_elems(
_dummy_elem(modality, key, size) for key, size in size_by_key.items() [_dummy_elem(modality, key, size) for key, size in size_by_key.items()]
]) )
class TestSingleWriterShmObjectStorage(unittest.TestCase): class TestSingleWriterShmObjectStorage(unittest.TestCase):
def setUp(self): def setUp(self):
"""Set up test fixtures before each test method.""" """Set up test fixtures before each test method."""
ring_buffer = SingleWriterShmRingBuffer( ring_buffer = SingleWriterShmRingBuffer(
@ -208,8 +213,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
self.storage.get(address, monotonic_id + 100) self.storage.get(address, monotonic_id + 100)
self.assertIn("has been modified or is invalid", \ self.assertIn("has been modified or is invalid", str(context.exception))
str(context.exception))
def test_clear_storage(self): def test_clear_storage(self):
"""Test clearing the storage.""" """Test clearing the storage."""
@ -234,8 +238,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
# Reader process function # Reader process function
def reader_process(process_id, storage_handle, items_to_read): def reader_process(process_id, storage_handle, items_to_read):
"""Reader process that connects to existing shared memory and reads data.""" """Reader process that connects to existing shared memory and reads data."""
reader_storage = SingleWriterShmObjectStorage.create_from_handle( reader_storage = SingleWriterShmObjectStorage.create_from_handle(storage_handle)
storage_handle)
print(f"Reader {process_id} started") print(f"Reader {process_id} started")
@ -276,11 +279,7 @@ def run_multiprocess_example():
# Test basic data types # Test basic data types
test_data = [ test_data = [
("user_data", { ("user_data", {"name": "Alice", "age": 30, "scores": [95, 87, 92]}),
"name": "Alice",
"age": 30,
"scores": [95, 87, 92]
}),
("simple_string", "Hello, World!"), ("simple_string", "Hello, World!"),
("number", 42), ("number", 42),
("list_data", [1, 2, 3, "four", 5.0]), ("list_data", [1, 2, 3, "four", 5.0]),
@ -301,8 +300,9 @@ def run_multiprocess_example():
# initialize lock for reader processes # initialize lock for reader processes
handle.reader_lock = Lock() handle.reader_lock = Lock()
for i in range(storage.n_readers): for i in range(storage.n_readers):
p = multiprocessing.Process(target=reader_process, p = multiprocessing.Process(
args=(i, handle, stored_items)) target=reader_process, args=(i, handle, stored_items)
)
processes.append(p) processes.append(p)
p.start() p.start()

View File

@ -14,11 +14,12 @@ import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import ( from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
CudaCommunicator) from vllm.distributed.parallel_state import (
from vllm.distributed.parallel_state import (get_tp_group, get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel,
)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -32,8 +33,7 @@ test_size_elements = 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
monkeypatch = pytest.MonkeyPatch() monkeypatch = pytest.MonkeyPatch()
config = VllmConfig(parallel_config=ParallelConfig( config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size))
tensor_parallel_size=world_size))
with monkeypatch.context() as m, set_current_vllm_config(config): with monkeypatch.context() as m, set_current_vllm_config(config):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
@ -42,34 +42,34 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
'RANK': str(local_rank), {
'LOCAL_RANK': str(local_rank), "RANK": str(local_rank),
'WORLD_SIZE': str(world_size), "LOCAL_RANK": str(local_rank),
'MASTER_ADDR': 'localhost', "WORLD_SIZE": str(world_size),
'MASTER_PORT': '12345', "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator, cuda_communicator = typing.cast(
get_tp_group().device_communicator) CudaCommunicator, get_tp_group().device_communicator
)
symm_mem_comm = cuda_communicator.symm_mem_comm symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled: if symm_mem_comm is None or symm_mem_comm.disabled:
# can't use skip under multiprocessing # can't use skip under multiprocessing
q.put("SymmMemCommunicator is not available or disabled.") q.put("SymmMemCommunicator is not available or disabled.")
return return
inp_direct_symm_mem = torch.randint(1, inp_direct_symm_mem = torch.randint(
23, (test_size_elements, ), 1, 23, (test_size_elements,), dtype=dtype, device=device
dtype=dtype, )
device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
# can't use skip under multiprocessing # can't use skip under multiprocessing
q.put( q.put("SymmMemCommunicator isn't used for this world and input size.")
"SymmMemCommunicator isn't used for this world and input size."
)
return return
original_inp_direct_symm_mem = inp_direct_symm_mem.clone() original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
@ -78,42 +78,37 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
group = get_tp_group().device_group group = get_tp_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group) dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem, torch.testing.assert_close(
original_inp_direct_symm_mem, out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1
atol=2.5, )
rtol=0.1)
# Test tensor_model_parallel_all_reduce which should use symm_mem # Test tensor_model_parallel_all_reduce which should use symm_mem
inp_tensor_parallel = torch.randint(-23, inp_tensor_parallel = torch.randint(
1, (test_size_elements, ), -23, 1, (test_size_elements,), dtype=dtype, device=device
dtype=dtype, )
device=device)
original_inp_tensor_parallel = inp_tensor_parallel.clone() original_inp_tensor_parallel = inp_tensor_parallel.clone()
out_tensor_parallel = tensor_model_parallel_all_reduce( out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel)
inp_tensor_parallel)
dist.all_reduce(original_inp_tensor_parallel, group=group) dist.all_reduce(original_inp_tensor_parallel, group=group)
torch.testing.assert_close(out_tensor_parallel, torch.testing.assert_close(
original_inp_tensor_parallel, out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1
atol=2.5, )
rtol=0.1)
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.") reason="SymmMemAllreduce is only available for CUDA platforms.",
)
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1]) @pytest.mark.parametrize("pipeline_parallel_size", [1])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
reason="Only test on CUDA") def test_symm_mem_allreduce(
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size
pipeline_parallel_size): ):
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
q = mp.get_context('spawn').Queue() q = mp.get_context("spawn").Queue()
mp.spawn(symm_mem_allreduce_worker, mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size)
args=(world_size, q),
nprocs=world_size)
try: try:
val = q.get(timeout=1) val = q.get(timeout=1)
except queue.Empty: except queue.Empty:
@ -126,18 +121,20 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.") reason="SymmMemAllreduce is only available for CUDA platforms.",
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], )
reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch): def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
world_size = 4 world_size = 4
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
# Verify that the DataParallel runs without error # Verify that the DataParallel runs without error
engine_args = EngineArgs(model="distilbert/distilgpt2", engine_args = EngineArgs(
enforce_eager=True, model="distilbert/distilgpt2",
enable_prefix_caching=True, enforce_eager=True,
data_parallel_size=2, enable_prefix_caching=True,
tensor_parallel_size=2, data_parallel_size=2,
data_parallel_backend="mp") tensor_parallel_size=2,
data_parallel_backend="mp",
)
LLMEngine.from_engine_args(engine_args) LLMEngine.from_engine_args(engine_args)

View File

@ -24,13 +24,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# set different `gpu_memory_utilization` and `swap_space` for different ranks, # set different `gpu_memory_utilization` and `swap_space` for different ranks,
# to test if all ranks agree on the same kv cache configuration. # to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="facebook/opt-125m", llm = LLM(
tensor_parallel_size=2, model="facebook/opt-125m",
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), tensor_parallel_size=2,
distributed_executor_backend="external_launcher", pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
gpu_memory_utilization=random.uniform(0.7, 0.9), distributed_executor_backend="external_launcher",
swap_space=random.randint(1, 4), gpu_memory_utilization=random.uniform(0.7, 0.9),
seed=0) swap_space=random.randint(1, 4),
seed=0,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj):
assert container[0] == obj assert container[0] == obj
test_consistent_across_ranks( test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process # make sure we can access the model parameters from the calling process
# of the `LLM` instance. # of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. params = list(
model.parameters()) llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
)
test_consistent_across_ranks(len(params)) test_consistent_across_ranks(len(params))
# all ranks should have the same outputs # all ranks should have the same outputs
@ -65,5 +66,4 @@ for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
test_consistent_across_ranks(prompt) test_consistent_across_ranks(prompt)
test_consistent_across_ranks(generated_text) test_consistent_across_ranks(generated_text)
print(f"Rank {torch_rank}, Prompt: {prompt!r}, " print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
f"Generated text: {generated_text!r}")

View File

@ -24,23 +24,22 @@ dp_rank = int(os.getenv("DP_RANK", "0"))
if dp_size > 1: if dp_size > 1:
# distribute the prompts across the data parallel ranks # distribute the prompts across the data parallel ranks
prompts = [ prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank]
prompt for idx, prompt in enumerate(prompts)
if idx % dp_size == dp_rank
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# set different `gpu_memory_utilization` and `swap_space` for different ranks, # set different `gpu_memory_utilization` and `swap_space` for different ranks,
# to test if all ranks agree on the same kv cache configuration. # to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="microsoft/Phi-mini-MoE-instruct", llm = LLM(
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")), model="microsoft/Phi-mini-MoE-instruct",
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")), tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
distributed_executor_backend="external_launcher", enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
gpu_memory_utilization=random.uniform(0.7, 0.9), distributed_executor_backend="external_launcher",
swap_space=random.randint(1, 4), gpu_memory_utilization=random.uniform(0.7, 0.9),
seed=0) swap_space=random.randint(1, 4),
seed=0,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
@ -54,21 +53,18 @@ def test_consistent_across_ranks(obj):
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group) dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
else: else:
container = [None] container = [None]
dist.broadcast_object_list(container, dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group)
src=group.ranks[0],
group=cpu_group)
assert container[0] == obj assert container[0] == obj
test_consistent_across_ranks( test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process # make sure we can access the model parameters from the calling process
# of the `LLM` instance. # of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. params = list(
model.parameters()) llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
)
test_consistent_across_ranks(len(params)) test_consistent_across_ranks(len(params))
# all ranks should have the same outputs # all ranks should have the same outputs
@ -77,5 +73,4 @@ for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
test_consistent_across_ranks(prompt) test_consistent_across_ranks(prompt)
test_consistent_across_ranks(generated_text) test_consistent_across_ranks(generated_text)
print(f"Rank {group_rank}, Prompt: {prompt!r}, " print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
f"Generated text: {generated_text!r}")

View File

@ -10,21 +10,22 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (cuda_device_count_stateless, get_open_port, from vllm.utils import (
update_environment_variables) cuda_device_count_stateless,
get_open_port,
update_environment_variables,
)
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
@ray.remote @ray.remote
class _CUDADeviceCountStatelessTestActor: class _CUDADeviceCountStatelessTestActor:
def get_count(self): def get_count(self):
return cuda_device_count_stateless() return cuda_device_count_stateless()
def set_cuda_visible_devices(self, cuda_visible_devices: str): def set_cuda_visible_devices(self, cuda_visible_devices: str):
update_environment_variables( update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
def get_cuda_visible_devices(self): def get_cuda_visible_devices(self):
return envs.CUDA_VISIBLE_DEVICES return envs.CUDA_VISIBLE_DEVICES
@ -34,10 +35,9 @@ def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if """Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed.""" CUDA_VISIBLE_DEVICES is changed."""
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote() num_gpus=2
assert len( ).remote()
sorted(ray.get( assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2
actor.get_cuda_visible_devices.remote()).split(","))) == 2
assert ray.get(actor.get_count.remote()) == 2 assert ray.get(actor.get_count.remote()) == 2
ray.get(actor.set_cuda_visible_devices.remote("0")) ray.get(actor.set_cuda_visible_devices.remote("0"))
assert ray.get(actor.get_count.remote()) == 1 assert ray.get(actor.get_count.remote()) == 1
@ -46,15 +46,13 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE, port1, port2): def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(host="127.0.0.1", pg1 = StatelessProcessGroup.create(
port=port1, host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
rank=rank, )
world_size=WORLD_SIZE)
if rank <= 2: if rank <= 2:
pg2 = StatelessProcessGroup.create(host="127.0.0.1", pg2 = StatelessProcessGroup.create(
port=port2, host="127.0.0.1", port=port2, rank=rank, world_size=3
rank=rank, )
world_size=3)
data = torch.tensor([rank]) data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2) data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2 assert data.item() == 2
@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def gpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(host="127.0.0.1", pg1 = StatelessProcessGroup.create(
port=port1, host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
rank=rank, )
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank) pynccl1 = PyNcclCommunicator(pg1, device=rank)
if rank <= 2: if rank <= 2:
pg2 = StatelessProcessGroup.create(host="127.0.0.1", pg2 = StatelessProcessGroup.create(
port=port2, host="127.0.0.1", port=port2, rank=rank, world_size=3
rank=rank, )
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank) pynccl2 = PyNcclCommunicator(pg2, device=rank)
data = torch.tensor([rank]).cuda() data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data) pynccl1.all_reduce(data)
@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
def broadcast_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(host="127.0.0.1", pg1 = StatelessProcessGroup.create(
port=port1, host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
rank=rank, )
world_size=WORLD_SIZE)
if rank == 2: if rank == 2:
pg1.broadcast_obj("secret", src=2) pg1.broadcast_obj("secret", src=2)
else: else:
@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
def allgather_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(host="127.0.0.1", pg1 = StatelessProcessGroup.create(
port=port1, host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
rank=rank, )
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank) data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE)) assert data == list(range(WORLD_SIZE))
pg1.barrier() pg1.barrier()
@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
@pytest.mark.skip(reason="This test is flaky and prone to hang.") @pytest.mark.skip(reason="This test is flaky and prone to hang.")
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]
)
def test_stateless_process_group(worker): def test_stateless_process_group(worker):
port1 = get_open_port() port1 = get_open_port()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@ -129,12 +124,14 @@ def test_stateless_process_group(worker):
port2 = get_open_port() port2 = get_open_port()
WORLD_SIZE = 4 WORLD_SIZE = 4
from multiprocessing import get_context from multiprocessing import get_context
ctx = get_context("fork") ctx = get_context("fork")
processes = [] processes = []
for i in range(WORLD_SIZE): for i in range(WORLD_SIZE):
rank = i rank = i
processes.append( processes.append(
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))) ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))
)
for p in processes: for p in processes:
p.start() p.start()
for p in processes: for p in processes:

View File

@ -10,22 +10,30 @@ from typing import Annotated, Literal, Optional, Union
import pytest import pytest
from vllm.config import CompilationConfig, config from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, from vllm.engine.arg_utils import (
get_type, get_type_hints, is_not_builtin, EngineArgs,
is_type, literal_to_kwargs, optional_type, contains_type,
parse_type) get_kwargs,
get_type,
get_type_hints,
is_not_builtin,
is_type,
literal_to_kwargs,
optional_type,
parse_type,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [ @pytest.mark.parametrize(
(int, "42", 42), ("type", "value", "expected"),
(float, "3.14", 3.14), [
(str, "Hello World!", "Hello World!"), (int, "42", 42),
(json.loads, '{"foo":1,"bar":2}', { (float, "3.14", 3.14),
"foo": 1, (str, "Hello World!", "Hello World!"),
"bar": 2 (json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}),
}), ],
]) )
def test_parse_type(type, value, expected): def test_parse_type(type, value, expected):
parse_type_func = parse_type(type) parse_type_func = parse_type(type)
assert parse_type_func(value) == expected assert parse_type_func(value) == expected
@ -37,50 +45,56 @@ def test_optional_type():
assert optional_type_func("42") == 42 assert optional_type_func("42") == 42
@pytest.mark.parametrize(("type_hint", "type", "expected"), [ @pytest.mark.parametrize(
(int, int, True), ("type_hint", "type", "expected"),
(int, float, False), [
(list[int], list, True), (int, int, True),
(list[int], tuple, False), (int, float, False),
(Literal[0, 1], Literal, True), (list[int], list, True),
]) (list[int], tuple, False),
(Literal[0, 1], Literal, True),
],
)
def test_is_type(type_hint, type, expected): def test_is_type(type_hint, type, expected):
assert is_type(type_hint, type) == expected assert is_type(type_hint, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [ @pytest.mark.parametrize(
({float, int}, int, True), ("type_hints", "type", "expected"),
({int, tuple}, int, True), [
({int, tuple[int]}, int, True), ({float, int}, int, True),
({int, tuple[int, ...]}, int, True), ({int, tuple}, int, True),
({int, tuple[int]}, float, False), ({int, tuple[int]}, int, True),
({int, tuple[int, ...]}, float, False), ({int, tuple[int, ...]}, int, True),
({str, Literal["x", "y"]}, Literal, True), ({int, tuple[int]}, float, False),
]) ({int, tuple[int, ...]}, float, False),
({str, Literal["x", "y"]}, Literal, True),
],
)
def test_contains_type(type_hints, type, expected): def test_contains_type(type_hints, type, expected):
assert contains_type(type_hints, type) == expected assert contains_type(type_hints, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [ @pytest.mark.parametrize(
({int, float}, int, int), ("type_hints", "type", "expected"),
({int, float}, str, None), [
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), ({int, float}, int, int),
]) ({int, float}, str, None),
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
],
)
def test_get_type(type_hints, type, expected): def test_get_type(type_hints, type, expected):
assert get_type(type_hints, type) == expected assert get_type(type_hints, type) == expected
@pytest.mark.parametrize(("type_hints", "expected"), [ @pytest.mark.parametrize(
({Literal[1, 2]}, { ("type_hints", "expected"),
"type": int, [
"choices": [1, 2] ({Literal[1, 2]}, {"type": int, "choices": [1, 2]}),
}), ({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}),
({str, Literal["x", "y"]}, { ({Literal[1, "a"]}, Exception),
"type": str, ],
"metavar": ["x", "y"] )
}),
({Literal[1, "a"]}, Exception),
])
def test_literal_to_kwargs(type_hints, expected): def test_literal_to_kwargs(type_hints, expected):
context = nullcontext() context = nullcontext()
if expected is Exception: if expected is Exception:
@ -123,22 +137,27 @@ class DummyConfig:
"""Nested config""" """Nested config"""
@pytest.mark.parametrize(("type_hint", "expected"), [ @pytest.mark.parametrize(
(int, False), ("type_hint", "expected"),
(DummyConfig, True), [
]) (int, False),
(DummyConfig, True),
],
)
def test_is_not_builtin(type_hint, expected): def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected assert is_not_builtin(type_hint) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
("type_hint", "expected"), [ ("type_hint", "expected"),
[
(Annotated[int, "annotation"], {int}), (Annotated[int, "annotation"], {int}),
(Optional[int], {int, type(None)}), (Optional[int], {int, type(None)}),
(Annotated[Optional[int], "annotation"], {int, type(None)}), (Annotated[Optional[int], "annotation"], {int, type(None)}),
(Optional[Annotated[int, "annotation"]], {int, type(None)}), (Optional[Annotated[int, "annotation"]], {int, type(None)}),
], ],
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"]) ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"],
)
def test_get_type_hints(type_hint, expected): def test_get_type_hints(type_hint, expected):
assert get_type_hints(type_hint) == expected assert get_type_hints(type_hint) == expected
@ -178,24 +197,16 @@ def test_get_kwargs():
("arg", "expected"), ("arg", "expected"),
[ [
(None, dict()), (None, dict()),
('{"video": {"num_frames": 123} }', { ('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}),
"video": {
"num_frames": 123
}
}),
( (
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa '{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
{ {
"video": { "video": {"num_frames": 123, "fps": 1.0, "foo": "bar"},
"num_frames": 123, "image": {"foo": "bar"},
"fps": 1.0, },
"foo": "bar" ),
}, ],
"image": { )
"foo": "bar"
}
}),
])
def test_media_io_kwargs_parser(arg, expected): def test_media_io_kwargs_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None: if arg is None:
@ -230,24 +241,32 @@ def test_compilation_config():
assert args.compilation_config.level == 3 assert args.compilation_config.level == 3
# set to string form of a dict # set to string form of a dict
args = parser.parse_args([ args = parser.parse_args(
"-O", [
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' "-O",
'"use_inductor": false}', '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
]) '"use_inductor": false}',
assert (args.compilation_config.level == 3 and ]
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] )
and not args.compilation_config.use_inductor) assert (
args.compilation_config.level == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and not args.compilation_config.use_inductor
)
# set to string form of a dict # set to string form of a dict
args = parser.parse_args([ args = parser.parse_args(
"--compilation-config=" [
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' "--compilation-config="
'"use_inductor": true}', '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
]) '"use_inductor": true}',
assert (args.compilation_config.level == 3 and ]
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] )
and args.compilation_config.use_inductor) assert (
args.compilation_config.level == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and args.compilation_config.use_inductor
)
def test_prefix_cache_default(): def test_prefix_cache_default():
@ -255,8 +274,7 @@ def test_prefix_cache_default():
args = parser.parse_args([]) args = parser.parse_args([])
engine_args = EngineArgs.from_cli_args(args=args) engine_args = EngineArgs.from_cli_args(args=args)
assert (not engine_args.enable_prefix_caching assert not engine_args.enable_prefix_caching, "prefix caching defaults to off."
), "prefix caching defaults to off."
# with flag to turn it on. # with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"]) args = parser.parse_args(["--enable-prefix-caching"])

View File

@ -5,12 +5,12 @@ import pytest
from ..conftest import IMAGE_ASSETS from ..conftest import IMAGE_ASSETS
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
"stop_sign": {
"USER: <image>\nWhat's the content of the image?\nASSISTANT:", "stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom": "cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:",
"USER: <image>\nWhat is the season?\nASSISTANT:", }
}) )
models = ["llava-hf/llava-1.5-7b-hf"] models = ["llava-hf/llava-1.5-7b-hf"]
@ -19,8 +19,7 @@ models = ["llava-hf/llava-1.5-7b-hf"]
def test_context_length_too_short(vllm_runner, image_assets, model): def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
with pytest.raises(ValueError, with pytest.raises(ValueError, match="longer than the maximum model length"):
match="longer than the maximum model length"):
vllm_model = vllm_runner( vllm_model = vllm_runner(
model, model,
max_model_len=128, # LLaVA has a feature size of 576 max_model_len=128, # LLaVA has a feature size of 576
@ -29,6 +28,6 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
) )
with vllm_model: with vllm_model:
vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], vllm_model.generate_greedy(
max_tokens=1, [HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]]
images=[images[0]]) )

View File

@ -26,8 +26,10 @@ def sample_token_ids():
@pytest.fixture @pytest.fixture
def sample_regex(): def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" return (
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
)
@pytest.fixture @pytest.fixture
@ -35,40 +37,27 @@ def sample_json_schema():
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "name": {"type": "string"},
"type": "string" "age": {"type": "integer"},
},
"age": {
"type": "integer"
},
"skills": { "skills": {
"type": "array", "type": "array",
"items": { "items": {"type": "string", "maxLength": 10},
"type": "string", "minItems": 3,
"maxLength": 10
},
"minItems": 3
}, },
"work_history": { "work_history": {
"type": "array", "type": "array",
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"company": { "company": {"type": "string"},
"type": "string" "duration": {"type": "number"},
}, "position": {"type": "string"},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
}, },
"required": ["company", "position"] "required": ["company", "position"],
} },
} },
}, },
"required": ["name", "age", "skills", "work_history"] "required": ["name", "age", "skills", "work_history"],
} }
@ -80,65 +69,53 @@ def sample_complex_json_schema():
"score": { "score": {
"type": "integer", "type": "integer",
"minimum": 0, "minimum": 0,
"maximum": 100 # Numeric range "maximum": 100, # Numeric range
}, },
"grade": { "grade": {
"type": "string", "type": "string",
"pattern": "^[A-D]$" # Regex pattern "pattern": "^[A-D]$", # Regex pattern
}, },
"email": { "email": {
"type": "string", "type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$",
}, },
"tags": { "tags": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string", "type": "string",
"pattern": "pattern": "^[a-z]{1,10}$", # Combining length and pattern restrictions
"^[a-z]{1,10}$" # Combining length and pattern restrictions },
} },
}
}, },
"required": ["score", "grade", "email", "tags"] "required": ["score", "grade", "email", "tags"],
} }
@pytest.fixture @pytest.fixture
def sample_definition_json_schema(): def sample_definition_json_schema():
return { return {
'$defs': { "$defs": {
'Step': { "Step": {
'properties': { "properties": {
'explanation': { "explanation": {"title": "Explanation", "type": "string"},
'title': 'Explanation', "output": {"title": "Output", "type": "string"},
'type': 'string'
},
'output': {
'title': 'Output',
'type': 'string'
}
}, },
'required': ['explanation', 'output'], "required": ["explanation", "output"],
'title': 'Step', "title": "Step",
'type': 'object' "type": "object",
} }
}, },
'properties': { "properties": {
'steps': { "steps": {
'items': { "items": {"$ref": "#/$defs/Step"},
'$ref': '#/$defs/Step' "title": "Steps",
}, "type": "array",
'title': 'Steps',
'type': 'array'
}, },
'final_answer': { "final_answer": {"title": "Final Answer", "type": "string"},
'title': 'Final Answer',
'type': 'string'
}
}, },
'required': ['steps', 'final_answer'], "required": ["steps", "final_answer"],
'title': 'MathReasoning', "title": "MathReasoning",
'type': 'object' "type": "object",
} }
@ -149,64 +126,71 @@ def sample_enum_json_schema():
"properties": { "properties": {
"status": { "status": {
"type": "string", "type": "string",
"enum": ["active", "inactive", "enum": ["active", "inactive", "pending"], # Literal values using enum
"pending"] # Literal values using enum
}, },
"priority": { "priority": {
"type": "string", "type": "string",
"enum": ["low", "medium", "high", "critical"] "enum": ["low", "medium", "high", "critical"],
}, },
"category": { "category": {
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"enum": ["bug", "feature", "improvement"] "enum": ["bug", "feature", "improvement"],
}, },
"severity": { "severity": {
"type": "integer", "type": "integer",
"enum": [1, 2, 3, 4, "enum": [1, 2, 3, 4, 5], # Enum can also contain numbers
5] # Enum can also contain numbers },
}
}, },
"required": ["type", "severity"] "required": ["type", "severity"],
}, },
"flags": { "flags": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string", "type": "string",
"enum": ["urgent", "blocked", "needs_review", "approved"] "enum": ["urgent", "blocked", "needs_review", "approved"],
} },
} },
}, },
"required": ["status", "priority", "category", "flags"] "required": ["status", "priority", "category", "flags"],
} }
@pytest.fixture @pytest.fixture
def sample_structured_outputs_choices(): def sample_structured_outputs_choices():
return [ return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Python",
"Ruby", "Swift", "Kotlin" "Java",
"JavaScript",
"C++",
"C#",
"PHP",
"TypeScript",
"Ruby",
"Swift",
"Kotlin",
] ]
@pytest.fixture @pytest.fixture
def sample_sql_statements(): def sample_sql_statements():
return (""" return """
start: select_statement start: select_statement
select_statement: "SELECT" column "from" table "where" condition select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2" column: "col_1" | "col_2"
table: "table_1" | "table_2" table: "table_1" | "table_2"
condition: column "=" number condition: column "=" number
number: "1" | "2" number: "1" | "2"
""") """
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def zephyr_lora_files(): def zephyr_lora_files():
"""Download zephyr LoRA files once per test session.""" """Download zephyr LoRA files once per test session."""
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
@ -214,5 +198,5 @@ def zephyr_lora_files():
def opt125_lora_files() -> str: def opt125_lora_files() -> str:
"""Download opt-125m LoRA files once per test session.""" """Download opt-125m LoRA files once per test session."""
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
return snapshot_download(
repo_id="peft-internal-testing/opt-125m-dummy-lora") return snapshot_download(repo_id="peft-internal-testing/opt-125m-dummy-lora")

View File

@ -48,20 +48,23 @@ def run_test(model_name, more_args=None):
measured_value = results["results"][TASK][FILTER] measured_value = results["results"][TASK][FILTER]
assert model_name in EXPECTED_VALUES, ( assert model_name in EXPECTED_VALUES, (
f"Cannot find the expected value for the model {model_name=}") f"Cannot find the expected value for the model {model_name=}"
)
expected_value = EXPECTED_VALUES[model_name] expected_value = EXPECTED_VALUES[model_name]
assert (measured_value - RTOL < expected_value assert (
and measured_value + RTOL > expected_value measured_value - RTOL < expected_value
), f"Expected: {expected_value} | Measured: {measured_value}" and measured_value + RTOL > expected_value
), f"Expected: {expected_value} | Measured: {measured_value}"
# TODO: [AlexM] Fix it with new CI/CD tests # TODO: [AlexM] Fix it with new CI/CD tests
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" TPU_TP_TEST_STR = "" # "tensor_parallel_size=4"
@pytest.mark.skipif(not current_platform.is_cuda() @pytest.mark.skipif(
and not current_platform.is_tpu(), not current_platform.is_cuda() and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU") reason="V1 is currently only supported on CUDA and TPU",
)
@pytest.mark.parametrize("model", MODEL_NAMES) @pytest.mark.parametrize("model", MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine.""" """Run with the V1 Engine."""
@ -82,12 +85,14 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test(model, more_args) run_test(model, more_args)
@pytest.mark.skipif(not current_platform.is_cuda() @pytest.mark.skipif(
and not current_platform.is_tpu(), not current_platform.is_cuda() and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU") reason="V1 is currently only supported on CUDA and TPU",
)
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES) @pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
model, monkeypatch: pytest.MonkeyPatch): model, monkeypatch: pytest.MonkeyPatch
):
"""Run with the V1 Engine.""" """Run with the V1 Engine."""
with monkeypatch.context() as m: with monkeypatch.context() as m:

View File

@ -14,9 +14,7 @@ from ..openai.test_vision import TEST_IMAGE_ASSETS
def text_llm(): def text_llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
# enable garbage collection # enable garbage collection
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0)
enforce_eager=True,
seed=0)
yield weakref.proxy(llm) yield weakref.proxy(llm)
@ -28,14 +26,8 @@ def text_llm():
def test_chat(text_llm): def test_chat(text_llm):
prompt1 = "Explain the concept of entropy." prompt1 = "Explain the concept of entropy."
messages = [ messages = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": prompt1},
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
] ]
outputs = text_llm.chat(messages) outputs = text_llm.chat(messages)
assert len(outputs) == 1 assert len(outputs) == 1
@ -46,25 +38,13 @@ def test_multi_chat(text_llm):
prompt2 = "Explain what among us is." prompt2 = "Explain what among us is."
conversation1 = [ conversation1 = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": prompt1},
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
] ]
conversation2 = [ conversation2 = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": prompt2},
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
] ]
messages = [conversation1, conversation2] messages = [conversation1, conversation2]
@ -94,26 +74,22 @@ def vision_llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize("image_urls", @pytest.mark.parametrize(
[[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], "image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True
indirect=True) )
def test_chat_multi_image(vision_llm, image_urls: list[str]): def test_chat_multi_image(vision_llm, image_urls: list[str]):
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*({ *(
"type": "image_url", {"type": "image_url", "image_url": {"url": image_url}}
"image_url": { for image_url in image_urls
"url": image_url ),
} {"type": "text", "text": "What's in this image?"},
} for image_url in image_urls), ],
{ }
"type": "text", ]
"text": "What's in this image?"
},
],
}]
outputs = vision_llm.chat(messages) outputs = vision_llm.chat(messages)
assert len(outputs) >= 0 assert len(outputs) >= 0
@ -124,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm):
Check we get a single BOS token for llama chat. Check we get a single BOS token for llama chat.
""" """
messages = [ messages = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": "Hello!"},
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello!"
},
] ]
outputs = text_llm.chat(messages) outputs = text_llm.chat(messages)
assert len(outputs) == 1 assert len(outputs) == 1
@ -167,14 +137,8 @@ def thinking_llm():
@pytest.mark.parametrize("enable_thinking", [True, False]) @pytest.mark.parametrize("enable_thinking", [True, False])
def test_chat_extra_kwargs(thinking_llm, enable_thinking): def test_chat_extra_kwargs(thinking_llm, enable_thinking):
messages = [ messages = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": "What is 1+1?"},
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is 1+1?"
},
] ]
outputs = thinking_llm.chat( outputs = thinking_llm.chat(

View File

@ -23,9 +23,11 @@ def test_collective_rpc(tp_size, backend, monkeypatch):
return self.rank return self.rank
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(
enforce_eager=True, model="meta-llama/Llama-3.2-1B-Instruct",
load_format="dummy", enforce_eager=True,
tensor_parallel_size=tp_size, load_format="dummy",
distributed_executor_backend=backend) tensor_parallel_size=tp_size,
distributed_executor_backend=backend,
)
assert llm.collective_rpc(echo_rank) == list(range(tp_size)) assert llm.collective_rpc(echo_rank) == list(range(tp_size))

View File

@ -29,11 +29,13 @@ TOKEN_IDS = [
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
# enable garbage collection # enable garbage collection
llm = LLM(model=MODEL_NAME, llm = LLM(
max_num_batched_tokens=4096, model=MODEL_NAME,
tensor_parallel_size=1, max_num_batched_tokens=4096,
gpu_memory_utilization=0.10, tensor_parallel_size=1,
enforce_eager=True) gpu_memory_utilization=0.10,
enforce_eager=True,
)
yield weakref.proxy(llm) yield weakref.proxy(llm)
@ -81,7 +83,8 @@ def test_max_model_len():
outputs = llm.generate(PROMPTS, sampling_params) outputs = llm.generate(PROMPTS, sampling_params)
for output in outputs: for output in outputs:
num_total_tokens = len(output.prompt_token_ids) + len( num_total_tokens = len(output.prompt_token_ids) + len(
output.outputs[0].token_ids) output.outputs[0].token_ids
)
# Total tokens must not exceed max_model_len + 1 (the last token can be # Total tokens must not exceed max_model_len + 1 (the last token can be
# generated with the context length equal to the max model length) # generated with the context length equal to the max model length)
# It can be less if generation finishes due to other reasons (e.g., EOS) # It can be less if generation finishes due to other reasons (e.g., EOS)

View File

@ -16,9 +16,8 @@ def test_gpu_memory_utilization():
# makes sure gpu_memory_utilization is per-instance limit, # makes sure gpu_memory_utilization is per-instance limit,
# not a global limit # not a global limit
llms = [ llms = [
LLM(model="facebook/opt-125m", LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True)
gpu_memory_utilization=0.3, for i in range(3)
enforce_eager=True) for i in range(3)
] ]
for llm in llms: for llm in llms:
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -8,12 +8,12 @@ from vllm import LLM
def test_empty_prompt(): def test_empty_prompt():
llm = LLM(model="openai-community/gpt2", enforce_eager=True) llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='decoder prompt cannot be empty'): with pytest.raises(ValueError, match="decoder prompt cannot be empty"):
llm.generate([""]) llm.generate([""])
@pytest.mark.skip_v1 @pytest.mark.skip_v1
def test_out_of_vocab_token(): def test_out_of_vocab_token():
llm = LLM(model="openai-community/gpt2", enforce_eager=True) llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='out of vocabulary'): with pytest.raises(ValueError, match="out of vocabulary"):
llm.generate({"prompt_token_ids": [999999]}) llm.generate({"prompt_token_ids": [999999]})

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for HF_HUB_OFFLINE mode""" """Tests for HF_HUB_OFFLINE mode"""
import dataclasses import dataclasses
import importlib import importlib
import sys import sys
@ -91,12 +92,11 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch):
def _re_import_modules(): def _re_import_modules():
hf_hub_module_names = [ hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")]
k for k in sys.modules if k.startswith("huggingface_hub")
]
transformers_module_names = [ transformers_module_names = [
k for k in sys.modules if k.startswith("transformers") k
and not k.startswith("transformers_modules") for k in sys.modules
if k.startswith("transformers") and not k.startswith("transformers_modules")
] ]
reload_exception = None reload_exception = None

View File

@ -7,14 +7,14 @@ from vllm.assets.audio import AudioAsset
@pytest.fixture @pytest.fixture
def mary_had_lamb(): def mary_had_lamb():
path = AudioAsset('mary_had_lamb').get_local_path() path = AudioAsset("mary_had_lamb").get_local_path()
with open(str(path), "rb") as f: with open(str(path), "rb") as f:
yield f yield f
@pytest.fixture @pytest.fixture
def winning_call(): def winning_call():
path = AudioAsset('winning_call').get_local_path() path = AudioAsset("winning_call").get_local_path()
with open(str(path), "rb") as f: with open(str(path), "rb") as f:
yield f yield f
@ -22,6 +22,6 @@ def winning_call():
@pytest.fixture @pytest.fixture
def foscolo(): def foscolo():
# Test translation it->en # Test translation it->en
path = AudioAsset('azacinto_foscolo').get_local_path() path = AudioAsset("azacinto_foscolo").get_local_path()
with open(str(path), "rb") as f: with open(str(path), "rb") as f:
yield f yield f

View File

@ -44,14 +44,15 @@ def run_test(more_args):
print(f"Running with: {args}") print(f"Running with: {args}")
with RemoteOpenAIServer( with RemoteOpenAIServer(
MODEL_NAME, args, MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS
max_wait_seconds=MAX_WAIT_SECONDS) as remote_server: ) as remote_server:
url = f"{remote_server.url_for('v1')}/completions" url = f"{remote_server.url_for('v1')}/completions"
model_args = ( model_args = (
f"model={MODEL_NAME}," f"model={MODEL_NAME},"
f"base_url={url}," f"base_url={url},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
)
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model="local-completions", model="local-completions",
@ -60,15 +61,18 @@ def run_test(more_args):
) )
measured_value = results["results"][TASK][FILTER] measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE assert (
and measured_value + RTOL > EXPECTED_VALUE measured_value - RTOL < EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not current_platform.is_cuda() @pytest.mark.skipif(
and not current_platform.is_tpu() not current_platform.is_cuda()
and not current_platform.is_xpu(), and not current_platform.is_tpu()
reason="V1 currently only supported on CUDA, XPU and TPU") and not current_platform.is_xpu(),
reason="V1 currently only supported on CUDA, XPU and TPU",
)
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine.""" """Run with the V1 Engine."""

View File

@ -7,6 +7,7 @@ a baseline.
This simulates real work usage of the API and makes sure that the frontend and This simulates real work usage of the API and makes sure that the frontend and
AsyncLLMEngine are working correctly. AsyncLLMEngine are working correctly.
""" """
import asyncio import asyncio
import io import io
import time import time
@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr):
# NOTE there's no streaming in transcriptions, can't measure ttft # NOTE there's no streaming in transcriptions, can't measure ttft
latency = end_time - start_time latency = end_time - start_time
num_output_tokens = len( num_output_tokens = len(
tokenizer(transcription.text, add_special_tokens=False).input_ids) tokenizer(transcription.text, add_special_tokens=False).input_ids
)
return latency, num_output_tokens, transcription.text return latency, num_output_tokens, transcription.text
@ -73,8 +75,8 @@ async def process_dataset(model, client, data, concurrent_request):
for sample in data: for sample in data:
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
task = asyncio.create_task( task = asyncio.create_task(
bound_transcribe(sem, client, tokenizer, (audio, sr), bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"])
sample["text"])) )
tasks.append(task) tasks.append(task)
return await asyncio.gather(*tasks) return await asyncio.gather(*tasks)
@ -98,34 +100,35 @@ def print_performance_metrics(results, total_time):
def add_duration(sample): def add_duration(sample):
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
return sample return sample
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs):
## Load and filter the dataset ## Load and filter the dataset
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
if 'duration_ms' not in dataset[0]: if "duration_ms" not in dataset[0]:
# compute duration to filter # compute duration to filter
dataset = dataset.map(add_duration) dataset = dataset.map(add_duration)
# Whisper max supported duration # Whisper max supported duration
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) dataset = dataset.filter(lambda example: example["duration_ms"] < 30000)
return dataset return dataset
def run_evaluation(model: str, def run_evaluation(
client, model: str,
dataset, client,
max_concurrent_reqs: int, dataset,
n_examples: int = -1, max_concurrent_reqs: int,
print_metrics: bool = True): n_examples: int = -1,
print_metrics: bool = True,
):
if n_examples > 0: if n_examples > 0:
dataset = dataset.select(range(n_examples)) dataset = dataset.select(range(n_examples))
start = time.perf_counter() start = time.perf_counter()
results = asyncio.run( results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs))
process_dataset(model, client, dataset, max_concurrent_reqs))
end = time.perf_counter() end = time.perf_counter()
total_time = end - start total_time = end - start
print(f"Total Test Time: {total_time:.4f} seconds") print(f"Total Test Time: {total_time:.4f} seconds")
@ -135,8 +138,7 @@ def run_evaluation(model: str,
predictions = [res[2] for res in results] predictions = [res[2] for res in results]
references = [res[3] for res in results] references = [res[3] for res in results]
wer = load("wer") wer = load("wer")
wer_score = 100 * wer.compute(references=references, wer_score = 100 * wer.compute(references=references, predictions=predictions)
predictions=predictions)
print("WER:", wer_score) print("WER:", wer_score)
return wer_score return wer_score
@ -145,26 +147,25 @@ def run_evaluation(model: str,
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) @pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. # Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
)
# NOTE: Expected WER measured with equivalent hf.transformers args: # NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. # whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize("expected_wer", [12.744980]) @pytest.mark.parametrize("expected_wer", [12.744980])
def test_wer_correctness(model_name, def test_wer_correctness(
dataset_repo, model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
expected_wer, ):
n_examples=-1,
max_concurrent_request=None):
# TODO refactor to use `ASRDataset` # TODO refactor to use `ASRDataset`
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
dataset = load_hf_dataset(dataset_repo) dataset = load_hf_dataset(dataset_repo)
if not max_concurrent_request: if not max_concurrent_request:
# No max concurrency # No max concurrency
max_concurrent_request = n_examples if n_examples > 0\ max_concurrent_request = n_examples if n_examples > 0 else len(dataset)
else len(dataset)
client = remote_server.get_async_client() client = remote_server.get_async_client()
wer = run_evaluation(model_name, client, dataset, wer = run_evaluation(
max_concurrent_request, n_examples) model_name, client, dataset, max_concurrent_request, n_examples
)
if expected_wer: if expected_wer:
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)

View File

@ -44,15 +44,11 @@ async def client(server):
ids=["completion", "chat"], ids=["completion", "chat"],
argnames=["create_func_gen", "content_body"], argnames=["create_func_gen", "content_body"],
argvalues=[ argvalues=[
(lambda x: x.completions.create, { (lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}),
"prompt": " ".join(['A'] * 10_000) (
}), lambda x: x.chat.completions.create,
(lambda x: x.chat.completions.create, { {"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]},
"messages": [{ ),
"role": "user",
"content": " ".join(['A'] * 10_000)
}]
}),
], ],
) )
async def test_with_and_without_truncate( async def test_with_and_without_truncate(
@ -65,15 +61,15 @@ async def test_with_and_without_truncate(
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
num_requests = 10 num_requests = 10
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * (
(num_requests - num_requests // 2)) num_requests - num_requests // 2
)
random.shuffle(truncate_prompt_tokens) random.shuffle(truncate_prompt_tokens)
bodies = [{ bodies = [
**body, "extra_body": { {**body, "extra_body": {"truncate_prompt_tokens": t}}
'truncate_prompt_tokens': t for t in truncate_prompt_tokens
} ]
} for t in truncate_prompt_tokens]
async def get_status_code(**kwargs): async def get_status_code(**kwargs):
try: try:

View File

@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio(client: openai.AsyncOpenAI, async def test_single_chat_session_audio(
model_name: str, audio_url: str): client: openai.AsyncOpenAI, model_name: str, audio_url: str
messages = [{ ):
"role": messages = [
"user", {
"content": [ "role": "user",
{ "content": [
"type": "audio_url", {"type": "audio_url", "audio_url": {"url": audio_url}},
"audio_url": { {"type": "text", "text": "What's happening in this audio?"},
"url": audio_url ],
} }
}, ]
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
max_completion_tokens=10, max_completion_tokens=10,
logprobs=True, logprobs=True,
temperature=0.0, temperature=0.0,
top_logprobs=5) top_logprobs=5,
)
assert len(chat_completion.choices) == 1 assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=202, total_tokens=212) completion_tokens=10, prompt_tokens=202, total_tokens=212
)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI, async def test_error_on_invalid_audio_url_type(
model_name: str, client: openai.AsyncOpenAI, model_name: str, audio_url: str
audio_url: str): ):
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "audio_url", "audio_url": audio_url},
"type": "audio_url", {"type": "text", "text": "What's happening in this audio?"},
"audio_url": audio_url ],
}, }
{ ]
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# audio_url should be a dict {"url": "some url"}, not directly a string # audio_url should be a dict {"url": "some url"}, not directly a string
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
_ = await client.chat.completions.create(model=model_name, _ = await client.chat.completions.create(
messages=messages, model=model_name,
max_completion_tokens=10, messages=messages,
temperature=0.0) max_completion_tokens=10,
temperature=0.0,
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio_base64encoded( async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, model_name: str, audio_url: str, client: openai.AsyncOpenAI,
base64_encoded_audio: dict[str, str]): model_name: str,
audio_url: str,
messages = [{ base64_encoded_audio: dict[str, str],
"role": ):
"user", messages = [
"content": [ {
{ "role": "user",
"type": "audio_url", "content": [
"audio_url": { {
"url": "type": "audio_url",
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" "audio_url": {
} "url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
}, },
{ },
"type": "text", {"type": "text", "text": "What's happening in this audio?"},
"text": "What's happening in this audio?" ],
}, }
], ]
}]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded(
max_completion_tokens=10, max_completion_tokens=10,
logprobs=True, logprobs=True,
temperature=0.0, temperature=0.0,
top_logprobs=5) top_logprobs=5,
)
assert len(chat_completion.choices) == 1 assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=202, total_tokens=212) completion_tokens=10, prompt_tokens=202, total_tokens=212
)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded(
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_input_audio( async def test_single_chat_session_input_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str, client: openai.AsyncOpenAI,
base64_encoded_audio: dict[str, str]): model_name: str,
messages = [{ audio_url: str,
"role": base64_encoded_audio: dict[str, str],
"user", ):
"content": [ messages = [
{ {
"type": "input_audio", "role": "user",
"input_audio": { "content": [
"data": base64_encoded_audio[audio_url], {
"format": "wav" "type": "input_audio",
} "input_audio": {
}, "data": base64_encoded_audio[audio_url],
{ "format": "wav",
"type": "text", },
"text": "What's happening in this audio?" },
}, {"type": "text", "text": "What's happening in this audio?"},
], ],
}] }
]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio(
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=10,
logprobs=True, logprobs=True,
top_logprobs=5) top_logprobs=5,
)
assert len(chat_completion.choices) == 1 assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=202, total_tokens=212) completion_tokens=10, prompt_tokens=202, total_tokens=212
)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_audio(client: openai.AsyncOpenAI, async def test_chat_streaming_audio(
model_name: str, audio_url: str): client: openai.AsyncOpenAI, model_name: str, audio_url: str
messages = [{ ):
"role": messages = [
"user", {
"content": [ "role": "user",
{ "content": [
"type": "audio_url", {"type": "audio_url", "audio_url": {"url": audio_url}},
"audio_url": { {"type": "text", "text": "What's happening in this audio?"},
"url": audio_url ],
} }
}, ]
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, async def test_chat_streaming_input_audio(
model_name: str, audio_url: str, client: openai.AsyncOpenAI,
base64_encoded_audio: dict[str, model_name: str,
str]): audio_url: str,
messages = [{ base64_encoded_audio: dict[str, str],
"role": ):
"user", messages = [
"content": [ {
{ "role": "user",
"type": "input_audio", "content": [
"input_audio": { {
"data": base64_encoded_audio[audio_url], "type": "input_audio",
"format": "wav" "input_audio": {
} "data": base64_encoded_audio[audio_url],
}, "format": "wav",
{ },
"type": "text", },
"text": "What's happening in this audio?" {"type": "text", "text": "What's happening in this audio?"},
}, ],
], }
}] ]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]) "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, )
audio_urls: list[str]): async def test_multi_audio_input(
client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str]
messages = [{ ):
"role": messages = [
"user", {
"content": [ "role": "user",
*({ "content": [
"type": "audio_url", *(
"audio_url": { {"type": "audio_url", "audio_url": {"url": audio_url}}
"url": audio_url for audio_url in audio_urls
} ),
} for audio_url in audio_urls), {"type": "text", "text": "What's happening in this audio?"},
{ ],
"type": "text", }
"text": "What's happening in this audio?" ]
},
],
}]
if len(audio_urls) > MAXIMUM_AUDIOS: if len(audio_urls) > MAXIMUM_AUDIOS:
with pytest.raises(openai.BadRequestError): # test multi-audio input with pytest.raises(openai.BadRequestError): # test multi-audio input

View File

@ -16,9 +16,9 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope='module') @pytest.fixture(scope="module")
def server_args(request: pytest.FixtureRequest) -> list[str]: def server_args(request: pytest.FixtureRequest) -> list[str]:
""" Provide extra arguments to the server via indirect parametrization """Provide extra arguments to the server via indirect parametrization
Usage: Usage:
@ -80,8 +80,10 @@ async def client(server):
"server_args", "server_args",
[ [
pytest.param([], id="default-frontend-multiprocessing"), pytest.param([], id="default-frontend-multiprocessing"),
pytest.param(["--disable-frontend-multiprocessing"], pytest.param(
id="disable-frontend-multiprocessing") ["--disable-frontend-multiprocessing"],
id="disable-frontend-multiprocessing",
),
], ],
indirect=True, indirect=True,
) )
@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer):
"server_args", "server_args",
[ [
pytest.param([], id="default-frontend-multiprocessing"), pytest.param([], id="default-frontend-multiprocessing"),
pytest.param(["--disable-frontend-multiprocessing"], pytest.param(
id="disable-frontend-multiprocessing") ["--disable-frontend-multiprocessing"],
id="disable-frontend-multiprocessing",
),
], ],
indirect=True, indirect=True,
) )
@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"server_args", "server_args",
[ [
pytest.param(["--max-model-len", "10100"], pytest.param(
id="default-frontend-multiprocessing"), ["--max-model-len", "10100"], id="default-frontend-multiprocessing"
),
pytest.param( pytest.param(
["--disable-frontend-multiprocessing", "--max-model-len", "10100"], ["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
id="disable-frontend-multiprocessing") id="disable-frontend-multiprocessing",
),
], ],
indirect=True, indirect=True,
) )
@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
# Request about 2 million tokens # Request about 2 million tokens
for _ in range(200): for _ in range(200):
task = asyncio.create_task( task = asyncio.create_task(
client.chat.completions.create(messages=chat_input, client.chat.completions.create(
model=MODEL_NAME, messages=chat_input,
max_tokens=10000, model=MODEL_NAME,
extra_body={"min_tokens": 10000})) max_tokens=10000,
extra_body={"min_tokens": 10000},
)
)
tasks.append(task) tasks.append(task)
done, pending = await asyncio.wait(tasks, done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
return_when=asyncio.ALL_COMPLETED)
# Make sure all requests were sent to the server and timed out # Make sure all requests were sent to the server and timed out
# (We don't want to hide other errors like 400s that would invalidate this # (We don't want to hide other errors like 400s that would invalidate this
@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
# If the server had not cancelled all the other requests, then it would not # If the server had not cancelled all the other requests, then it would not
# be able to respond to this one within the timeout # be able to respond to this one within the timeout
client = server.get_async_client(timeout=5) client = server.get_async_client(timeout=5)
response = await client.chat.completions.create(messages=chat_input, response = await client.chat.completions.create(
model=MODEL_NAME, messages=chat_input, model=MODEL_NAME, max_tokens=10
max_tokens=10) )
assert len(response.choices) == 1 assert len(response.choices) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_wrong_content_type(server: RemoteOpenAIServer): async def test_request_wrong_content_type(server: RemoteOpenAIServer):
chat_input = [{"role": "user", "content": "Write a long story"}] chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client() client = server.get_async_client()
@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
messages=chat_input, messages=chat_input,
model=MODEL_NAME, model=MODEL_NAME,
max_tokens=10000, max_tokens=10000,
extra_headers={ extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
"Content-Type": "application/x-www-form-urlencoded" )
})
@pytest.mark.parametrize( @pytest.mark.parametrize(
"server_args", "server_args",
[ [pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
pytest.param(["--enable-server-load-tracking"],
id="enable-server-load-tracking")
],
indirect=True, indirect=True,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -202,7 +205,8 @@ async def test_server_load(server: RemoteOpenAIServer):
# Start the completion request in a background thread. # Start the completion request in a background thread.
completion_future = asyncio.create_task( completion_future = asyncio.create_task(
asyncio.to_thread(make_long_completion_request)) asyncio.to_thread(make_long_completion_request)
)
# Give a short delay to ensure the request has started. # Give a short delay to ensure the request has started.
await asyncio.sleep(0.1) await asyncio.sleep(0.1)

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More