diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml deleted file mode 100644 index d5cad1c73c..0000000000 --- a/.buildkite/pyproject.toml +++ /dev/null @@ -1,46 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.format] -docstring-code-format = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ca414ee42..ea63ef1f52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,28 +6,16 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: - id: ruff args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos rev: v1.35.5 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.3 hooks: diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py index eae8d9927e..5434f8b6a4 100644 --- a/benchmarks/benchmark_block_pool.py +++ b/benchmarks/benchmark_block_pool.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +from benchmark_utils import TimeCollector from tabulate import tabulate -from benchmark_utils import TimeCollector from vllm.utils import FlexibleArgumentParser from vllm.v1.core.block_pool import BlockPool diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index d4b83edbd9..291d87d608 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -5,9 +5,9 @@ import time from unittest import mock import numpy as np +from benchmark_utils import TimeCollector from tabulate import tabulate -from benchmark_utils import TimeCollector from vllm.config import ( CacheConfig, DeviceConfig, diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index a035062549..f6b48ad524 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -37,14 +37,13 @@ from typing import Optional import datasets import numpy as np import pandas as pd -from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase - from backend_request_func import ( ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput, ) +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase try: from vllm.transformers_utils.tokenizer import get_tokenizer diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml deleted file mode 100644 index 65b1e09a24..0000000000 --- a/benchmarks/pyproject.toml +++ /dev/null @@ -1,49 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.lint.isort] -known-first-party = ["vllm"] - -[tool.ruff.format] -docstring-code-format = true \ No newline at end of file diff --git a/cmake/hipify.py b/cmake/hipify.py index 55d378f5b1..8504f9defe 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -16,7 +16,7 @@ import shutil from torch.utils.hipify.hipify_python import hipify -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Project directory where all the source + include files live. @@ -34,15 +34,14 @@ if __name__ == '__main__': ) # Source files to convert. - parser.add_argument("sources", - help="Source files to hipify.", - nargs="*", - default=[]) + parser.add_argument( + "sources", help="Source files to hipify.", nargs="*", default=[] + ) args = parser.parse_args() # Limit include scope to project_dir only - includes = [os.path.join(args.project_dir, '*')] + includes = [os.path.join(args.project_dir, "*")] # Get absolute path for all source files. extra_files = [os.path.abspath(s) for s in args.sources] @@ -51,25 +50,31 @@ if __name__ == '__main__': # The directory might already exist to hold object files so we ignore that. shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) - hipify_result = hipify(project_directory=args.project_dir, - output_directory=args.output_dir, - header_include_dirs=[], - includes=includes, - extra_files=extra_files, - show_detailed=True, - is_pytorch_extension=True, - hipify_extra_files_only=True) + hipify_result = hipify( + project_directory=args.project_dir, + output_directory=args.output_dir, + header_include_dirs=[], + includes=includes, + extra_files=extra_files, + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) hipified_sources = [] for source in args.sources: s_abs = os.path.abspath(source) - hipified_s_abs = (hipify_result[s_abs].hipified_path if - (s_abs in hipify_result - and hipify_result[s_abs].hipified_path is not None) - else s_abs) + hipified_s_abs = ( + hipify_result[s_abs].hipified_path + if ( + s_abs in hipify_result + and hipify_result[s_abs].hipified_path is not None + ) + else s_abs + ) hipified_sources.append(hipified_s_abs) - assert (len(hipified_sources) == len(args.sources)) + assert len(hipified_sources) == len(args.sources) # Print hipified source files. print("\n".join(hipified_sources)) diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 1dd7101acc..7a81dd40c8 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = { **{ VLLMDataType.u4b8: "u4b8", VLLMDataType.u8b128: "u8b128", - } + }, } VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", - } + }, } VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { @@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { **{ VLLMDataType.u4b8: 4, VLLMDataType.u8b128: 8, - } + }, } VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: dict[Union[ - MixedInputKernelScheduleType, KernelScheduleType], str] = { - **KernelScheduleTag, # type: ignore - **{ - MixedInputKernelScheduleType.TmaWarpSpecialized: - "cutlass::gemm::KernelTmaWarpSpecialized", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: - "cutlass::gemm::KernelTmaWarpSpecializedPingpong", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: - "cutlass::gemm::KernelTmaWarpSpecializedCooperative", - } - } +VLLMKernelScheduleTag: dict[ + Union[MixedInputKernelScheduleType, KernelScheduleType], str +] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", + }, +} diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 698deb107c..be5b68cc53 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -17,25 +17,30 @@ FILE_HEAD = """ namespace MARLIN_NAMESPACE_NAME { """.strip() -TEMPLATE = ("template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " - "{{s_type_id}}, " - "{{threads}}, " - "{{thread_m_blocks}}, " - "{{thread_n_blocks}}, " - "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " - "{{stages}}, " - "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" - "( MARLIN_KERNEL_PARAMS );") +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ - "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kU4", + "vllm::kU4B8", + "vllm::kU8B128", + "vllm::kFE4M3fn", + "vllm::kFE2M1f", ] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] @@ -58,11 +63,12 @@ def generate_new_kernels(): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" + "vllm::kU4B8", + "vllm::kU8B128", ]: continue if thread_configs[2] == 256: diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 7576e0548a..42d3b45609 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -17,28 +17,32 @@ FILE_HEAD = """ namespace MARLIN_NAMESPACE_NAME { """.strip() -TEMPLATE = ("template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " - "{{s_type_id}}, " - "{{threads}}, " - "{{thread_m_blocks}}, " - "{{thread_n_blocks}}, " - "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " - "{{stages}}, " - "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" - "( MARLIN_KERNEL_PARAMS );") +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ - "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kU4", + "vllm::kU4B8", + "vllm::kU8B128", + "vllm::kFE4M3fn", + "vllm::kFE2M1f", ] -THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), - (128, 64, 128)] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] # group_blocks: @@ -59,11 +63,12 @@ def generate_new_kernels(): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" + "vllm::kU4B8", + "vllm::kU8B128", ]: continue if thread_configs[2] == 256: @@ -93,8 +98,7 @@ def generate_new_kernels(): c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and \ - group_blocks == 4: + if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: # HQQ (is_zp_float = true) only supports # 4bit quantization and fp16 is_zp_float_list.append(True) diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 8fd536ef46..f7106f016b 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -12,18 +12,24 @@ from functools import reduce from typing import Optional, Union import jinja2 + # yapf conflicts with isort for this block # yapf: disable -from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, - EpilogueScheduleType, - MixedInputKernelScheduleType, - TileSchedulerTag, - TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, - VLLMDataTypeSize, VLLMDataTypeTag, - VLLMDataTypeTorchDataTypeTag, - VLLMDataTypeVLLMScalarTypeTag, - VLLMKernelScheduleTag) +from vllm_cutlass_library_extension import ( + DataType, + EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, + VLLMDataType, + VLLMDataTypeNames, + VLLMDataTypeSize, + VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, + VLLMKernelScheduleTag, +) # yapf: enable @@ -286,18 +292,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) - cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + - f"x{schedule_config.cluster_shape_mnk[1]}" + - f"x{schedule_config.cluster_shape_mnk[2]}") - kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ - .split("::")[-1] - epilogue_schedule = EpilogueScheduleTag[ - schedule_config.epilogue_schedule].split("::")[-1] - tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ - .split("::")[-1] + cluster_shape = ( + f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}" + ) + kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split( + "::" + )[-1] + epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split( + "::" + )[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1] - return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + - f"_{epilogue_schedule}_{tile_scheduler}") + return ( + f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}" + ) # mostly unique shorter sch_sig @@ -316,18 +327,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: # unique type_name def generate_type_signature(kernel_types: TypeConfig): - return str("".join([ - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ])) + return str( + "".join( + [ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) + ) def generate_type_option_name(kernel_types: TypeConfig): - return ", ".join([ - f"{field.name.replace('b_', 'with_')+'_type'}=" + - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ]) + return ", ".join( + [ + f"{field.name.replace('b_', 'with_') + '_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) def is_power_of_two(n): @@ -335,7 +352,6 @@ def is_power_of_two(n): def to_cute_constant(value: list[int]): - def _to_cute_constant(value: int): if is_power_of_two(value): return f"_{value}" @@ -350,11 +366,11 @@ def to_cute_constant(value: list[int]): def unique_schedules(impl_configs: list[ImplConfig]): # Use dict over set for deterministic ordering - return list({ - sch: None - for impl_config in impl_configs - for sch in impl_config.schedules - }.keys()) + return list( + { + sch: None for impl_config in impl_configs for sch in impl_config.schedules + }.keys() + ) def unsigned_type_with_bitwidth(num_bits): @@ -380,7 +396,7 @@ template_globals = { "gen_type_sig": generate_type_signature, "unique_schedules": unique_schedules, "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, - "gen_type_option_name": generate_type_option_name + "gen_type_option_name": generate_type_option_name, } @@ -398,23 +414,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE) def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): sources = [] - sources.append(( - "machete_mm_dispatch", - mm_dispatch_template.render(impl_configs=impl_configs), - )) + sources.append( + ( + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), + ) + ) prepack_types = [] for impl_config in impl_configs: - convert_type = impl_config.types.a \ - if impl_config.types.b_group_scale == DataType.void \ - else impl_config.types.b_group_scale + convert_type = ( + impl_config.types.a + if impl_config.types.b_group_scale == DataType.void + else impl_config.types.b_group_scale + ) prepack_types.append( PrepackTypeConfig( a=impl_config.types.a, b_num_bits=VLLMDataTypeSize[impl_config.types.b], convert=convert_type, accumulator=impl_config.types.accumulator, - )) + ) + ) def prepacked_type_key(prepack_type: PrepackTypeConfig): # For now, we can just use the first accumulator type seen since @@ -430,10 +451,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): unique_prepack_types.append(prepack_type) prepack_types_seen.add(key) - sources.append(( - "machete_prepack", - prepack_dispatch_template.render(types=unique_prepack_types, ), - )) + sources.append( + ( + "machete_prepack", + prepack_dispatch_template.render( + types=unique_prepack_types, + ), + ) + ) # Split up impls across files num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) @@ -466,10 +491,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): curr_impl_in_file += len(files_impls[-1][-1].schedules) for part, file_impls in enumerate(files_impls): - sources.append(( - f"machete_mm_impl_part{part+1}", - mm_impl_template.render(impl_configs=file_impls), - )) + sources.append( + ( + f"machete_mm_impl_part{part + 1}", + mm_impl_template.render(impl_configs=file_impls), + ) + ) return sources @@ -514,8 +541,7 @@ def generate(): # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore + (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore for cond, tile_config in default_tile_heuristic_config.items() ] @@ -541,14 +567,18 @@ def generate(): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(GPTQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] AWQ_kernel_type_configs = list( @@ -561,14 +591,18 @@ def generate(): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (DataType.u4, DataType.u8) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(AWQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] # TODO: Support W4A8 when ready diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index d026235dd9..ecd71ee1f3 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -33,8 +33,11 @@ def auto_mock(module, attr, max_mocks=50): try: # First treat attr as an attr, then as a submodule with patch("importlib.metadata.version", return_value="0.0.0"): - return getattr(importlib.import_module(module), attr, - importlib.import_module(f"{module}.{attr}")) + return getattr( + importlib.import_module(module), + attr, + importlib.import_module(f"{module}.{attr}"), + ) except importlib.metadata.PackageNotFoundError as e: raise e except ModuleNotFoundError as e: @@ -42,7 +45,8 @@ def auto_mock(module, attr, max_mocks=50): sys.modules[e.name] = PydanticMagicMock() raise ImportError( - f"Failed to import {module}.{attr} after mocking {max_mocks} imports") + f"Failed to import {module}.{attr} after mocking {max_mocks} imports" + ) latency = auto_mock("vllm.benchmarks", "latency") @@ -61,9 +65,7 @@ class MarkdownFormatter(HelpFormatter): """Custom formatter that generates markdown for argument groups.""" def __init__(self, prog, starting_heading_level=3): - super().__init__(prog, - max_help_position=float('inf'), - width=float('inf')) + super().__init__(prog, max_help_position=float("inf"), width=float("inf")) self._section_heading_prefix = "#" * starting_heading_level self._argument_heading_prefix = "#" * (starting_heading_level + 1) self._markdown_output = [] @@ -85,23 +87,19 @@ class MarkdownFormatter(HelpFormatter): def add_arguments(self, actions): for action in actions: - if (len(action.option_strings) == 0 - or "--help" in action.option_strings): + if len(action.option_strings) == 0 or "--help" in action.option_strings: continue - option_strings = f'`{"`, `".join(action.option_strings)}`' + option_strings = f"`{'`, `'.join(action.option_strings)}`" heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n" self._markdown_output.append(heading_md) if choices := action.choices: - choices = f'`{"`, `".join(str(c) for c in choices)}`' - self._markdown_output.append( - f"Possible choices: {choices}\n\n") - elif ((metavar := action.metavar) - and isinstance(metavar, (list, tuple))): - metavar = f'`{"`, `".join(str(m) for m in metavar)}`' - self._markdown_output.append( - f"Possible choices: {metavar}\n\n") + choices = f"`{'`, `'.join(str(c) for c in choices)}`" + self._markdown_output.append(f"Possible choices: {choices}\n\n") + elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)): + metavar = f"`{'`, `'.join(str(m) for m in metavar)}`" + self._markdown_output.append(f"Possible choices: {metavar}\n\n") if action.help: self._markdown_output.append(f"{action.help}\n\n") @@ -116,7 +114,7 @@ class MarkdownFormatter(HelpFormatter): def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser: """Create a parser for the given class with markdown formatting. - + Args: cls: The class to create a parser for **kwargs: Additional keyword arguments to pass to `cls.add_cli_args`. @@ -143,24 +141,17 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Create parsers to document parsers = { - "engine_args": - create_parser(EngineArgs.add_cli_args), - "async_engine_args": - create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True), - "serve": - create_parser(cli_args.make_arg_parser), - "chat": - create_parser(ChatCommand.add_cli_args), - "complete": - create_parser(CompleteCommand.add_cli_args), - "bench_latency": - create_parser(latency.add_cli_args), - "bench_throughput": - create_parser(throughput.add_cli_args), - "bench_serve": - create_parser(serve.add_cli_args), - "run-batch": - create_parser(run_batch.make_arg_parser), + "engine_args": create_parser(EngineArgs.add_cli_args), + "async_engine_args": create_parser( + AsyncEngineArgs.add_cli_args, async_args_only=True + ), + "serve": create_parser(cli_args.make_arg_parser), + "chat": create_parser(ChatCommand.add_cli_args), + "complete": create_parser(CompleteCommand.add_cli_args), + "bench_latency": create_parser(latency.add_cli_args), + "bench_throughput": create_parser(throughput.add_cli_args), + "bench_serve": create_parser(serve.add_cli_args), + "run-batch": create_parser(run_batch.make_arg_parser), } # Generate documentation for each parser diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 0cbaebb598..ed8277f628 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -11,7 +11,7 @@ import regex as re logger = logging.getLogger("mkdocs") ROOT_DIR = Path(__file__).parent.parent.parent.parent -ROOT_DIR_RELATIVE = '../../../../..' +ROOT_DIR_RELATIVE = "../../../../.." EXAMPLE_DIR = ROOT_DIR / "examples" EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" @@ -36,7 +36,7 @@ def fix_case(text: str) -> str: r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 } for pattern, repl in subs.items(): - text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE) + text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE) return text @@ -58,7 +58,8 @@ class Example: determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. determine_title() -> str: Determines the title of the document. generate() -> str: Generates the documentation content. - """ # noqa: E501 + """ # noqa: E501 + path: Path category: str = None main_file: Path = field(init=False) @@ -84,9 +85,8 @@ class Example: Markdown file found in the directory. Raises: IndexError: If no Markdown files are found in the directory. - """ # noqa: E501 - return self.path if self.path.is_file() else list( - self.path.glob("*.md")).pop() + """ # noqa: E501 + return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop() def determine_other_files(self) -> list[Path]: """ @@ -98,7 +98,7 @@ class Example: Returns: list[Path]: A list of Path objects representing the other files in the directory. - """ # noqa: E501 + """ # noqa: E501 if self.path.is_file(): return [] is_other_file = lambda file: file.is_file() and file != self.main_file @@ -109,25 +109,25 @@ class Example: # Specify encoding for building on Windows with open(self.main_file, encoding="utf-8") as f: first_line = f.readline().strip() - match = re.match(r'^#\s+(?P
int: def _detect_cloud_provider() -> str: # Try detecting through vendor file vendor_files = [ - "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", + "/sys/class/dmi/id/product_version", + "/sys/class/dmi/id/bios_vendor", "/sys/class/dmi/id/product_name", - "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor" + "/sys/class/dmi/id/chassis_asset_tag", + "/sys/class/dmi/id/sys_vendor", ] # Mapping of identifiable strings to cloud providers cloud_identifiers = { @@ -152,39 +153,53 @@ class UsageMessage: self.log_time: Optional[int] = None self.source: Optional[str] = None - def report_usage(self, - model_architecture: str, - usage_context: UsageContext, - extra_kvs: Optional[dict[str, Any]] = None) -> None: - t = Thread(target=self._report_usage_worker, - args=(model_architecture, usage_context, extra_kvs or {}), - daemon=True) + def report_usage( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: Optional[dict[str, Any]] = None, + ) -> None: + t = Thread( + target=self._report_usage_worker, + args=(model_architecture, usage_context, extra_kvs or {}), + daemon=True, + ) t.start() - def _report_usage_worker(self, model_architecture: str, - usage_context: UsageContext, - extra_kvs: dict[str, Any]) -> None: + def _report_usage_worker( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any], + ) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continuous_usage() - def _report_usage_once(self, model_architecture: str, - usage_context: UsageContext, - extra_kvs: dict[str, Any]) -> None: + def _report_usage_once( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any], + ) -> None: # Platform information from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): self.gpu_count = cuda_device_count_stateless() - self.gpu_type, self.gpu_memory_per_device = ( - cuda_get_device_properties(0, ("name", "total_memory"))) + self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties( + 0, ("name", "total_memory") + ) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda if current_platform.is_tpu(): try: import torch_xla + self.gpu_count = torch_xla.runtime.world_size() self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = ( - torch_xla.core.xla_model.get_memory_info()["bytes_limit"]) + self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ + "bytes_limit" + ] except Exception: logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() @@ -195,11 +210,13 @@ class UsageMessage: info = cpuinfo.get_cpu_info() self.num_cpu = info.get("count", None) self.cpu_type = info.get("brand_raw", "") - self.cpu_family_model_stepping = ",".join([ - str(info.get("family", "")), - str(info.get("model", "")), - str(info.get("stepping", "")) - ]) + self.cpu_family_model_stepping = ",".join( + [ + str(info.get("family", "")), + str(info.get("model", "")), + str(info.get("stepping", "")), + ] + ) # vLLM information self.context = usage_context.value @@ -207,10 +224,9 @@ class UsageMessage: self.model_architecture = model_architecture # Environment variables - self.env_var_json = json.dumps({ - env_var: getattr(envs, env_var) - for env_var in _USAGE_ENV_VARS_TO_COLLECT - }) + self.env_var_json = json.dumps( + {env_var: getattr(envs, env_var) for env_var in _USAGE_ENV_VARS_TO_COLLECT} + ) # Metadata self.log_time = _get_current_timestamp_ns() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 6b208bca69..c9999649b5 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -33,22 +33,48 @@ import types import uuid import warnings import weakref -from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, - ArgumentTypeError, RawDescriptionHelpFormatter, - _ArgumentGroup) +from argparse import ( + Action, + ArgumentDefaultsHelpFormatter, + ArgumentParser, + ArgumentTypeError, + RawDescriptionHelpFormatter, + _ArgumentGroup, +) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, - Hashable, Iterable, Iterator, KeysView, Mapping, - Sequence) +from collections.abc import ( + AsyncGenerator, + Awaitable, + Collection, + Generator, + Hashable, + Iterable, + Iterator, + KeysView, + Mapping, + Sequence, +) from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from pathlib import Path from types import MappingProxyType -from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, TextIO, TypeVar, Union, cast, overload) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NamedTuple, + Optional, + TextIO, + TypeVar, + Union, + cast, + overload, +) from urllib.parse import urlparse from uuid import uuid4 @@ -117,8 +143,8 @@ GiB_bytes = 1 << 30 """The number of bytes in one gibibyte (GiB).""" # ANSI color codes -CYAN = '\033[1;36m' -RESET = '\033[0;0m' +CYAN = "\033[1;36m" +RESET = "\033[0;0m" STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, @@ -152,7 +178,7 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) -P = ParamSpec('P') +P = ParamSpec("P") T = TypeVar("T") U = TypeVar("U") @@ -161,8 +187,7 @@ _V = TypeVar("_V") _T = TypeVar("_T") -class _Sentinel: - ... +class _Sentinel: ... ALL_PINNED_SENTINEL = _Sentinel() @@ -179,7 +204,6 @@ class LayerBlockType(enum.Enum): class Counter: - def __init__(self, start: int = 0) -> None: self.counter = start @@ -193,7 +217,6 @@ class Counter: class _MappingOrderCacheView(UserDict[_K, _V]): - def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): super().__init__(data) self.ordered_keys = ordered_keys @@ -224,10 +247,9 @@ class CacheInfo(NamedTuple): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): - - def __init__(self, - capacity: float, - getsizeof: Optional[Callable[[_V], float]] = None): + def __init__( + self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None + ): super().__init__(capacity, getsizeof) self.pinned_items = set[_K]() @@ -247,8 +269,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): def __delitem__(self, key: _K) -> None: run_on_remove = key in self - value = self.__getitem__(key, - update_info=False) # type: ignore[call-arg] + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] super().__delitem__(key) if key in self.pinned_items: # Todo: add warning to inform that del pinned item @@ -261,7 +282,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): """Return the internal cache dictionary in order (read-only).""" return _MappingOrderCacheView( self._Cache__data, # type: ignore - self.order) + self.order, + ) @property def order(self) -> Mapping[_K, None]: @@ -302,22 +324,17 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): self._LRUCache__order[key] = None # type: ignore @overload - def get(self, key: _K, /) -> Optional[_V]: - ... + def get(self, key: _K, /) -> Optional[_V]: ... @overload - def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: - ... + def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... - def get(self, - key: _K, - /, - default: Optional[Union[_V, - _T]] = None) -> Optional[Union[_V, _T]]: + def get( + self, key: _K, /, default: Optional[Union[_V, _T]] = None + ) -> Optional[Union[_V, _T]]: value: Optional[Union[_V, _T]] if key in self: - value = self.__getitem__( - key, update_info=False) # type: ignore[call-arg] + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] self._hits += 1 else: @@ -327,23 +344,19 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): return value @overload - def pop(self, key: _K) -> _V: - ... + def pop(self, key: _K) -> _V: ... @overload - def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: - ... + def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... - def pop(self, - key: _K, - default: Optional[Union[_V, - _T]] = None) -> Optional[Union[_V, _T]]: + def pop( + self, key: _K, default: Optional[Union[_V, _T]] = None + ) -> Optional[Union[_V, _T]]: value: Optional[Union[_V, _T]] if key not in self: return default - value = self.__getitem__(key, - update_info=False) # type: ignore[call-arg] + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] self.__delitem__(key) return value @@ -385,10 +398,12 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): # pop the oldest item in the cache that is not pinned lru_key = next( (key for key in self.order if key not in self.pinned_items), - ALL_PINNED_SENTINEL) + ALL_PINNED_SENTINEL, + ) if lru_key is ALL_PINNED_SENTINEL: - raise RuntimeError("All items are pinned, " - "cannot remove oldest from the cache.") + raise RuntimeError( + "All items are pinned, cannot remove oldest from the cache." + ) else: lru_key = next(iter(self.order)) value = self.pop(cast(_K, lru_key)) @@ -436,8 +451,7 @@ class PyObjectCache: return obj def reset(self): - """Makes all cached-objects available for the next scheduler iteration. - """ + """Makes all cached-objects available for the next scheduler iteration.""" self._index = 0 @@ -445,8 +459,8 @@ class PyObjectCache: def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" from vllm import _custom_ops as ops - max_shared_mem = ( - ops.get_max_shared_memory_per_block_device_attribute(gpu)) + + max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail assert max_shared_mem > 0, "max_shared_mem can not be zero" @@ -481,11 +495,14 @@ class AsyncMicrobatchTokenizer: self.batch_wait_timeout_s = batch_wait_timeout_s self._loop = asyncio.get_running_loop() - self._queues: dict[tuple, - asyncio.Queue[Union[tuple[str, dict, - asyncio.Future], - tuple[list[int], - asyncio.Future]]]] = {} + self._queues: dict[ + tuple, + asyncio.Queue[ + Union[ + tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future] + ] + ], + ] = {} self._batcher_tasks: list[asyncio.Task] = [] # Single-thread executor for blocking tokenizer calls. @@ -509,8 +526,9 @@ class AsyncMicrobatchTokenizer: # === Internal helpers === def _get_queue( self, loop: asyncio.AbstractEventLoop, key: tuple - ) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[ - list[int], asyncio.Future]]]: + ) -> asyncio.Queue[ + Union[tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future]] + ]: """Get the request queue for the given operation key, creating a new queue and batcher task if needed.""" queue = self._queues.get(key) @@ -520,8 +538,7 @@ class AsyncMicrobatchTokenizer: can_batch = key[1] != "other" coro = self._batch_encode_loop(queue, can_batch) else: - assert key[0] == "decode", \ - f"Unknown operation type: {key[0]}." + assert key[0] == "decode", f"Unknown operation type: {key[0]}." coro = self._batch_decode_loop(queue) self._batcher_tasks.append(loop.create_task(coro)) return queue @@ -541,7 +558,8 @@ class AsyncMicrobatchTokenizer: break try: prompt, kwargs, result_future = await asyncio.wait_for( - queue.get(), timeout) + queue.get(), timeout + ) prompts.append(prompt) result_futures.append(result_future) if not can_batch: @@ -553,10 +571,10 @@ class AsyncMicrobatchTokenizer: # If every request uses identical kwargs we can run a single # batched tokenizer call for a big speed-up. if can_batch and len(prompts) > 1: - batch_encode_fn = partial(self.tokenizer, prompts, - **kwargs) + batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) results = await self._loop.run_in_executor( - self._executor, batch_encode_fn) + self._executor, batch_encode_fn + ) for i, fut in enumerate(result_futures): if not fut.done(): @@ -564,11 +582,11 @@ class AsyncMicrobatchTokenizer: fut.set_result(BatchEncoding(data)) else: encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ - self.tokenizer(p, **kw) - for p, kw in zip(prompts, kwargs) + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) ] results = await self._loop.run_in_executor( - self._executor, encode_fn) + self._executor, encode_fn + ) for fut, res in zip(result_futures, results): if not fut.done(): @@ -592,7 +610,8 @@ class AsyncMicrobatchTokenizer: break try: token_ids, result_future = await asyncio.wait_for( - queue.get(), timeout) + queue.get(), timeout + ) token_ids_list.append(token_ids) result_futures.append(result_future) except asyncio.TimeoutError: @@ -601,8 +620,8 @@ class AsyncMicrobatchTokenizer: try: # Perform a single batched decode call for all requests results = await self._loop.run_in_executor( - self._executor, self.tokenizer.batch_decode, - token_ids_list) + self._executor, self.tokenizer.batch_decode, token_ids_list + ) for fut, res in zip(result_futures, results): if not fut.done(): fut.set_result(res) @@ -631,7 +650,7 @@ class AsyncMicrobatchTokenizer: """ if op == "decode": - return ("decode", ) + return ("decode",) add_special_tokens = kwargs.get("add_special_tokens", True) truncation = kwargs.get("truncation", False) @@ -641,16 +660,17 @@ class AsyncMicrobatchTokenizer: return "encode", add_special_tokens, False, None model_max = getattr(self.tokenizer, "model_max_length", None) - if max_length is None or (model_max is not None - and max_length == model_max): + if max_length is None or (model_max is not None and max_length == model_max): return "encode", add_special_tokens, True, "model_max" return "encode", "other" def __del__(self): - if ((tasks := getattr(self, "_batcher_tasks", None)) - and (loop := getattr(self, "_loop", None)) - and not loop.is_closed()): + if ( + (tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed() + ): def cancel_tasks(): for task in tasks: @@ -685,8 +705,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool: def make_async( - func: Callable[P, T], - executor: Optional[concurrent.futures.Executor] = None + func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None ) -> Callable[P, Awaitable[T]]: """Take a blocking function, and run it on in an executor thread. @@ -703,15 +722,14 @@ def make_async( return _async_wrapper -def _next_task(iterator: AsyncGenerator[T, None], - loop: AbstractEventLoop) -> Task: +def _next_task(iterator: AsyncGenerator[T, None], loop: AbstractEventLoop) -> Task: # Can use anext() in python >= 3.10 return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] async def merge_async_iterators( - *iterators: AsyncGenerator[T, - None], ) -> AsyncGenerator[tuple[int, T], None]: + *iterators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. @@ -729,8 +747,7 @@ async def merge_async_iterators( awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} try: while awaits: - done, _ = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED) + done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) for d in done: pair = awaits.pop(d) try: @@ -748,8 +765,7 @@ async def merge_async_iterators( await it.aclose() -async def collect_from_async_generator( - iterator: AsyncGenerator[T, None]) -> list[T]: +async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: """Collect all items from an async generator into a list.""" items = [] async for item in iterator: @@ -765,7 +781,8 @@ def get_ip() -> str: " it is often used by Docker and other software to" " interact with the container's network stack. Please " "use VLLM_HOST_IP instead to set the IP address for vLLM processes" - " to communicate with each other.") + " to communicate with each other." + ) if host_ip: return host_ip @@ -793,7 +810,8 @@ def get_ip() -> str: "Failed to get the IP address, using 0.0.0.0 by default." "The value can be set by the environment variable" " VLLM_HOST_IP or HOST_IP.", - stacklevel=2) + stacklevel=2, + ) return "0.0.0.0" @@ -821,7 +839,8 @@ def get_loopback_ip() -> str: else: raise RuntimeError( "Neither 127.0.0.1 nor ::1 are bound to a local interface. " - "Set the VLLM_LOOPBACK_IP environment variable explicitly.") + "Set the VLLM_LOOPBACK_IP environment variable explicitly." + ) def is_valid_ipv6_address(address: str) -> bool: @@ -834,13 +853,13 @@ def is_valid_ipv6_address(address: str) -> bool: def split_host_port(host_port: str) -> tuple[str, int]: # ipv6 - if host_port.startswith('['): - host, port = host_port.rsplit(']', 1) + if host_port.startswith("["): + host, port = host_port.rsplit("]", 1) host = host[1:] - port = port.split(':')[1] + port = port.split(":")[1] return host, int(port) else: - host, port = host_port.split(':') + host, port = host_port.split(":") return host, int(port) @@ -908,8 +927,7 @@ def _get_open_port() -> int: return port except OSError: port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", - port - 1, port) + logger.info("Port %d is already in use, trying port %d", port - 1, port) # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -932,8 +950,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]: our_pid = os.getpid() for conn in psutil.net_connections(): - if conn.laddr.port == port and (conn.pid is not None - and conn.pid != our_pid): + if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid): try: return psutil.Process(conn.pid) except psutil.NoSuchProcess: @@ -945,15 +962,18 @@ def update_environment_variables(envs: dict[str, str]): for k, v in envs.items(): if k in os.environ and os.environ[k] != v: logger.warning( - "Overwriting environment variable %s " - "from '%s' to '%s'", k, os.environ[k], v) + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) os.environ[k] = v def chunk_list(lst: list[T], chunk_size: int): """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size): - yield lst[i:i + chunk_size] + yield lst[i : i + chunk_size] def cdiv(a: int, b: int) -> int: @@ -997,6 +1017,7 @@ def _generate_random_fp8( # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) ops.convert_fp8(tensor, tensor_tmp) @@ -1004,12 +1025,12 @@ def _generate_random_fp8( def get_kv_cache_torch_dtype( - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, +) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": - if isinstance(model_dtype, - str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype @@ -1039,32 +1060,30 @@ def create_kv_caches_with_random_flash( cache_layout: Optional[str] = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, - 4) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) - kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] - for i in stride_order) + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_value_cache = torch.empty(size=kv_cache_allocation_shape, - dtype=torch_dtype, - device=device).permute(*stride_order) + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=torch_dtype, device=device + ).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(key_value_cache, -scale, scale) else: - raise ValueError( - f"Does not support key cache of type {cache_dtype}") + raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches @@ -1086,6 +1105,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type fp8 with head_size {head_size}" ) from vllm.platforms import current_platform + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -1095,31 +1115,27 @@ def create_kv_caches_with_random( key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=torch_dtype, - device=device) + key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(key_cache, -scale, scale) else: - raise ValueError( - f"Does not support key cache of type {cache_dtype}") + raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=torch_dtype, - device=device) + value_cache = torch.empty( + size=value_cache_shape, dtype=torch_dtype, device=device + ) if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(value_cache, -scale, scale) else: - raise ValueError( - f"Does not support value cache of type {cache_dtype}") + raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches @@ -1127,6 +1143,7 @@ def create_kv_caches_with_random( @cache def is_pin_memory_available() -> bool: from vllm.platforms import current_platform + return current_platform.is_pin_memory_available() @@ -1139,13 +1156,13 @@ def is_uva_available() -> bool: class DeviceMemoryProfiler: - def __init__(self, device: Optional[torch.types.Device] = None): self.device = device def current_memory_usage(self) -> float: # Return the memory usage in bytes. from vllm.platforms import current_platform + gc.collect() return current_platform.get_current_memory_usage(self.device) @@ -1182,7 +1199,7 @@ def make_ndarray_with_pad( padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb + padded_x[ind, : len(blocktb)] = blocktb return padded_x @@ -1231,8 +1248,7 @@ def get_dtype_size(dtype: torch.dtype) -> int: # bool = 0, int = 1, float = 2, complex = 3 def _get_precision_level(dtype: torch.dtype) -> int: # NOTE: Complex dtypes return `is_floating_point=False` - return ((dtype != torch.bool) + dtype.is_floating_point + - dtype.is_complex * 2) + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): @@ -1260,8 +1276,11 @@ def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): # Compare floating-point types src_info = torch.finfo(src_dtype) tgt_info = torch.finfo(tgt_dtype) - return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max - and src_info.resolution >= tgt_info.resolution) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): @@ -1329,6 +1348,7 @@ def init_cached_hf_modules() -> None: Lazy initialization of the Hugging Face modules. """ from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() @@ -1372,8 +1392,8 @@ def find_nccl_library() -> str: # manually load the nccl library if so_file: logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", - so_file) + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) else: if torch.version.cuda is not None: so_file = "libnccl.so.2" @@ -1388,8 +1408,8 @@ def find_nccl_library() -> str: def find_nccl_include_paths() -> Optional[list[str]]: """ We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` - environment variable, or we find the library file brought by - nvidia-nccl-cuXX. load_inline by default uses + environment variable, or we find the library file brought by + nvidia-nccl-cuXX. load_inline by default uses torch.utils.cpp_extension.include_paths """ paths: list[str] = [] @@ -1399,6 +1419,7 @@ def find_nccl_include_paths() -> Optional[list[str]]: try: import importlib.util + spec = importlib.util.find_spec("nvidia.nccl") if spec and getattr(spec, "submodule_search_locations", None): for loc in spec.submodule_search_locations: @@ -1431,7 +1452,6 @@ torch.cuda.set_stream = _patched_set_stream class _StreamPlaceholder: - def __init__(self): self.synchronize = lambda: None @@ -1448,8 +1468,8 @@ def current_stream() -> torch.cuda.Stream: from C/C++ code. """ from vllm.platforms import current_platform - if not hasattr(_current_stream_tls, - "value") or _current_stream_tls.value is None: + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL @@ -1467,7 +1487,8 @@ def current_stream() -> torch.cuda.Stream: else: raise ValueError( "Fail to set current stream, current platform " - "may not support current_stream with torch API") + "may not support current_stream with torch API" + ) return _current_stream_tls.value @@ -1480,12 +1501,14 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: tmp_dir = tempfile.gettempdir() # add username to tmp_dir to avoid permission issues tmp_dir = os.path.join(tmp_dir, getpass.getuser()) - filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", - f"vllm-instance-{vllm_config.instance_id}", - filename) + filename = ( + f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log" + ).replace(" ", "_") + log_path = os.path.join( + tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename + ) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) @@ -1496,7 +1519,7 @@ def identity(value: T, **kwargs) -> T: return value -F = TypeVar('F', bound=Callable[..., Any]) +F = TypeVar("F", bound=Callable[..., Any]) def deprecate_args( @@ -1508,24 +1531,22 @@ def deprecate_args( is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: - params = inspect.signature(fn).parameters pos_types = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) - pos_kws = [ - kw for kw, param in params.items() if param.kind in pos_types - ] + pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): - deprecated_args = pos_kws[start_index:len(args)] + deprecated_args = pos_kws[start_index : len(args)] if deprecated_args: msg = ( f"The positional arguments {deprecated_args} are " - "deprecated and will be removed in a future update.") + "deprecated and will be removed in a future update." + ) if additional_message is not None: msg += f" {additional_message}" @@ -1552,7 +1573,6 @@ def deprecate_kwargs( is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: - @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): @@ -1560,7 +1580,8 @@ def deprecate_kwargs( if deprecated_kwargs: msg = ( f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update.") + "deprecated and will be removed in a future update." + ) if additional_message is not None: msg += f" {additional_message}" @@ -1577,8 +1598,7 @@ def deprecate_kwargs( @lru_cache(maxsize=8) -def _cuda_device_count_stateless( - cuda_visible_devices: Optional[str] = None) -> int: +def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. @@ -1590,13 +1610,17 @@ def _cuda_device_count_stateless( import torch.version from vllm.platforms import current_platform + if not torch.cuda._is_compiled(): return 0 if current_platform.is_rocm(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = torch.cuda._device_count_amdsmi() if (hasattr( - torch.cuda, "_device_count_amdsmi")) else -1 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) else: raw_count = torch.cuda._device_count_nvml() r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count @@ -1630,9 +1654,9 @@ def xpu_is_initialized() -> bool: return torch.xpu.is_initialized() -def cuda_get_device_properties(device, - names: Sequence[str], - init_cuda=False) -> tuple[Any, ...]: +def cuda_get_device_properties( + device, names: Sequence[str], init_cuda=False +) -> tuple[Any, ...]: """Get specified CUDA device property values without initializing CUDA in the current process.""" if init_cuda or cuda_is_initialized(): @@ -1642,11 +1666,12 @@ def cuda_get_device_properties(device, # Run in subprocess to avoid initializing CUDA as a side effect. mp_ctx = multiprocessing.get_context("fork") with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: - return executor.submit(cuda_get_device_properties, device, names, - True).result() + return executor.submit(cuda_get_device_properties, device, names, True).result() -def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: +def weak_bind( + bound_method: Callable[..., Any], +) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that instance is collected.""" @@ -1661,7 +1686,6 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if wrapper.has_run: # type: ignore[attr-defined] return @@ -1677,19 +1701,18 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]: class StoreBoolean(Action): - def __call__(self, parser, namespace, values, option_string=None): if values.lower() == "true": setattr(namespace, self.dest, True) elif values.lower() == "false": setattr(namespace, self.dest, False) else: - raise ValueError(f"Invalid boolean value: {values}. " - "Expected 'true' or 'false'.") + raise ValueError( + f"Invalid boolean value: {values}. Expected 'true' or 'false'." + ) -class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, - RawDescriptionHelpFormatter): +class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): """SortedHelpFormatter that sorts arguments by their option strings.""" def _split_lines(self, text, width): @@ -1701,7 +1724,7 @@ class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, # The patterns also include whitespace after the newline single_newline = re.compile(r"(? str: @@ -1901,28 +1930,29 @@ class FlexibleArgumentParser(ArgumentParser): processed_args = list[str]() for i, arg in enumerate(args): if arg.startswith("--help="): - FlexibleArgumentParser._search_keyword = arg.split( - '=', 1)[-1].lower() + FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower() processed_args.append("--help") - elif arg.startswith('--'): - if '=' in arg: - key, value = arg.split('=', 1) + elif arg.startswith("--"): + if "=" in arg: + key, value = arg.split("=", 1) key = pattern.sub(repl, key, count=1) - processed_args.append(f'{key}={value}') + processed_args.append(f"{key}={value}") else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith('-O') and arg != '-O' and arg[2] != '.': + elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": # allow -O flag to be used without space, e.g. -O3 or -Odecode # -O.<...> handled later # also handle -O=here - level = arg[3:] if arg[2] == '=' else arg[2:] - processed_args.append(f'-O.level={level}') - elif arg == '-O' and i + 1 < len(args) and args[i + 1] in { - "0", "1", "2", "3" - }: + level = arg[3:] if arg[2] == "=" else arg[2:] + processed_args.append(f"-O.level={level}") + elif ( + arg == "-O" + and i + 1 < len(args) + and args[i + 1] in {"0", "1", "2", "3"} + ): # Convert -O to -O.level - processed_args.append('-O.level') + processed_args.append("-O.level") else: processed_args.append(arg) @@ -1986,14 +2016,11 @@ class FlexibleArgumentParser(ArgumentParser): # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) - arg_duplicates = recursive_dict_update(dict_args[key], - arg_dict) - duplicates |= {f'{key}.{d}' for d in arg_duplicates} + arg_duplicates = recursive_dict_update(dict_args[key], arg_dict) + duplicates |= {f"{key}.{d}" for d in arg_duplicates} delete.add(i) # Filter out the dict args we set to None - processed_args = [ - a for i, a in enumerate(processed_args) if i not in delete - ] + processed_args = [a for i, a in enumerate(processed_args) if i not in delete] if duplicates: logger.warning("Found duplicate keys %s", ", ".join(duplicates)) @@ -2050,13 +2077,14 @@ class FlexibleArgumentParser(ArgumentParser): this way the order of priorities is maintained when these are args parsed by super(). """ - assert args.count( - '--config') <= 1, "More than one config file specified!" + assert args.count("--config") <= 1, "More than one config file specified!" - index = args.index('--config') + index = args.index("--config") if index == len(args) - 1: - raise ValueError("No config file specified! \ - Please check your command-line arguments.") + raise ValueError( + "No config file specified! \ + Please check your command-line arguments." + ) file_path = args[index + 1] @@ -2068,29 +2096,33 @@ class FlexibleArgumentParser(ArgumentParser): # followed by rest of cli args. # maintaining this order will enforce the precedence # of cli > config > defaults - if args[0].startswith('-'): + if args[0].startswith("-"): # No sub command (e.g., api_server entry point) - args = config_args + args[0:index] + args[index + 2:] + args = config_args + args[0:index] + args[index + 2 :] elif args[0] == "serve": - model_in_cli = len(args) > 1 and not args[1].startswith('-') - model_in_config = any(arg == '--model' for arg in config_args) + model_in_cli = len(args) > 1 and not args[1].startswith("-") + model_in_config = any(arg == "--model" for arg in config_args) if not model_in_cli and not model_in_config: raise ValueError( "No model specified! Please specify model either " - "as a positional argument or in a config file.") + "as a positional argument or in a config file." + ) if model_in_cli: # Model specified as positional arg, keep CLI version - args = [args[0]] + [ - args[1] - ] + config_args + args[2:index] + args[index + 2:] + args = ( + [args[0]] + + [args[1]] + + config_args + + args[2:index] + + args[index + 2 :] + ) else: # No model in CLI, use config if available - args = [args[0] - ] + config_args + args[1:index] + args[index + 2:] + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] else: - args = [args[0]] + config_args + args[1:index] + args[index + 2:] + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] return args @@ -2107,11 +2139,13 @@ class FlexibleArgumentParser(ArgumentParser): '--tensor-parallel-size': '4' ] """ - extension: str = file_path.split('.')[-1] - if extension not in ('yaml', 'yml'): + extension: str = file_path.split(".")[-1] + if extension not in ("yaml", "yml"): raise ValueError( "Config file must be of a yaml/yml type.\ - %s supplied", extension) + %s supplied", + extension, + ) # only expecting a flat dictionary of atomic types processed_args: list[str] = [] @@ -2123,32 +2157,32 @@ class FlexibleArgumentParser(ArgumentParser): except Exception as ex: logger.error( "Unable to read the config file at %s. \ - Make sure path is correct", file_path) + Make sure path is correct", + file_path, + ) raise ex store_boolean_arguments = [ - action.dest for action in self._actions - if isinstance(action, StoreBoolean) + action.dest for action in self._actions if isinstance(action, StoreBoolean) ] for key, value in config.items(): if isinstance(value, bool) and key not in store_boolean_arguments: if value: - processed_args.append('--' + key) + processed_args.append("--" + key) elif isinstance(value, list): if value: - processed_args.append('--' + key) + processed_args.append("--" + key) for item in value: processed_args.append(str(item)) else: - processed_args.append('--' + key) + processed_args.append("--" + key) processed_args.append(str(value)) return processed_args -async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, - **kwargs): +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) @@ -2172,19 +2206,26 @@ def supports_kw( param_val = params.get(kw_name) # Types where the it may be valid, i.e., explicitly defined & nonvariadic - passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY)) + passable_kw_types = set( + ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ) if param_val: is_sig_param = param_val.kind in passable_kw_types # We want kwargs only, but this is passable as a positional arg - if (requires_kw_only and is_sig_param - and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + if ( + requires_kw_only + and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY + ): return False - if ((requires_kw_only - and param_val.kind == inspect.Parameter.KEYWORD_ONLY) - or (not requires_kw_only and is_sig_param)): + if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( + not requires_kw_only and is_sig_param + ): return True # If we're okay with var-kwargs, it's supported as long as @@ -2194,8 +2235,10 @@ def supports_kw( # mapping, but it wraps an ordered dict, and they appear in order. # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters last_param = params[next(reversed(params))] # type: ignore - return (last_param.kind == inspect.Parameter.VAR_KEYWORD - and last_param.name != kw_name) + return ( + last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name + ) return False @@ -2234,10 +2277,12 @@ def get_allowed_kwarg_only_overrides( filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if supports_kw(callable, - kwarg_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs) + if supports_kw( + callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) } # If anything is dropped, log a warning @@ -2246,11 +2291,15 @@ def get_allowed_kwarg_only_overrides( if requires_kw_only: logger.warning( "The following intended overrides are not keyword-only args " - "and will be dropped: %s", dropped_keys) + "and will be dropped: %s", + dropped_keys, + ) else: logger.warning( "The following intended overrides are not keyword args " - "and will be dropped: %s", dropped_keys) + "and will be dropped: %s", + dropped_keys, + ) return filtered_overrides @@ -2265,8 +2314,9 @@ def supports_dynamo() -> bool: # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform def supports_xccl() -> bool: - return is_torch_equal_or_newer( - "2.8.0.dev") and torch.distributed.is_xccl_available() + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) # Some backends use pytorch version < 2.4.0 which doesn't @@ -2302,7 +2352,6 @@ class AtomicCounter: # Adapted from: https://stackoverflow.com/a/47212782/5082708 class LazyDict(Mapping[str, T], Generic[T]): - def __init__(self, factory: dict[str, Callable[[], T]]): self._factory = factory self._dict: dict[str, T] = {} @@ -2325,7 +2374,6 @@ class LazyDict(Mapping[str, T], Generic[T]): class ClassRegistry(UserDict[type[T], _V]): - def __getitem__(self, key: type[T]) -> _V: for cls in key.mro(): if cls in self.data: @@ -2359,8 +2407,9 @@ def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], - IntermediateTensors] + tensors: Union[ + torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], IntermediateTensors + ], ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, @@ -2375,11 +2424,11 @@ def weak_ref_tensors( # For IntermediateTensors used in pipeline parallelism from vllm.sequence import IntermediateTensors + if isinstance(tensors, IntermediateTensors): - ret = IntermediateTensors({ - key: weak_ref_tensor(val) - for key, val in tensors.tensors.items() - }) + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) return ret raise ValueError("Invalid type for tensors") @@ -2419,7 +2468,8 @@ def get_vllm_optional_dependencies(): return { extra: [ - re.split(r";|>=|<=|==", req)[0] for req in requirements + re.split(r";|>=|<=|==", req)[0] + for req in requirements if req.endswith(f'extra == "{extra}"') ] for extra in extras @@ -2612,12 +2662,13 @@ class PlaceholderModule(_PlaceholderBase): raise exc - raise AssertionError("PlaceholderModule should not be used " - "when the original module can be imported") + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) class _PlaceholderModuleAttr(_PlaceholderBase): - def __init__(self, module: PlaceholderModule, attr_path: str) -> None: super().__init__() @@ -2626,14 +2677,15 @@ class _PlaceholderModuleAttr(_PlaceholderBase): self.__attr_path = attr_path def placeholder_attr(self, attr_path: str): - return _PlaceholderModuleAttr(self.__module, - f"{self.__attr_path}.{attr_path}") + return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}") def __getattr__(self, key: str): getattr(self.__module, f"{self.__attr_path}.{key}") - raise AssertionError("PlaceholderModule should not be used " - "when the original module can be imported") + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) # create a library to hold the custom op @@ -2641,13 +2693,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: Optional[list[str]] = None, - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: Optional[str] = None, - tags: tuple[torch.Tag, ...] = (), + op_name: str, + op_func: Callable, + mutates_args: Optional[list[str]] = None, + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, + dispatch_key: Optional[str] = None, + tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it @@ -2666,12 +2718,14 @@ def direct_register_custom_op( """ if not supports_custom_op(): from vllm.platforms import current_platform + assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " - "the required dependencies.") + "the required dependencies." + ) return if mutates_args is None: @@ -2679,15 +2733,17 @@ def direct_register_custom_op( if dispatch_key is None: from vllm.platforms import current_platform + dispatch_key = current_platform.dispatch_key import torch.library + if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, - mutates_args=mutates_args) + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) @@ -2733,6 +2789,7 @@ def kill_process_tree(pid: int): @dataclass class MemorySnapshot: """Memory snapshot.""" + torch_peak: int = 0 free_memory: int = 0 total_memory: int = 0 @@ -2754,15 +2811,14 @@ class MemorySnapshot: # After `torch.cuda.reset_peak_memory_stats()`, # `torch.cuda.memory_reserved()` will keep growing, and only shrink # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats().get( - "allocated_bytes.all.peak", 0) + self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) self.free_memory, self.total_memory = torch.cuda.mem_get_info() - shared_sysmem_device_mem_sms = ( - (8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark - if current_platform.is_cuda() and \ - current_platform.get_device_capability() in \ - shared_sysmem_device_mem_sms: + shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark + if ( + current_platform.is_cuda() + and current_platform.get_device_capability() in shared_sysmem_device_mem_sms + ): # On UMA (Orin, Thor and Spark) platform, # where both CPU and GPU rely on system memory, # the cudaMemGetInfo function shows the amount of free system memory @@ -2801,8 +2857,8 @@ class MemorySnapshot: @dataclass class MemoryProfilingResult: - """Memory profiling result. All numbers are in bytes. - """ + """Memory profiling result. All numbers are in bytes.""" + non_kv_cache_memory: int = 0 torch_peak_increase: int = 0 non_torch_increase: int = 0 @@ -2813,20 +2869,22 @@ class MemoryProfilingResult: profile_time: float = 0.0 def __repr__(self) -> str: - return (f"Memory profiling takes {self.profile_time:.2f} seconds. " - f"Total non KV cache memory: " - f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " - f"torch peak memory increase: " - f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " - f"non-torch forward increase memory: " - f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " - f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.") + return ( + f"Memory profiling takes {self.profile_time:.2f} seconds. " + f"Total non KV cache memory: " + f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " + f"torch peak memory increase: " + f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " + f"non-torch forward increase memory: " + f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " + f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB." + ) @contextlib.contextmanager def memory_profiling( - baseline_snapshot: MemorySnapshot, - weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + baseline_snapshot: MemorySnapshot, weights_memory: int +) -> Generator[MemoryProfilingResult, None, None]: """Memory profiling context manager. baseline_snapshot: the memory snapshot before the current vLLM instance. weights_memory: memory used by PyTorch when loading the model weights. @@ -2900,29 +2958,34 @@ def memory_profiling( non_torch_memory = result.non_torch_increase peak_activation_memory = result.torch_peak_increase - result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa + result.non_kv_cache_memory = ( + non_torch_memory + peak_activation_memory + result.weights_memory + ) # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 def set_ulimit(target_soft_limit=65535): - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): logger.info("Windows detected, skipping ulimit adjustment.") return import resource + resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) if current_soft < target_soft_limit: try: - resource.setrlimit(resource_type, - (target_soft_limit, current_hard)) + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: logger.warning( "Found ulimit of %s and failed to automatically increase " "with error %s. This can cause fd limit errors like " "`OSError: [Errno 24] Too many open files`. Consider " - "increasing with ulimit -n", current_soft, e) + "increasing with ulimit -n", + current_soft, + e, + ) # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501 @@ -3043,11 +3106,7 @@ def zmq_socket_ctx( ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, - path, - socket_type, - bind=bind, - identity=identity) + yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") @@ -3068,6 +3127,7 @@ def _maybe_force_spawn(): # to the subprocess so that it knows how to connect to the ray cluster. # env vars are inherited by subprocesses, even if we use spawn. import ray + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address reasons.append("In a Ray actor and can only be spawned") @@ -3082,7 +3142,9 @@ def _maybe_force_spawn(): "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " "See https://docs.vllm.ai/en/latest/usage/" "troubleshooting.html#python-multiprocessing " - "for more information. Reasons: %s", "; ".join(reasons)) + "for more information. Reasons: %s", + "; ".join(reasons), + ) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -3101,7 +3163,7 @@ def get_mp_context(): def bind_kv_cache( ctx: dict[str, Any], kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: Optional[dict[str, str]] = None + shared_kv_cache_layers: Optional[dict[str, str]] = None, ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -3119,33 +3181,40 @@ def bind_kv_cache( shared_kv_cache_layers = {} from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index + layer_need_kv_cache = [ - layer_name for layer_name in ctx - if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \ - and ctx[layer_name].kv_sharing_target_layer_name is None + layer_name + for layer_name in ctx + if ( + hasattr(ctx[layer_name], "attn_type") + and ctx[layer_name].attn_type + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER) + ) + and ctx[layer_name].kv_sharing_target_layer_name is None ] layer_index_sorted = sorted( - set( - extract_layer_index(layer_name) - for layer_name in layer_need_kv_cache)) + set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache) + ) for layer_name in layer_need_kv_cache: - kv_cache_idx = layer_index_sorted.index( - extract_layer_index(layer_name)) + kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name)) forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] if shared_kv_cache_layers is not None: for layer_name, target_layer_name in shared_kv_cache_layers.items(): - assert extract_layer_index(target_layer_name) < \ - extract_layer_index(layer_name), \ - "v0 doesn't support interleaving kv sharing" + assert extract_layer_index(target_layer_name) < extract_layer_index( + layer_name + ), "v0 doesn't support interleaving kv sharing" ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache -def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], - kwargs: dict[str, Any]) -> Any: +def run_method( + obj: Any, + method: Union[str, bytes, Callable], + args: tuple[Any], + kwargs: dict[str, Any], +) -> Any: """ Run a method of an object with the given arguments and keyword arguments. If the method is string, it will be converted to a method using getattr. @@ -3159,8 +3228,9 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], try: func = getattr(obj, method) except AttributeError: - raise NotImplementedError(f"Method {method!r} is not" - " implemented.") from None + raise NotImplementedError( + f"Method {method!r} is not implemented." + ) from None else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) @@ -3194,6 +3264,7 @@ def import_pynvml(): module to our codebase, and use it directly. """ import vllm.third_party.pynvml as pynvml + return pynvml @@ -3213,7 +3284,7 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: unimplemented_methods = [] for attr_name in dir(self): # bypass inner method - if attr_name.startswith('_'): + if attr_name.startswith("_"): continue try: @@ -3227,8 +3298,8 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: - method_names = ','.join(unimplemented_methods) - msg = (f"Methods {method_names} not implemented in {self}") + method_names = ",".join(unimplemented_methods) + msg = f"Methods {method_names} not implemented in {self}" logger.debug(msg) @wraps(original_init) @@ -3236,7 +3307,7 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: original_init(self, *args, **kwargs) find_unimplemented_methods(self) - type.__setattr__(cls, '__init__', wrapped_init) + type.__setattr__(cls, "__init__", wrapped_init) return cls @@ -3340,7 +3411,6 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): """ def decorator(func: Callable): - @wraps(func) def wrapper(*args, **kwargs): if not enabled: @@ -3358,16 +3428,26 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): # Only relevant for models using ALiBi (e.g, MPT) def check_use_alibi(model_config: ModelConfig) -> bool: cfg = model_config.hf_text_config - return (getattr(cfg, "alibi", False) # Falcon - or ("BloomForCausalLM" in getattr(model_config.hf_config, - "architectures", [])) # Bloom - or getattr(cfg, "position_encoding_type", "") == - "alibi" # codellm_1b_alibi - or (hasattr(cfg, "attn_config") # MPT - and ((isinstance(cfg.attn_config, dict) - and cfg.attn_config.get("alibi", False)) or - (not isinstance(cfg.attn_config, dict) - and getattr(cfg.attn_config, "alibi", False))))) + return ( + getattr(cfg, "alibi", False) # Falcon + or ( + "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", []) + ) # Bloom + or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi + or ( + hasattr(cfg, "attn_config") # MPT + and ( + ( + isinstance(cfg.attn_config, dict) + and cfg.attn_config.get("alibi", False) + ) + or ( + not isinstance(cfg.attn_config, dict) + and getattr(cfg.attn_config, "alibi", False) + ) + ) + ) + ) def sha256(input: Any) -> bytes: @@ -3435,7 +3515,7 @@ def is_torch_equal_or_newer(target: str) -> bool: return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. - return Version(importlib.metadata.version('torch')) >= Version(target) + return Version(importlib.metadata.version("torch")) >= Version(target) # Helper function used in testing. @@ -3484,9 +3564,9 @@ def has_tilelang() -> bool: return _has_module("tilelang") -def set_process_title(name: str, - suffix: str = "", - prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: +def set_process_title( + name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX +) -> None: """ Set the current process title to a specific name with an optional suffix. @@ -3513,7 +3593,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: if file.start_new_line: # type: ignore[attr-defined] file_write(prefix) idx = 0 - while (next_idx := s.find('\n', idx)) != -1: + while (next_idx := s.find("\n", idx)) != -1: next_idx += 1 file_write(s[idx:next_idx]) if next_idx == len(s): @@ -3557,23 +3637,20 @@ def length_from_prompt_token_ids_or_embeds( """Calculate the request length (in number of tokens) give either prompt_token_ids or prompt_embeds. """ - prompt_token_len = None if prompt_token_ids is None else len( - prompt_token_ids) - prompt_embeds_len = \ - None if prompt_embeds is None else len(prompt_embeds) + prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids) + prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds) if prompt_token_len is None: if prompt_embeds_len is None: - raise ValueError( - "Neither prompt_token_ids nor prompt_embeds were defined.") + raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.") return prompt_embeds_len else: - if (prompt_embeds_len is not None - and prompt_embeds_len != prompt_token_len): + if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len: raise ValueError( "Prompt token ids and prompt embeds had different lengths" f" prompt_token_ids={prompt_token_len}" - f" prompt_embeds={prompt_embeds_len}") + f" prompt_embeds={prompt_embeds_len}" + ) return prompt_token_len diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 125508bc4a..ac4fcc0156 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -4,6 +4,7 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import functools @@ -26,9 +27,14 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability(100)) - return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch - and not envs.VLLM_USE_FLASHINFER_MOE_FP8) + or current_platform.is_device_capability(100) + ) + return ( + envs.VLLM_USE_DEEP_GEMM + and has_deep_gemm() + and is_supported_arch + and not envs.VLLM_USE_FLASHINFER_MOE_FP8 + ) @functools.cache @@ -38,7 +44,8 @@ def is_deep_gemm_e8m0_used() -> bool: """ if not is_deep_gemm_supported(): logger.debug_once( - "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system." + ) return False _lazy_init() @@ -51,13 +58,14 @@ def is_deep_gemm_e8m0_used() -> bool: logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.") return False - if current_platform.is_device_capability(100) and \ - envs.VLLM_USE_DEEP_GEMM_E8M0: + if current_platform.is_device_capability(100) and envs.VLLM_USE_DEEP_GEMM_E8M0: logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") return True - if current_platform.is_device_capability(90) and \ - envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + if ( + current_platform.is_device_capability(90) + and envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER + ): logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") return True @@ -69,7 +77,8 @@ def _missing(*_: Any, **__: Any) -> NoReturn: """Placeholder for unavailable DeepGEMM backend.""" raise RuntimeError( "DeepGEMM backend is not available or outdated. Please install or " - "update the `deep_gemm` to a newer version to enable FP8 kernels.") + "update the `deep_gemm` to a newer version to enable FP8 kernels." + ) _fp8_gemm_nt_impl: Callable[..., Any] | None = None @@ -89,21 +98,25 @@ def _lazy_init() -> None: global _get_mn_major_tma_aligned_tensor_impl # fast path - if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None - or _fp8_mqa_logits_impl is not None - or _fp8_paged_mqa_logits_impl is not None - or _get_paged_mqa_logits_metadata_impl is not None): + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + ): return if not has_deep_gemm(): return # Set up deep_gemm cache path - DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR' + DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR" if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( - envs.VLLM_CACHE_ROOT, "deep_gemm") + envs.VLLM_CACHE_ROOT, "deep_gemm" + ) _dg = importlib.import_module("deep_gemm") @@ -113,9 +126,11 @@ def _lazy_init() -> None: _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( - _dg, "get_paged_mqa_logits_metadata", None) + _dg, "get_paged_mqa_logits_metadata", None + ) _get_mn_major_tma_aligned_tensor_impl = getattr( - _dg, "get_mn_major_tma_aligned_tensor", None) + _dg, "get_mn_major_tma_aligned_tensor", None + ) def get_num_sms() -> int: @@ -148,9 +163,9 @@ def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl(*args, - disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), - **kwargs) + return _grouped_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -158,7 +173,8 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) def fp8_mqa_logits( @@ -191,8 +207,9 @@ def fp8_mqa_logits( return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) -def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, - num_sms: int) -> torch.Tensor: +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. Args: @@ -208,8 +225,7 @@ def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, - num_sms) + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) def fp8_paged_mqa_logits( @@ -245,14 +261,16 @@ def fp8_paged_mqa_logits( _lazy_init() if _fp8_paged_mqa_logits_impl is None: return _missing() - return _fp8_paged_mqa_logits_impl(q_fp8, - kv_cache_fp8, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True) + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) def _ceil_to_ue8m0(x: torch.Tensor): @@ -269,15 +287,14 @@ DEFAULT_BLOCK_SIZE = [128, 128] # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38 @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def per_block_cast_to_fp8( - x: torch.Tensor, - block_size: list[int] = DEFAULT_BLOCK_SIZE, - use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape block_m, block_n = block_size - x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -285,7 +302,8 @@ def per_block_cast_to_fp8( sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( - x_view.size(0), x_view.size(2)) + x_view.size(0), x_view.size(2) + ) def calc_diff(x: torch.Tensor, y: torch.Tensor): @@ -305,13 +323,18 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): def should_use_deepgemm_for_fp8_linear( - output_dtype: torch.dtype, - weight: torch.Tensor, - supports_deep_gemm: Optional[bool] = None): + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: Optional[bool] = None, +): if supports_deep_gemm is None: supports_deep_gemm = is_deep_gemm_supported() - return (supports_deep_gemm and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + return ( + supports_deep_gemm + and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0 + ) __all__ = [ diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 734cd93879..22dfbe60f8 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -4,6 +4,7 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import contextlib @@ -44,7 +45,8 @@ def _missing(*_: Any, **__: Any) -> NoReturn: raise RuntimeError( "FlashInfer backend is not available. Please install the package " "to enable FlashInfer kernels: " - "https://github.com/flashinfer-ai/flashinfer") + "https://github.com/flashinfer-ai/flashinfer" + ) def _get_submodule(module_name: str) -> Any | None: @@ -56,9 +58,9 @@ def _get_submodule(module_name: str) -> Any | None: # General lazy import wrapper -def _lazy_import_wrapper(module_name: str, - attr_name: str, - fallback_fn: Callable[..., Any] = _missing): +def _lazy_import_wrapper( + module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing +): """Create a lazy import wrapper for a specific function.""" @functools.cache @@ -79,29 +81,34 @@ def _lazy_import_wrapper(module_name: str, # Create lazy wrappers for each function flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( - "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") + "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe" +) flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper( - "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe") -flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", - "cutlass_fused_moe") + "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe" +) +flashinfer_cutlass_fused_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "cutlass_fused_moe" +) fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_block_scale_interleave = _lazy_import_wrapper( - "flashinfer", "nvfp4_block_scale_interleave") + "flashinfer", "nvfp4_block_scale_interleave" +) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( - "flashinfer", "trtllm_fp4_block_scale_moe") + "flashinfer", "trtllm_fp4_block_scale_moe" +) # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", "autotune", - fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), +) @functools.cache def has_flashinfer_comm() -> bool: """Return ``True`` if FlashInfer comm module is available.""" - return has_flashinfer() and importlib.util.find_spec( - "flashinfer.comm") is not None + return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None @functools.cache @@ -128,8 +135,10 @@ def has_flashinfer_all2all() -> bool: @functools.cache def has_flashinfer_moe() -> bool: """Return ``True`` if FlashInfer MoE module is available.""" - return has_flashinfer() and importlib.util.find_spec( - "flashinfer.fused_moe") is not None + return ( + has_flashinfer() + and importlib.util.find_spec("flashinfer.fused_moe") is not None + ) @functools.cache @@ -174,7 +183,8 @@ def has_nvidia_artifactory() -> bool: else: logger.warning_once( "NVIDIA artifactory returned failed status code: %d", - response.status_code) + response.status_code, + ) return accessible except Exception as e: logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e) @@ -188,8 +198,7 @@ def supports_trtllm_attention() -> bool: NVIDIA artifactory is accessible """ # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return current_platform.is_device_capability( - 100) and has_nvidia_artifactory() + return current_platform.is_device_capability(100) and has_nvidia_artifactory() @functools.cache @@ -238,7 +247,8 @@ def use_trtllm_attention( if force_use_trtllm: logger.warning_once( "TRTLLM attention is not supported on this platform, " - "but VLLM_USE_TRTLLM_ATTENTION is set to 1") + "but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) return False # The combination of query and key heads is not supported @@ -252,8 +262,7 @@ def use_trtllm_attention( if has_spec and not is_prefill: # Speculative decoding requires TRTLLM attention for decodes - logger.info_once( - "Using TRTLLM attention (enabled for speculative decoding).") + logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") return True # Must use TRTLLM attention if query is FP8 quantized @@ -261,28 +270,28 @@ def use_trtllm_attention( if has_sinks: raise RuntimeError( "TRTLLM FP8-qkv kernel is not supported for attention sinks. " - "Use kv_cache_dtype=auto for now.") + "Use kv_cache_dtype=auto for now." + ) logger.info_once("Using TRTLLM attention (query is quantized).") return True # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: - logger.info_once( - "Using TRTLLM attention (required for attention sinks).") + logger.info_once("Using TRTLLM attention (required for attention sinks).") return True if force_use_trtllm is None: # Environment variable not set - use auto-detection - use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072 - and kv_cache_dtype == "auto") + use_trtllm = ( + num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" + ) if use_trtllm: logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm # Environment variable is set to 1 - respect it - logger.info_once( - "Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") + logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") return True @@ -303,16 +312,14 @@ if has_flashinfer(): backend: str, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ - return flashinfer_mm_fp4_(A, - B, - A_scale, - B_scale, - g_scale, - dtype, - block_size=16, - backend=backend) - @torch.library.register_fake("vllm::flashinfer_mm_fp4", ) + return flashinfer_mm_fp4_( + A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + ) + + @torch.library.register_fake( + "vllm::flashinfer_mm_fp4", + ) def flashinfer_mm_fp4_fake( A: torch.Tensor, B: torch.Tensor, @@ -322,10 +329,7 @@ if has_flashinfer(): dtype: torch.dtype, backend: str, ) -> torch.Tensor: - return torch.empty(A.shape[0], - B.shape[1], - dtype=dtype, - device=A.device) + return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) @torch.library.custom_op( "vllm::bmm_fp8", @@ -341,9 +345,12 @@ if has_flashinfer(): backend: str, ) -> torch.Tensor: from flashinfer import bmm_fp8 as bmm_fp8_ + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) - @torch.library.register_fake("vllm::bmm_fp8", ) + @torch.library.register_fake( + "vllm::bmm_fp8", + ) def bmm_fp8_fake( A: torch.Tensor, B: torch.Tensor, @@ -352,18 +359,20 @@ if has_flashinfer(): dtype: torch.dtype, backend: str, ) -> torch.Tensor: - return torch.empty(A.shape[0], - A.shape[1], - B.shape[2], - dtype=dtype, - device=A.device) + return torch.empty( + A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device + ) -def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, alpha: torch.Tensor, - out_dtype: torch.dtype, - backend: str) -> torch.Tensor: +def flashinfer_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 @@ -387,12 +396,13 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, def flashinfer_scaled_fp8_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert a.shape[1] == b.shape[0] assert scale_a.numel() == 1 and scale_b.numel() == 1 diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 8ce2c200e2..e3b5b61dd3 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -36,8 +36,7 @@ class GCDebugConfig: self.top_objects = json_conf.get("top_objects", -1) except Exception: self.enabled = False - logger.error("Failed to parse VLLM_GC_DEBUG(%s)", - VLLM_GC_DEBUG) + logger.error("Failed to parse VLLM_GC_DEBUG(%s)", VLLM_GC_DEBUG) logger.info("GC Debug Config. %s", str(self)) def __repr__(self) -> str: @@ -70,7 +69,8 @@ class GCDebugger: # and top collected objects self.start_time_ns = time.monotonic_ns() self.gc_top_collected_objects = _compute_top_gc_collected_objects( - gc.get_objects(generation), self.config.top_objects) + gc.get_objects(generation), self.config.top_objects + ) elif phase == "stop": # After GC finished, Record GC elapsed time and # optionally top collected objects @@ -81,8 +81,11 @@ class GCDebugger: elpased_ms, str(info.get("collected", "?")), generation, - (f" Top collected objects: \n{self.gc_top_collected_objects}" - if self.gc_top_collected_objects else ""), + ( + f" Top collected objects: \n{self.gc_top_collected_objects}" + if self.gc_top_collected_objects + else "" + ), ) @@ -125,4 +128,5 @@ def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str: object_types = [_compute_detailed_type(o) for o in objects] return "\n".join( f"{count:>5}:{object_type}" - for object_type, count in Counter(object_types).most_common(top)) + for object_type, count in Counter(object_types).most_common(top) + ) diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 7eb58b5f5c..dcdc6ccb4c 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -52,40 +52,35 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: def json_map_leaves( func: Callable[["torch.Tensor"], "torch.Tensor"], value: "BatchedTensorInputs", -) -> "BatchedTensorInputs": - ... +) -> "BatchedTensorInputs": ... @overload def json_map_leaves( func: Callable[[_T], _U], value: Union[_T, dict[str, _T]], -) -> Union[_U, dict[str, _U]]: - ... +) -> Union[_U, dict[str, _U]]: ... @overload def json_map_leaves( func: Callable[[_T], _U], value: Union[_T, list[_T]], -) -> Union[_U, list[_U]]: - ... +) -> Union[_U, list[_U]]: ... @overload def json_map_leaves( func: Callable[[_T], _U], value: Union[_T, tuple[_T, ...]], -) -> Union[_U, tuple[_U, ...]]: - ... +) -> Union[_U, tuple[_U, ...]]: ... @overload def json_map_leaves( func: Callable[[_T], _U], value: JSONTree[_T], -) -> JSONTree[_U]: - ... +) -> JSONTree[_U]: ... def json_map_leaves( @@ -111,8 +106,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: Union[_T, dict[str, _T]], /, -) -> _T: - ... +) -> _T: ... @overload @@ -120,8 +114,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: Union[_T, list[_T]], /, -) -> _T: - ... +) -> _T: ... @overload @@ -129,8 +122,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: Union[_T, tuple[_T, ...]], /, -) -> _T: - ... +) -> _T: ... @overload @@ -138,8 +130,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: JSONTree[_T], /, -) -> _T: - ... +) -> _T: ... @overload @@ -148,15 +139,14 @@ def json_reduce_leaves( value: JSONTree[_T], initial: _U, /, -) -> _U: - ... +) -> _U: ... def json_reduce_leaves( - func: Callable[..., Union[_T, _U]], - value: _JSONTree[_T], - initial: _U = cast(_U, ...), # noqa: B008 - /, + func: Callable[..., Union[_T, _U]], + value: _JSONTree[_T], + initial: _U = cast(_U, ...), # noqa: B008 + /, ) -> Union[_T, _U]: """ Apply a function of two arguments cumulatively to each leaf in a diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 81daca7dfb..e17676ccf7 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (Annotated, Any, Optional, Union, get_args, get_origin, - get_type_hints) +from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints import torch @@ -11,7 +10,6 @@ logger = init_logger(__name__) class TensorShape: - def __init__( self, *dims: Union[int, str], @@ -37,8 +35,7 @@ class TensorShape: for dim in self.dims: if isinstance(dim, str): if dim in self.dynamic_dims: - dim_strs.append( - f"{dim}*") # Mark dynamic dimensions with * + dim_strs.append(f"{dim}*") # Mark dynamic dimensions with * else: dim_strs.append(dim) else: @@ -47,7 +44,6 @@ class TensorShape: class TensorSchema: - def __init__( self, *, @@ -101,12 +97,12 @@ class TensorSchema: return str(list(idxs)) def _validate_field( - self, - value: object, - field_name: str, - expected_shape: tuple[Union[int, str], ...], - dynamic_dims: set[str], - leading_idxs: tuple[int, ...] = (), + self, + value: object, + field_name: str, + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + leading_idxs: tuple[int, ...] = (), ) -> tuple[int, ...]: """Validate a field and return the actual shape.""" if isinstance(value, (int, float)): @@ -118,11 +114,13 @@ class TensorSchema: raise TypeError( f"{field_name}{self._fmt_indexer(leading_idxs)} is not " f"one of the expected types: int, float, Tensor, list, tuple. " - f"Got: {type(value)}") + f"Got: {type(value)}" + ) if len(value) == 0: - raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} " - f"is an empty sequence") + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence" + ) # Ensure all tensors in the list have the same # shape, besides dynamic dimensions @@ -132,25 +130,26 @@ class TensorSchema: field_name, expected_shape[1:], dynamic_dims, - leading_idxs=leading_idxs + (i, ), + leading_idxs=leading_idxs + (i,), ) if i == 0: first_shape = shape elif not self._match_shape_with_dynamic( - shape, - first_shape, - expected_shape, - dynamic_dims, + shape, + first_shape, + expected_shape, + dynamic_dims, ): raise ValueError( f"{field_name}{self._fmt_indexer(leading_idxs)} " f"contains inconsistent shapes: {first_shape} " - f"(index 0) vs {shape} (index {i})") + f"(index 0) vs {shape} (index {i})" + ) # Treat the list as a stacked tensor: # shape = (len(list), *tensor.shape) - return (len(value), ) + first_shape + return (len(value),) + first_shape def _validate_tensor_shape_expected( self, @@ -163,31 +162,38 @@ class TensorSchema: """Validate that the actual tensor shape matches the expected shape.""" if len(actual_shape) != len(expected_shape): - raise ValueError(f"{field_name} has rank {len(actual_shape)} " - f"but expected {len(expected_shape)}. " - f"Expected shape: {expected_shape}, " - f"but got {actual_shape}") + raise ValueError( + f"{field_name} has rank {len(actual_shape)} " + f"but expected {len(expected_shape)}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) for i, dim in enumerate(expected_shape): if dim in dynamic_dims: continue elif isinstance(dim, int): if actual_shape[i] != dim: - raise ValueError(f"{field_name} dim[{i}] expected " - f"{dim}, got {actual_shape[i]}. " - f"Expected shape: {expected_shape}, " - f"but got {actual_shape}") + raise ValueError( + f"{field_name} dim[{i}] expected " + f"{dim}, got {actual_shape[i]}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) elif isinstance(dim, str): if dim in shape_env: if actual_shape[i] != shape_env[dim]: - raise ValueError(f"{field_name} dim[{i}] expected " - f"'{dim}'={shape_env[dim]}, got " - f"{actual_shape[i]}") + raise ValueError( + f"{field_name} dim[{i}] expected " + f"'{dim}'={shape_env[dim]}, got " + f"{actual_shape[i]}" + ) else: shape_env[dim] = actual_shape[i] else: - raise TypeError(f"{field_name} dim[{i}] has unsupported " - f"type: {type(dim)}") + raise TypeError( + f"{field_name} dim[{i}] has unsupported type: {type(dim)}" + ) def validate(self) -> None: type_hints = get_type_hints(self.__class__, include_extras=True) @@ -195,8 +201,7 @@ class TensorSchema: for field_name, field_type in type_hints.items(): # Check if field is missing - if (not hasattr(self, field_name) - or getattr(self, field_name) is None): + if not hasattr(self, field_name) or getattr(self, field_name) is None: # Check if field is marked as optional actual_type = field_type if get_origin(field_type) is Annotated: @@ -228,8 +233,12 @@ class TensorSchema: ) self._validate_tensor_shape_expected( - actual_shape, expected_shape, field_name, - shape_env, arg.dynamic_dims) + actual_shape, + expected_shape, + field_name, + shape_env, + arg.dynamic_dims, + ) def print_shapes(self) -> None: """Print TensorShape annotations for debugging.""" diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 369f706200..6e27e93c91 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -7,19 +7,26 @@ import numpy as np import torch from torch.nn.functional import scaled_dot_product_attention -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True # AttributeError is to handle a bug in ipex # https://github.com/intel/intel-extension-for-pytorch/pull/813 @@ -41,15 +48,15 @@ class TorchSDPABackend(AttentionBackend): @classmethod def validate_head_size(cls, head_size: int) -> None: attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size( - head_size) + is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) if not is_valid: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -76,7 +83,8 @@ class TorchSDPABackend(AttentionBackend): cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return _get_paged_attn_impl().get_kv_cache_shape( - num_blocks, block_size, num_kv_heads, head_size) + num_blocks, block_size, num_kv_heads, head_size + ) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -86,6 +94,7 @@ class TorchSDPABackend(AttentionBackend): @dataclass class TorchSDPAMetadata(AttentionMetadata): """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. num_prefills: int # Number of prefill tokens. @@ -156,23 +165,27 @@ class TorchSDPAMetadata(AttentionMetadata): @property def is_all_encoder_attn_metadata_set(self): - ''' + """ All attention metadata required for encoder attention is set. - ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + """ + return ( + (self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None) + ) @property def is_all_cross_attn_metadata_set(self): - ''' + """ All attention metadata required for enc/dec cross-attention is set. Superset of encoder attention required metadata. - ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + """ + return ( + self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None) + ) @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: @@ -190,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata): self, attn_type: str, ): - ''' + """ Extract appropriate sequence lengths from attention metadata according to attention type. @@ -203,10 +216,12 @@ class TorchSDPAMetadata(AttentionMetadata): Returns: * Appropriate sequence lengths tensor for query * Appropriate sequence lengths tensor for key & value - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): seq_lens_q = self.seq_lens seq_lens_kv = self.seq_lens elif attn_type == AttentionType.ENCODER: @@ -223,7 +238,7 @@ class TorchSDPAMetadata(AttentionMetadata): self, attn_type: str, ) -> Optional[list[torch.Tensor]]: - ''' + """ Extract appropriate attention bias from attention metadata according to attention type. @@ -235,10 +250,12 @@ class TorchSDPAMetadata(AttentionMetadata): Returns: * Appropriate attention bias value given the attention type - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): return self.attn_bias elif attn_type == AttentionType.ENCODER: return self.encoder_attn_bias @@ -252,7 +269,7 @@ class TorchSDPAMetadata(AttentionMetadata): attn_bias: list[torch.Tensor], attn_type: str, ) -> None: - ''' + """ Update appropriate attention bias field of attention metadata, according to attention type. @@ -262,10 +279,12 @@ class TorchSDPAMetadata(AttentionMetadata): * attn_bias: The desired attention bias value * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): self.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: self.encoder_attn_bias = attn_bias @@ -278,7 +297,7 @@ class TorchSDPAMetadata(AttentionMetadata): self, attn_type: str, ) -> tuple: - ''' + """ The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation. @@ -300,23 +319,30 @@ class TorchSDPAMetadata(AttentionMetadata): * Appropriate sequence-lengths tensor * Appropriate max sequence-length scalar * Appropriate block tables (or None) - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run - return (self.decode_seq_lens_tensor, self.decode_max_seq_len, - self.decode_block_tables) + return ( + self.decode_seq_lens_tensor, + self.decode_max_seq_len, + self.decode_block_tables, + ) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - self.cross_block_tables) + return ( + self.encoder_seq_lens_tensor, + self.max_encoder_seq_len, + self.cross_block_tables, + ) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - None) + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") @@ -324,8 +350,13 @@ class TorchSDPAMetadata(AttentionMetadata): class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device) -> None: + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ) -> None: super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.scheduler_config = vllm_config.scheduler_config @@ -338,10 +369,12 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TorchSDPAMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TorchSDPAMetadata: num_reqs = common_attn_metadata.num_reqs max_query_len = common_attn_metadata.max_query_len @@ -351,22 +384,27 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_np = query_start_loc_cpu.numpy() - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, - require_uniform=True) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) - max_prefill_seq_len = seq_lens_np[num_decodes:num_reqs].max().item( - ) if num_prefills > 0 else 0 - max_decode_seq_len = seq_lens_np[:num_decodes].max().item( - ) if num_prefills < num_reqs else 0 + max_prefill_seq_len = ( + seq_lens_np[num_decodes:num_reqs].max().item() if num_prefills > 0 else 0 + ) + max_decode_seq_len = ( + seq_lens_np[:num_decodes].max().item() if num_prefills < num_reqs else 0 + ) self.seq_start_loc_np[0] = 0 - np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) + np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1 : num_reqs + 1]) slot_mapping = common_attn_metadata.slot_mapping.long() block_table_tensor = common_attn_metadata.block_table_tensor query_start_loc_np = query_start_loc_cpu.numpy() - query_start_loc_np[num_decodes:num_reqs + 1] -= num_decode_tokens + query_start_loc_np[num_decodes : num_reqs + 1] -= num_decode_tokens attn_metadata = TorchSDPAMetadata( num_prefills=num_prefills, @@ -381,21 +419,20 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): chunked_prefill=self.scheduler_config.chunked_prefill_enabled, max_query_len=max_query_len, prefill_max_seq_len=max_prefill_seq_len, - prefill_query_start_loc=query_start_loc_cpu[num_decodes:num_reqs + - 1], # prefill - prefill_seq_start_loc=self.seq_start_loc_cpu[num_decodes:num_reqs + - 1], # prefill - prefill_block_tables=block_table_tensor[ - num_decodes:num_reqs], # prefill - query_start_loc=query_start_loc_cpu[:num_reqs + - 1], # for logits index + prefill_query_start_loc=query_start_loc_cpu[ + num_decodes : num_reqs + 1 + ], # prefill + prefill_seq_start_loc=self.seq_start_loc_cpu[ + num_decodes : num_reqs + 1 + ], # prefill + prefill_block_tables=block_table_tensor[num_decodes:num_reqs], # prefill + query_start_loc=query_start_loc_cpu[: num_reqs + 1], # for logits index ) return attn_metadata class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): - def __init__( self, num_heads: int, @@ -412,8 +449,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if logits_soft_cap is not None: - logger.warning_once("Torch SPDA does not support logits soft cap. " - "Outputs may be slightly off.") + logger.warning_once( + "Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off." + ) self.paged_attn_impl = _get_paged_attn_impl() self.num_heads = num_heads self.head_size = head_size @@ -426,13 +465,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): self.kv_cache_dtype = kv_cache_dtype self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.need_mask = ( + self.alibi_slopes is not None or self.sliding_window is not None + ) if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: raise NotImplementedError( "Torch SDPA backend FP8 KV cache requires " - "intel_extension_for_pytorch support.") + "intel_extension_for_pytorch support." + ) self.attn_type = attn_type def forward( @@ -464,22 +505,28 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for TorchSDPABackendImpl") + " for TorchSDPABackendImpl" + ) # For warming-up if attn_metadata is None: return query attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") + if attn_type == AttentionType.ENCODER and ( + not attn_metadata.is_all_encoder_attn_metadata_set + ): + raise AttributeError( + "Encoder attention requires setting encoder metadata attributes." + ) + elif attn_type == AttentionType.ENCODER_DECODER and ( + not attn_metadata.is_all_cross_attn_metadata_set + ): + raise AttributeError( + "Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes." + ) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -490,7 +537,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): else: assert value is None - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -499,7 +546,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = self.paged_attn_impl.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size + ) if (key is not None) and (value is not None): if attn_type == AttentionType.ENCODER_DECODER: @@ -512,8 +560,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): updated_slot_mapping = attn_metadata.slot_mapping self.paged_attn_impl.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) + key, + value, + key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -539,20 +594,18 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if prefill_meta := attn_metadata.prefill_metadata: if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore assert attn_metadata.seq_lens is not None - self._run_sdpa_forward(output, - query, - key, - value, - prefill_meta, - attn_type=attn_type) + self._run_sdpa_forward( + output, query, key, value, prefill_meta, attn_type=attn_type + ) else: # prefix-enabled attention assert not self.need_mask import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) ipex_modules.PagedAttention.flash_attn_varlen_func( - output[prefill_meta.num_decode_tokens:, :, :], - query[prefill_meta.num_decode_tokens:, :, :], + output[prefill_meta.num_decode_tokens :, :, :], + query[prefill_meta.num_decode_tokens :, :, :], key_cache, value_cache, prefill_meta.prefill_query_start_loc, @@ -567,7 +620,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") + "Encoder-only models should not have decode metadata." + ) # Decoding run. ( seq_lens_arg, @@ -576,8 +630,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ) = decode_meta.get_seq_len_block_table_args(attn_type) self.paged_attn_impl.forward_decode( - output[:attn_metadata.num_decode_tokens, :, :], - query[:attn_metadata.num_decode_tokens, :, :], + output[: attn_metadata.num_decode_tokens, :, :], + query[: attn_metadata.num_decode_tokens, :, :], key_cache, value_cache, block_tables_arg, @@ -607,13 +661,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if attn_masks is None: if self.alibi_slopes is not None: attn_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, # type: ignore + ) elif self.sliding_window is not None: assert attn_metadata.seq_lens is not None attn_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore + attn_metadata.seq_lens, self.sliding_window, query.dtype + ) else: seq_lens, _ = attn_metadata.get_seq_lens(attn_type) attn_masks = [None] * len(seq_lens) @@ -627,22 +683,26 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) - causal_attn = (attn_type == AttentionType.DECODER) + causal_attn = attn_type == AttentionType.DECODER seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) start_q, start_kv = 0, 0 - for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, - attn_masks): + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv - sub_out = scaled_dot_product_attention( - query[None, :, start_q:end_q, :], - key[None, :, start_kv:end_kv, :], - value[None, :, start_kv:end_kv, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=causal_attn and mask is None, - scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + sub_out = ( + scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv @@ -665,9 +725,11 @@ def _make_alibi_bias( num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + inf_mask = ( + torch.empty((1, seq_len, seq_len), dtype=bias.dtype) + .fill_(-torch.inf) + .triu_(diagonal=1) + ) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases @@ -696,7 +758,6 @@ def _make_sliding_window_bias( class _PagedAttention: - @staticmethod def validate_head_size(head_size: int) -> tuple[bool, list[int]]: SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] @@ -723,8 +784,7 @@ class _PagedAttention: num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @@ -800,7 +860,6 @@ class _PagedAttention: class _IPEXPagedAttention(_PagedAttention): - @staticmethod def validate_head_size(head_size: int) -> tuple[bool, list[int]]: return True, [] @@ -833,8 +892,8 @@ class _IPEXPagedAttention(_PagedAttention): *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, - slot_mapping.flatten().int()) + key, value, key_cache, value_cache, slot_mapping.flatten().int() + ) @staticmethod def forward_decode( @@ -854,17 +913,30 @@ class _IPEXPagedAttention(_PagedAttention): *args, ) -> None: block_size = value_cache.shape[2] - head_mapping = torch.arange( - 0, - num_kv_heads, - device="cpu", - dtype=torch.int32, - ).view(num_kv_heads, - 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + head_mapping = ( + torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ) + .view(num_kv_heads, 1) + .repeat_interleave(query.size(1) // num_kv_heads) + .flatten() + ) ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, query.contiguous(), key_cache, value_cache, head_mapping, - scale, block_tables, context_lens, block_size, max_context_len, - alibi_slopes) + output, + query.contiguous(), + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) def _get_paged_attn_impl(): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f0770f7441..bb3dcddba3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" + from dataclasses import dataclass from typing import Optional @@ -8,34 +9,43 @@ import numpy as np import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version, - is_flash_attn_varlen_func_available) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, - get_scheduler_metadata, - reshape_and_cache_flash) + from vllm.attention.utils.fa_utils import ( + flash_attn_varlen_func, + get_scheduler_metadata, + reshape_and_cache_flash, + ) from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supports_quant_query_input: bool = True @@ -56,7 +66,8 @@ class FlashAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -141,7 +152,8 @@ class FlashAttentionMetadata: def _get_sliding_window_configs( - vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + vllm_config: VllmConfig, +) -> set[Optional[tuple[int, int]]]: """Get the set of all sliding window configs used in the model.""" sliding_window_configs: set[Optional[tuple[int, int]]] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) @@ -151,8 +163,7 @@ def _get_sliding_window_configs( return sliding_window_configs -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]): # FA3: # Supports full cudagraphs for all cases. # @@ -171,11 +182,19 @@ class FlashAttentionMetadataBuilder( # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support = AttentionCGSupport.ALWAYS \ - if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH + cudagraph_support = ( + AttentionCGSupport.ALWAYS + if get_flash_attn_version() == 3 + else AttentionCGSupport.UNIFORM_BATCH + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -183,18 +202,19 @@ class FlashAttentionMetadataBuilder( self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. - self.aot_schedule = (get_flash_attn_version() == 3) + self.aot_schedule = get_flash_attn_version() == 3 - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.aot_schedule: @@ -202,8 +222,8 @@ class FlashAttentionMetadataBuilder( # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. raise ValueError( - "Capture size larger than 992 is not supported for " - "full cuda graph.") + "Capture size larger than 992 is not supported for full cuda graph." + ) self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -213,19 +233,20 @@ class FlashAttentionMetadataBuilder( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = ( - envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: """ - fast_build disables AOT scheduling, used when there will be few + fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode """ num_reqs = common_attn_metadata.num_reqs @@ -249,8 +270,7 @@ class FlashAttentionMetadataBuilder( # build() call so the layers are constructed (cannot populate) # in __init__. if aot_schedule: - sliding_window_configs = _get_sliding_window_configs( - self.vllm_config) + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: @@ -260,20 +280,21 @@ class FlashAttentionMetadataBuilder( aot_schedule = False max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible - if self.use_full_cuda_graph and \ - num_actual_tokens <= self.max_cudagraph_size: + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( - cache_dtype) + cache_dtype + ) else: qkv_dtype = self.kv_cache_dtype if aot_schedule: @@ -297,39 +318,44 @@ class FlashAttentionMetadataBuilder( use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, - causal=False) - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=suffix_kv_lens, - max_seq_len=max_seq_len - - common_prefix_len, - causal=True) + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=seq_lens, - max_seq_len=max_seq_len, - causal=causal) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) # For FA3 + full cudagraph if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] @@ -357,7 +383,8 @@ class FlashAttentionMetadataBuilder( suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - causal=causal) + causal=causal, + ) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -365,7 +392,6 @@ class FlashAttentionMetadataBuilder( class FlashAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -406,18 +432,20 @@ class FlashAttentionImpl(AttentionImpl): self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() - if is_quantized_kv_cache(self.kv_cache_dtype) \ - and not flash_attn_supports_fp8(): + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( - "FlashAttention does not support fp8 kv-cache on this device.") + "FlashAttention does not support fp8 kv-cache on this device." + ) self.sinks = sinks if self.sinks is not None: assert self.vllm_flash_attn_version == 3, ( - "Sinks are only supported in FlashAttention 3") + "Sinks are only supported in FlashAttention 3" + ) assert self.sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " - "heads in the layer") + "heads in the layer" + ) def forward( self, @@ -450,8 +478,8 @@ class FlashAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -474,11 +502,14 @@ class FlashAttentionImpl(AttentionImpl): if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching - return self._forward_encoder_attention(query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - output[:num_actual_tokens], - attn_metadata, layer) + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) @@ -486,8 +517,11 @@ class FlashAttentionImpl(AttentionImpl): # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if (self.kv_sharing_target_layer_name is None and key is not None - and value is not None): + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -509,7 +543,8 @@ class FlashAttentionImpl(AttentionImpl): if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( - self.kv_cache_dtype) + self.kv_cache_dtype + ) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) @@ -597,7 +632,8 @@ class FlashAttentionImpl(AttentionImpl): # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError( - "quantization is not supported for encoder attention") + "quantization is not supported for encoder attention" + ) # Use encoder-specific metadata for sequence information cu_seqlens_q = attn_metadata.query_start_loc @@ -607,7 +643,8 @@ class FlashAttentionImpl(AttentionImpl): descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads) + self.num_kv_heads, + ) # Call flash attention directly on Q, K, V tensors flash_attn_varlen_func( @@ -670,8 +707,12 @@ def use_cascade_attention( num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window - and not use_alibi and np.all(query_lens == 1)) + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) if not use_flash_decoding: # Use cascade attention. return True @@ -693,8 +734,9 @@ def use_cascade_attention( cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles - flash_decoding_ctas = (num_reqs * num_kv_heads * - cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) @@ -726,10 +768,11 @@ def cascade_attention( k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window.") + "Cascade attention does not support sliding window." + ) num_tokens = query.shape[0] block_size = key_cache.shape[-3] @@ -755,12 +798,9 @@ def cascade_attention( return_softmax_lse=True, scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) - if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) - if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) - if v_descale is not None else None, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -782,14 +822,10 @@ def cascade_attention( return_softmax_lse=True, scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) - if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) - if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) - if v_descale is not None else None, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, ) # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 15a252734d..1c05a17db8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashInfer.""" + from __future__ import annotations from dataclasses import dataclass @@ -8,36 +9,50 @@ from typing import ClassVar, Optional, Union import numpy as np import torch -from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - MultiLevelCascadeAttentionWrapper) +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper, +) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, +) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant) + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (can_use_trtllm_attention, - flashinfer_disable_q_quantization, - supports_trtllm_attention, - use_trtllm_attention) +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + flashinfer_disable_q_quantization, + supports_trtllm_attention, + use_trtllm_attention, +) + # yapf conflicts with isort for this block # yapf: disable -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) + # yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec @@ -55,7 +70,8 @@ def _get_trtllm_gen_workspace_buffer(): global trtllm_gen_workspace_buffer if trtllm_gen_workspace_buffer is None: trtllm_gen_workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + ) return trtllm_gen_workspace_buffer @@ -72,9 +88,9 @@ def _trtllm_prefill_attn_kvfp8_dequant( ): batch_idx = tl.program_id(0).to(tl.int64) mock_block_table_idx = tl.program_id(1).to(tl.int64) - orig_page_num = tl.load(block_tables_prefill_ptr + - batch_idx * block_table_stride + - mock_block_table_idx).to(tl.int64) + orig_page_num = tl.load( + block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx + ).to(tl.int64) if orig_page_num <= 0: return dequant_dtype = mock_kv_cache_ptr.dtype.element_ty @@ -84,20 +100,24 @@ def _trtllm_prefill_attn_kvfp8_dequant( offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) fp8_vals = tl.load(kv_cache_ptr + offset) dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val - mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx - + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + mock_cache_offset = ( + batch_idx * block_table_stride + mock_block_table_idx + 1 + ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) dequantized_vals = dequantized_vals.to(dequant_dtype) tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) # Dequantize V v_scale_val = tl.load(v_scale_ptr) - offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + - tl.arange(0, K_CACHE_STRIDE)) + offset = ( + orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + ) fp8_vals = tl.load(kv_cache_ptr + offset) dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val mock_cache_offset = ( - (batch_idx * block_table_stride + mock_block_table_idx + 1) * - KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) + (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE + + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE) + ) dequantized_vals = dequantized_vals.to(dequant_dtype) tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) @@ -117,9 +137,7 @@ def trtllm_prefill_attn_kvfp8_dequant( kv_cache_stride = k_cache_stride * s[1] new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) # mock kv cache contains just the pages needed by this prefill - mock_kv_cache = torch.empty(new_s, - dtype=dequant_dtype, - device=kv_cache.device) + mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) # we simply sequentially index the pages needed by this prefill mock_block_table = torch.arange( start=1, @@ -162,7 +180,8 @@ class FlashInferBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -249,13 +268,19 @@ class FlashInferMetadata: class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config @@ -264,22 +289,27 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._decode_wrapper = None # Wrapper for decode (general shape) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(self.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL) + self.enable_cuda_graph = ( + self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + ) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. self._decode_wrappers_cudagraph: dict[ - int, BatchDecodeWithPagedKVCacheWrapper] = {} + int, BatchDecodeWithPagedKVCacheWrapper + ] = {} self._decode_cudagraph_max_bs = min( - max_num_reqs, self.compilation_config.max_capture_size) + max_num_reqs, self.compilation_config.max_capture_size + ) self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) + self.vllm_config.parallel_config + ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size FlashInferBackend.validate_head_size(self.head_dim) @@ -287,9 +317,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): - self.kv_cache_dtype = ( - FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.cache_dtype)) + self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype + ) else: assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype @@ -298,14 +328,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata - if supports_trtllm_attention() and \ - not flashinfer_disable_q_quantization(): + if supports_trtllm_attention() and not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype - supports_spec_as_decode = \ - can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + supports_spec_as_decode = can_use_trtllm_attention( + self.num_qo_heads, self.num_kv_heads + ) self._init_reorder_batch_threshold(1, supports_spec_as_decode) self._cascade_wrapper = None # Wrapper for cascade attention @@ -313,7 +343,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Global hyperparameters shared by all attention layers # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl) + ) self.sm_scale = self.global_hyperparameters.sm_scale self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap @@ -322,67 +353,62 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs.") + "earlier GPUs." + ) # Preparing persistent buffers (device-side) - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=self.device) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) self.paged_kv_indices = torch.zeros( max_num_pages, # max num pages possible dtype=torch.int32, - device=self.device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=self.device) + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) # host-side buffer pin_memory = is_pin_memory_available() - self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() self.paged_kv_indptr_buffer = torch.zeros_like( - self.paged_kv_indptr_cpu, pin_memory=pin_memory) - self.paged_kv_indices_cpu = torch.zeros(max_num_pages, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_np = ( - self.paged_kv_last_page_len_cpu.numpy()) + self.paged_kv_indptr_cpu, pin_memory=pin_memory + ) + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device + ) return self._workspace_buffer def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout()) + self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._prefill_wrapper - def _get_decode_wrapper(self, - batch_size: int, - use_cudagraph: bool = False): + def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): if use_cudagraph: - decode_wrapper = self._decode_wrappers_cudagraph.get( - batch_size, None) + decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) else: decode_wrapper = self._decode_wrapper if decode_wrapper is None: if use_cudagraph: - paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[: - batch_size] + paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] else: paged_kv_indptr = None paged_kv_indices = None @@ -411,19 +437,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), get_kv_cache_layout()) + 2, self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._cascade_wrapper - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashInferMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashInferMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, - require_uniform=True) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len @@ -442,17 +474,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_common_kv_blocks = common_prefix_len // page_size # Create CPU versions directly for cascade (no GPU versions needed) - shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device='cpu') - shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device='cpu') - shared_kv_page_indices_cpu = block_table_tensor[ - 0, :num_common_kv_blocks] - shared_kv_last_page_len_cpu = torch.tensor([page_size], - dtype=torch.int32, - device='cpu') + shared_qo_indptr_cpu = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indptr_cpu = torch.tensor( + [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor( + [page_size], dtype=torch.int32, device="cpu" + ) # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] @@ -467,22 +498,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): np.cumsum( num_blocks_np, dtype=np.int32, - out=self.paged_kv_indptr_np[1:num_reqs + 1], + out=self.paged_kv_indptr_np[1 : num_reqs + 1], ) # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified # after this line (e.g., for cuda graphs), we need to copy the data to # self.paged_kv_indptr_buffer to avoid race condition. - self.paged_kv_indptr_buffer[:num_reqs + - 1] = (self.paged_kv_indptr_cpu[:num_reqs + - 1]) - paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] - paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1], - non_blocking=True) + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + ) # write self.paged_kv_indices inplace num_actual_pages = self.paged_kv_indptr_np[num_reqs] paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - _copy_page_indices_kernel[(num_reqs, )]( + _copy_page_indices_kernel[(num_reqs,)]( paged_kv_indices, block_table_tensor, block_table_tensor.stride(0), @@ -499,29 +531,34 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) uses_spec_reorder = self.reorder_batch_threshold > 1 - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_prefill_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=True, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder) - decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_decode_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=False, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder) + prefill_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) + decode_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs.") + "earlier GPUs." + ) # If TRTLLM attention is not used, the q quantization is not supported. # Fall back to use model dtype. @@ -547,7 +584,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] if attn_metadata.use_cascade: @@ -578,17 +615,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Decodes are first so prefills start after the last decode prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert qo_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert paged_kv_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert paged_kv_last_page_len_cpu[prefill_start:].shape[ - 0] == num_prefills + assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert ( + paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills + ) # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. - qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ - prefill_start] + qo_indptr_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] # Recompute max_q_len for the slice of requests we are using @@ -596,8 +633,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # we have a non-uniform batch with some short decodes offloaded # to the prefill pathway query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] - attn_metadata.max_q_len_prefill = \ - int(query_lens_prefill.max().item()) + attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: attn_metadata.prefill_wrapper.plan( @@ -618,42 +654,50 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) if num_decodes > 0: pure_decode = num_prefills == 0 # possible required padding for cudagraph replay - use_cudagraph = (self.enable_cuda_graph and pure_decode and - num_decodes <= self._decode_cudagraph_max_bs) + use_cudagraph = ( + self.enable_cuda_graph + and pure_decode + and num_decodes <= self._decode_cudagraph_max_bs + ) if use_cudagraph: - num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decode_tokens)) + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_decode_tokens + ) # Carefully fulfill the padding region with reasonable value # on cpu. # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - paged_kv_indptr_cpu[-1]) + self.paged_kv_indptr_cpu[ + 1 + num_decodes : 1 + num_input_tokens + ].fill_(paged_kv_indptr_cpu[-1]) # Fill the remaining paged_kv_last_page_len_cpu with 1. # This is because flashinfer treats 0 as a full page # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) + self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( + 1 + ) else: num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph) + num_input_tokens, use_cudagraph + ) if not attn_metadata.decode_use_trtllm: # Use the persistent buffer with padding length, # instead of the same address but chunked version # in atten_metadata when using cudagraph. fast_plan_decode( attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], + self.paged_kv_indptr_cpu[: num_input_tokens + 1], paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens], @@ -682,7 +726,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -708,8 +751,9 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) - self.window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) + self.window_left = ( + self.sliding_window[0] if self.sliding_window is not None else -1 + ) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -717,10 +761,12 @@ class FlashInferImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl" + ) self.sinks: Optional[torch.Tensor] = None if sinks is not None: @@ -728,19 +774,23 @@ class FlashInferImpl(AttentionImpl): raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}.") + f"{sinks.shape[0]}." + ) self.sinks = sinks - self.support_trtllm_attn = (supports_trtllm_attention() - and num_heads % num_kv_heads == 0) + self.support_trtllm_attn = ( + supports_trtllm_attention() and num_heads % num_kv_heads == 0 + ) self.bmm1_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None self.o_sf_scale: Optional[float] = None def fused_output_quant_supported(self, quant_key: QuantKey): - return (self.support_trtllm_attn - and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + ) def forward( self, @@ -774,28 +824,32 @@ class FlashInferImpl(AttentionImpl): return output if self.bmm1_scale is None: - self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * - self.scale) + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert output_block_scale is None, "output_block_scale "\ - "is not supported when fusion has not happened" + assert output_block_scale is None, ( + "output_block_scale is not supported when fusion has not happened" + ) else: - assert attn_metadata.q_data_type == FP8_DTYPE, \ + assert attn_metadata.q_data_type == FP8_DTYPE, ( "Query must be FP8 when attn+quant fusion happened." - assert (attn_metadata.prefill_use_trtllm and - attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + ) + assert ( + attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + ), "Must use TRT-LLM attn" if output.dtype == FP8_DTYPE: - assert output_block_scale is None, \ + assert output_block_scale is None, ( "output_block_scale should not be provided for fp8 output" + ) elif output.dtype == FP4_DTYPE: - assert output_block_scale is not None, \ + assert output_block_scale is not None, ( "output_block_scale is required for nvfp4 output" + ) else: raise ValueError(f"Unsupported output dtype: {output.dtype}") @@ -813,9 +867,9 @@ class FlashInferImpl(AttentionImpl): if attn_metadata.q_data_type == FP8_DTYPE: num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) # IMPORTANT! @@ -852,7 +906,8 @@ class FlashInferImpl(AttentionImpl): # to process the cache when the kv_cache_dtype is fp8 if self.kv_cache_dtype.startswith("fp8"): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype) + self.kv_cache_dtype + ) kv_cache = kv_cache.view(torch_dtype) # Inputs and outputs may be padded for CUDA graphs @@ -886,8 +941,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, @@ -900,8 +954,7 @@ class FlashInferImpl(AttentionImpl): # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_prefill = attn_metadata.block_table_tensor[ - num_decodes:] + block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -914,28 +967,31 @@ class FlashInferImpl(AttentionImpl): if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None - out = FP4Tensor(data=output[num_decode_tokens:], - scale=output_block_scale, - scale_start_index=num_decode_tokens, - original_shape=prefill_query.shape) + out = FP4Tensor( + data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape, + ) else: assert self.o_sf_scale is None out = output[num_decode_tokens:] - if attn_metadata.q_data_type != FP8_DTYPE \ - and self.kv_cache_dtype.startswith("fp8"): + if ( + attn_metadata.q_data_type != FP8_DTYPE + and self.kv_cache_dtype.startswith("fp8") + ): # TRTLLM prefill attention does not support BF16 Q # and fp8 kv cache. So to enable prefill attention # with fp8 kv cache, we can construct a mock block # and mock kv cache with BF16 KV involved in the prefill - mock_kv_cache, mock_block_table = ( - trtllm_prefill_attn_kvfp8_dequant( - kv_cache_permute, - block_tables_prefill, - layer._k_scale, - layer._v_scale, - attn_metadata.q_data_type, - )) + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + ) else: mock_kv_cache = kv_cache_permute mock_block_table = block_tables_prefill @@ -967,8 +1023,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.decode_use_trtllm: assert decode_wrapper._window_left == self.window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, @@ -981,8 +1036,9 @@ class FlashInferImpl(AttentionImpl): # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] + block_tables_decode = attn_metadata.block_table_tensor[ + :num_decode_tokens + ] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -995,10 +1051,12 @@ class FlashInferImpl(AttentionImpl): if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None - out = FP4Tensor(data=output[:num_decode_tokens], - scale=output_block_scale, - scale_start_index=0, - original_shape=decode_query.shape) + out = FP4Tensor( + data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape, + ) else: assert self.o_sf_scale is None out = output[:num_decode_tokens] @@ -1008,8 +1066,7 @@ class FlashInferImpl(AttentionImpl): # attention to be initialized with q_len = 0 q_len_per_req = 1 else: - q_len_per_req = \ - num_decode_tokens // attn_metadata.num_decodes + q_len_per_req = num_decode_tokens // attn_metadata.num_decodes trtllm_batch_decode_with_kv_cache( query=decode_query, @@ -1024,7 +1081,8 @@ class FlashInferImpl(AttentionImpl): sinks=self.sinks, o_sf_scale=self.o_sf_scale, out=out, - q_len_per_req=q_len_per_req) + q_len_per_req=q_len_per_req, + ) return output_padded @@ -1065,8 +1123,7 @@ def fast_plan_decode( # Warm up with the original plan if it is first call, and always run the # original plan if we run for dynamic shape. For fixed shape (cudagraph), # this warm up is to generate the _cached_module for the decode wrapper. - if not self.is_cuda_graph_enabled or \ - getattr(self, "vllm_first_call", True): + if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): self.plan( indptr_cpu, indices, @@ -1106,26 +1163,28 @@ def fast_plan_decode( if kv_data_type is None: kv_data_type = q_data_type - q_data_type = getattr(torch, q_data_type) if isinstance( - q_data_type, str) else q_data_type - kv_data_type = getattr(torch, kv_data_type) if isinstance( - kv_data_type, str) else kv_data_type + q_data_type = ( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ) + kv_data_type = ( + getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type + ) if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " "batch size {} mismatches the batch size set during " - "initialization {}".format(batch_size, self._fixed_batch_size)) + "initialization {}".format(batch_size, self._fixed_batch_size) + ) if len(indices) > len(self._paged_kv_indices_buf): raise ValueError( - "The size of indices should be less than or equal to the " - "allocated buffer") + "The size of indices should be less than or equal to the allocated buffer" + ) # host-to-device copy for the indptr buffer self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) # host-to-device copy for the last_page_len buffer - self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, - non_blocking=True) + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") @@ -1176,6 +1235,8 @@ def _copy_page_indices_kernel( offset = tl.arange(0, BLOCK_SIZE) for i in tl.range(0, num_blocks, BLOCK_SIZE): block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) - tl.store(page_indices + start_idx + i + offset, - block_ids, - mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e548b51060..4640e62abf 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -8,21 +8,32 @@ from typing import TYPE_CHECKING, Optional, Union import torch import torch._dynamo.decorators import torch.nn.functional as F -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - _score_mod_signature, and_masks, - create_block_mask, - flex_attention) +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + _score_mod_signature, + and_masks, + create_block_mask, + flex_attention, +) -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant) + vllm_kernel_override_batch_invariant, +) from vllm.utils import cdiv, is_torch_equal_or_newer -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -31,9 +42,9 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch -create_block_mask_compiled = torch.compile(create_block_mask, - fullgraph=True, - mode="reduce-overhead") +create_block_mask_compiled = torch.compile( + create_block_mask, fullgraph=True, mode="reduce-overhead" +) flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) @@ -41,7 +52,8 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: device = offsets.device counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts) + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): @@ -103,10 +115,13 @@ class FlexAttentionBackend(AttentionBackend): return False -#@torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping(block_table: torch.Tensor, - seq_lens: torch.Tensor, block_size: int, - total_blocks: int) -> torch.Tensor: +# @torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping( + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + total_blocks: int, +) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -176,35 +191,37 @@ def physical_to_logical_mapping(block_table: torch.Tensor, max_reqs, max_num_blocks = block_table.shape device = block_table.device - physical_to_logical = torch.full((max_reqs, total_blocks), - -1, - dtype=torch.long, - device=device) + physical_to_logical = torch.full( + (max_reqs, total_blocks), -1, dtype=torch.long, device=device + ) # Only process valid blocks to avoid garbage values num_blocks_per_seq = cdiv(seq_lens, block_size) - mask = torch.arange(max_num_blocks, - device=device)[None, :] < num_blocks_per_seq[:, None] + mask = ( + torch.arange(max_num_blocks, device=device)[None, :] + < num_blocks_per_seq[:, None] + ) valid_block_table = torch.where(mask, block_table, 0) valid_logical_indices = torch.where( - mask, - torch.arange(max_num_blocks, device=device)[None, :], 0) + mask, torch.arange(max_num_blocks, device=device)[None, :], 0 + ) - physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), - valid_logical_indices) + physical_to_logical.scatter_( + -1, valid_block_table.to(torch.int64), valid_logical_indices + ) # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical def unique_static_unsorted( - x: torch.Tensor, - *, - M: int, # maximum positive value (0 is “skip me”) - dim: int = -1, # axis along which to deduplicate - ignored_val: int = 0, # value to ignore - pad_val: int = -1, # sentinel for unused slots + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots ) -> torch.Tensor: """ - Keeps the first occurrence of each non-zero value while preserving order, @@ -236,8 +253,7 @@ def unique_static_unsorted( first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── - keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) - ) # [B, N] + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] # ── left-pack uniques into a fresh tensor ─────────────────────────── dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go @@ -251,8 +267,9 @@ def unique_static_unsorted( return packed -def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, - kv_idx: torch.Tensor): +def causal_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +): return q_idx >= kv_idx @@ -317,8 +334,7 @@ class FlexAttentionMetadata: physical_kv_block = physical_kv_idx // self.block_size physical_kv_offset = physical_kv_idx % self.block_size logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] - logical_kv_idx = (logical_block_idx * self.block_size + - physical_kv_offset) + logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # Determine valid kv indices live_block = logical_block_idx >= 0 @@ -352,9 +368,9 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - self.doc_ids, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -392,11 +408,11 @@ class FlexAttentionMetadata: """ if self.sliding_window is None: - raise ValueError( - "sliding_window must be set for sliding window attention") + raise ValueError("sliding_window must be set for sliding window attention") - def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor, - q_idx: torch.Tensor, kv_idx: torch.Tensor): + def sliding_window_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): return torch.abs(q_idx - kv_idx) < self.sliding_window def final_mask_mod( @@ -405,9 +421,9 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - self.doc_ids, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) return torch.where( is_valid, sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), @@ -451,18 +467,19 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - request_lookup, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx + ) + ) return torch.where( is_valid, - user_score_mod(score, - b, - h, - logical_q_idx, - logical_kv_idx, - physical_q=q_idx), -float('inf')) + user_score_mod( + score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx + ), + -float("inf"), + ) return transformed_score_mod @@ -493,18 +510,22 @@ class FlexAttentionMetadata: f"FlexAttention currently requires the cache block size " f"({self.block_size}) to be equal to the kv_block_size " f"({self.kv_block_size}). Please check your model's " - f"configuration.") + f"configuration." + ) used_pages = self.block_table[ - self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] - used_pages_padded = pad_to_multiple(used_pages, - multiple=self.q_block_size, - dim=0) + self.doc_ids, : cdiv(self.max_seq_len, self.block_size) + ] + used_pages_padded = pad_to_multiple( + used_pages, multiple=self.q_block_size, dim=0 + ) used_pages_padded = used_pages_padded.reshape( - used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded.shape[0] // self.q_block_size, -1 + ) used_pages_padded = used_pages_padded // page_to_block_ratio - kv_indices = unique_static_unsorted((used_pages_padded.long()), - M=self.num_blocks).to(torch.int32) + kv_indices = unique_static_unsorted( + (used_pages_padded.long()), M=self.num_blocks + ).to(torch.int32) kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) block_mask_kwargs = { @@ -524,8 +545,7 @@ class FlexAttentionMetadata: def build_block_mask(self) -> BlockMask: mask_mod = self.get_mask_mod() - kv_len = (self.total_cache_tokens - if self.causal else self.num_actual_tokens) + kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens return create_block_mask_compiled( mask_mod, None, @@ -555,11 +575,14 @@ class FlexAttentionMetadata: self.block_mask = self.build_block_mask() -class FlexAttentionMetadataBuilder( - AttentionMetadataBuilder[FlexAttentionMetadata]): - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): +class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -567,26 +590,27 @@ class FlexAttentionMetadataBuilder( self.cache_config = vllm_config.cache_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") - self.q_block_size: int = 16 if is_torch_equal_or_newer( - "2.9.0.dev0") else 128 - self.kv_block_size: int = 16 if is_torch_equal_or_newer( - "2.9.0.dev0") else 128 + self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: return False - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlexAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlexAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -609,15 +633,18 @@ class FlexAttentionMetadataBuilder( max_possible_seq_len = self.model_config.max_model_len num_gpu_blocks = self.cache_config.num_gpu_blocks - assert num_gpu_blocks is not None, \ + assert num_gpu_blocks is not None, ( "FlexAttention requires num_gpu_blocks to be set" - total_cache_tokens = (num_gpu_blocks * block_size) + ) + total_cache_tokens = num_gpu_blocks * block_size inverse_block_table = physical_to_logical_mapping( - block_table_tensor, seq_lens, block_size, num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks + ) offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) out = FlexAttentionMetadata( causal=common_attn_metadata.causal, @@ -675,14 +702,15 @@ class FlexAttentionImpl(AttentionImpl): self.num_kv_heads = num_kv_heads self.attn_type = attn_type - if attn_type not in (AttentionType.ENCODER_ONLY, - AttentionType.DECODER): + if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER): raise NotImplementedError( - f"FlexAttention does not support {attn_type} attention") + f"FlexAttention does not support {attn_type} attention" + ) if alibi_slopes is not None: raise NotImplementedError( - "FlexAttention does not support alibi slopes yet.") + "FlexAttention does not support alibi slopes yet." + ) else: self.alibi_slopes = None @@ -692,19 +720,20 @@ class FlexAttentionImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap if self.logits_soft_cap is not None: raise NotImplementedError( - "FlexAttention does not support logits soft cap yet.") + "FlexAttention does not support logits soft cap yet." + ) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: - raise NotImplementedError( - "FlexAttention does not support kv sharing yet.") + raise NotImplementedError("FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "FlexAttention does not support quantized kv-cache. Yet") + "FlexAttention does not support quantized kv-cache. Yet" + ) @staticmethod def view_as_4d(tensor: torch.Tensor) -> torch.Tensor: @@ -741,8 +770,8 @@ class FlexAttentionImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlexAttentionImpl") + "fused output quantization is not yet supported for FlexAttentionImpl" + ) enable_gqa = self.num_kv_heads != self.num_heads @@ -761,11 +790,11 @@ class FlexAttentionImpl(AttentionImpl): # in direct block mask building code path. logger.warning_once( "Using direct block mask building with sliding window, " - "which is suboptimal now. Performance may be degraded.") + "which is suboptimal now. Performance may be degraded." + ) # update mask mod in attention metadata attn_metadata.mask_mod = attn_metadata.get_mask_mod() - attn_metadata.block_mask = ( - attn_metadata._build_block_mask_direct()) + attn_metadata.block_mask = attn_metadata._build_block_mask_direct() else: attn_metadata.block_mask = attn_metadata.build_block_mask() @@ -778,8 +807,9 @@ class FlexAttentionImpl(AttentionImpl): ) query = query[:, :, :num_actual_tokens, :] - if ((key_tensor.size(-2) > num_actual_tokens) - or (value_tensor.size(-2) > num_actual_tokens)): + if (key_tensor.size(-2) > num_actual_tokens) or ( + value_tensor.size(-2) > num_actual_tokens + ): # In the encoder-only model with torch.compile, # qkv might be padded, which might cause exception. # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 @@ -803,8 +833,7 @@ class FlexAttentionImpl(AttentionImpl): # View out the block_size dim key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, - self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) query, key_tensor, value_tensor = map( lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), (query, key_cache, value_cache), @@ -818,8 +847,9 @@ class FlexAttentionImpl(AttentionImpl): assert attn_metadata.block_mask is not None block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE - kernel_options = get_kernel_options(query, block_m, block_n, - attn_metadata.direct_build) + kernel_options = get_kernel_options( + query, block_m, block_n, attn_metadata.direct_build + ) out = flex_attention_compiled( query, key_tensor, @@ -837,8 +867,9 @@ class FlexAttentionImpl(AttentionImpl): return output -def get_kernel_options(query, block_m, block_n, - use_direct_build: bool) -> dict[str, Union[int, bool]]: +def get_kernel_options( + query, block_m, block_n, use_direct_build: bool +) -> dict[str, Union[int, bool]]: kernel_options: dict[str, Union[int, bool]] = { "FORCE_USE_FLEX_ATTENTION": True, } diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 11f165d6cf..0e271da5fb 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Backend for GatedDeltaNet attention.""" + from dataclasses import dataclass from typing import Optional @@ -9,16 +10,17 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class GDNAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: return GDNAttentionMetadataBuilder @@ -36,19 +38,21 @@ class GDNAttentionMetadata: has_initial_state: Optional[torch.Tensor] = None - spec_query_start_loc: Optional[ - torch.Tensor] = None # shape: [num_spec_decodes + 1,] - non_spec_query_start_loc: Optional[ - torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,] + spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [num_spec_decodes + 1,] + ) + non_spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes + 1,] + ) - spec_state_indices_tensor: Optional[ - torch.Tensor] = None # shape: [batch, num_spec] - non_spec_state_indices_tensor: Optional[ - torch.Tensor] = None # shape: [batch - num_spec_decodes,] + spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes,] + ) spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] - spec_token_masks: Optional[ - torch. - Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,] + spec_token_masks: Optional[torch.Tensor] = ( + None # shape: [num_prefill_tokens + num_decode_tokens,] + ) num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d @@ -57,15 +61,18 @@ class GDNAttentionMetadata: token_chunk_offset_ptr: Optional[torch.Tensor] = None -class GDNAttentionMetadataBuilder( - AttentionMetadataBuilder[GDNAttentionMetadata]): - +class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): cudagraph_support = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): assert isinstance(kv_cache_spec, MambaSpec) self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config @@ -78,11 +85,13 @@ class GDNAttentionMetadataBuilder( self.use_spec_decode = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs * - (self.num_spec + 1), self.compilation_config.max_capture_size) + self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), + self.compilation_config.max_capture_size, + ) self.spec_state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), @@ -90,32 +99,32 @@ class GDNAttentionMetadataBuilder( device=device, ) self.non_spec_state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) self.spec_sequence_masks = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.bool, device=device, ) self.spec_token_masks = torch.empty( - (self.decode_cudagraph_max_bs * (self.num_spec + 1), ), + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), dtype=torch.bool, device=device, ) self.spec_query_start_loc = torch.empty( - (self.decode_cudagraph_max_bs + 1, ), + (self.decode_cudagraph_max_bs + 1,), dtype=torch.int32, device=device, ) self.non_spec_query_start_loc = torch.empty( - (self.decode_cudagraph_max_bs + 1, ), + (self.decode_cudagraph_max_bs + 1,), dtype=torch.int32, device=device, ) self.num_accepted_tokens = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) @@ -135,9 +144,14 @@ class GDNAttentionMetadataBuilder( context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None - or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= - 0].sum().item() == 0): + if ( + not self.use_spec_decode + or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0] + .sum() + .item() + == 0 + ): spec_sequence_masks = None num_spec_decodes = 0 else: @@ -147,11 +161,13 @@ class GDNAttentionMetadataBuilder( spec_sequence_masks = None else: spec_sequence_masks = spec_sequence_masks.to( - query_start_loc.device, non_blocking=True) + query_start_loc.device, non_blocking=True + ) if spec_sequence_masks is None: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(m, decode_threshold=1)) + split_decodes_and_prefills(m, decode_threshold=1) + ) num_spec_decode_tokens = 0 spec_token_masks = None spec_state_indices_tensor = None @@ -166,45 +182,56 @@ class GDNAttentionMetadataBuilder( num_decodes = (non_spec_query_lens == 1).sum().item() num_prefills = non_spec_query_lens.size(0) - num_decodes num_decode_tokens = num_decodes - num_prefill_tokens = non_spec_query_lens.sum().item( - ) - num_decode_tokens + num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens if num_prefills == 0 and num_decodes == 0: spec_token_masks = torch.ones( - (min(num_spec_decodes * - (self.num_spec + 1), query_start_loc[-1].item())), + ( + min( + num_spec_decodes * (self.num_spec + 1), + query_start_loc[-1].item(), + ) + ), dtype=torch.bool, - device=query_start_loc.device) - spec_state_indices_tensor = m.block_table_tensor[:, :self. - num_spec + 1] + device=query_start_loc.device, + ) + spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] non_spec_state_indices_tensor = None spec_query_start_loc = query_start_loc non_spec_query_start_loc = None else: spec_token_masks = torch.repeat_interleave( - spec_sequence_masks, query_lens) + spec_sequence_masks, query_lens + ) spec_state_indices_tensor = m.block_table_tensor[ - spec_sequence_masks, :self.num_spec + 1] - non_spec_state_indices_tensor = \ - m.block_table_tensor[~spec_sequence_masks, 0] + spec_sequence_masks, : self.num_spec + 1 + ] + non_spec_state_indices_tensor = m.block_table_tensor[ + ~spec_sequence_masks, 0 + ] spec_query_start_loc = torch.zeros( num_spec_decodes + 1, dtype=torch.int32, - device=query_start_loc.device) - torch.cumsum(query_lens[spec_sequence_masks], - dim=0, - out=spec_query_start_loc[1:]) + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] + ) non_spec_query_start_loc = torch.zeros( query_lens.size(0) - num_spec_decodes + 1, dtype=torch.int32, - device=query_start_loc.device) - torch.cumsum(query_lens[~spec_sequence_masks], - dim=0, - out=non_spec_query_start_loc[1:]) + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:], + ) - num_spec_decode_tokens = (query_lens.sum().item() - - num_prefill_tokens - num_decode_tokens) + num_spec_decode_tokens = ( + query_lens.sum().item() - num_prefill_tokens - num_decode_tokens + ) assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] @@ -212,12 +239,14 @@ class GDNAttentionMetadataBuilder( has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: has_initial_state = has_initial_state[~spec_sequence_masks] - nums_dict, batch_ptr, token_chunk_offset_ptr = \ + nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(non_spec_query_start_loc) + ) else: has_initial_state = None - num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ - num_spec_decode_tokens + num_actual_tokens = ( + num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens + ) # prepare tensors for cudagraph # @@ -226,64 +255,71 @@ class GDNAttentionMetadataBuilder( # # In above cases, the max possible batch size for n tokens, can be # min(n, cudagraph_max_bs). - if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 - and num_spec_decodes <= self.decode_cudagraph_max_bs - and num_spec_decode_tokens <= self.decode_cudagraph_max_bs): - num_actual_tokens = self.vllm_config.pad_for_cudagraph( - m.num_actual_tokens) + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) self.spec_state_indices_tensor[:num_spec_decodes].copy_( - spec_state_indices_tensor, non_blocking=True) - spec_state_indices_tensor = self.spec_state_indices_tensor[: - batch_size] + spec_state_indices_tensor, non_blocking=True + ) + spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) self.spec_sequence_masks[:num_spec_decodes].copy_( - spec_sequence_masks, non_blocking=True) + spec_sequence_masks, non_blocking=True + ) spec_sequence_masks = self.spec_sequence_masks[:batch_size] spec_sequence_masks[num_spec_decodes:].fill_(False) assert spec_token_masks is not None - self.spec_token_masks[:spec_token_masks.size(0)].copy_( - spec_token_masks, non_blocking=True) + self.spec_token_masks[: spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True + ) spec_token_masks = self.spec_token_masks[:num_actual_tokens] - spec_token_masks[spec_token_masks.size(0):].fill_(False) + spec_token_masks[spec_token_masks.size(0) :].fill_(False) - self.spec_query_start_loc[:num_spec_decodes + 1].copy_( - spec_query_start_loc, non_blocking=True) - spec_num_query_tokens = spec_query_start_loc[ - -1] # type: ignore[index] - spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1] - spec_query_start_loc[num_spec_decodes + - 1:].fill_(spec_num_query_tokens) + self.spec_query_start_loc[: num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True + ) + spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1] + spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens) self.num_accepted_tokens[:num_spec_decodes].copy_( - num_accepted_tokens, non_blocking=True) + num_accepted_tokens, non_blocking=True + ) num_accepted_tokens = self.num_accepted_tokens[:batch_size] num_accepted_tokens[num_spec_decodes:].fill_(1) - if (self.use_full_cuda_graph and num_prefills == 0 - and num_spec_decodes == 0 - and num_decodes <= self.decode_cudagraph_max_bs): - num_actual_tokens = self.vllm_config.pad_for_cudagraph( - m.num_actual_tokens) + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) batch_size = num_actual_tokens self.non_spec_state_indices_tensor[:num_decodes].copy_( - non_spec_state_indices_tensor, non_blocking=True) - non_spec_state_indices_tensor = \ - self.non_spec_state_indices_tensor[:batch_size] + non_spec_state_indices_tensor, non_blocking=True + ) + non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ + :batch_size + ] non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) - self.non_spec_query_start_loc[:num_decodes + 1].copy_( - non_spec_query_start_loc, non_blocking=True) - non_spec_num_query_tokens = non_spec_query_start_loc[ - -1] # type: ignore[index] - non_spec_query_start_loc = \ - self.non_spec_query_start_loc[:batch_size + 1] - non_spec_query_start_loc[num_decodes + - 1:].fill_(non_spec_num_query_tokens) + self.non_spec_query_start_loc[: num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True + ) + non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index] + non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] + non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, @@ -308,7 +344,8 @@ class GDNAttentionMetadataBuilder( return attn_metadata def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): + self, common_attn_metadata: CommonAttentionMetadata + ): """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. @@ -317,16 +354,17 @@ class GDNAttentionMetadataBuilder( assert ( m.num_reqs <= self.decode_cudagraph_max_bs - and m.num_actual_tokens <= self.decode_cudagraph_max_bs), ( - f"GDN only supports decode-only full CUDAGraph capture. " - f"Make sure batch size ({m.num_reqs}) <= " - f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " - f"and number of tokens ({m.num_actual_tokens}) <= " - f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).") + and m.num_actual_tokens <= self.decode_cudagraph_max_bs + ), ( + f"GDN only supports decode-only full CUDAGraph capture. " + f"Make sure batch size ({m.num_reqs}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " + f"and number of tokens ({m.num_actual_tokens}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})." + ) num_accepted_tokens = torch.diff(m.query_start_loc) num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() - return self.build(0, m, num_accepted_tokens, - num_decode_draft_tokens_cpu) + return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 0dc62d6680..1900c50849 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -6,14 +6,15 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class LinearAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder @@ -31,20 +32,25 @@ class LinearAttentionMetadata: state_indices_tensor: torch.Tensor # shape: [batch,] -class LinearAttentionMetadataBuilder( - AttentionMetadataBuilder[LinearAttentionMetadata]): - +class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]): reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> LinearAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> LinearAttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -52,8 +58,9 @@ class LinearAttentionMetadataBuilder( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 7cbfa2c2c9..e305cb2d87 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -8,14 +8,14 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + split_decodes_and_prefills, +) class Mamba1AttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: return Mamba1AttentionMetadataBuilder @@ -35,8 +35,8 @@ class Mamba1AttentionMetadata: class Mamba1AttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): - + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] +): def build( self, common_prefix_len: int, @@ -47,24 +47,30 @@ class Mamba1AttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - query_start_loc.device) + query_start_loc.device + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) has_initial_states = None padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 - elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph): + elif ( + num_decodes > 0 + and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): state_indices_for_decode = state_indices_tensor[:num_decodes] padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( - state_indices_for_decode, non_blocking=True) + state_indices_for_decode, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor[num_decodes:] = PAD_SLOT_ID diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 49fe1584e7..ae8a0e92da 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -9,12 +9,13 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.utils import cdiv -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec @@ -68,27 +69,26 @@ def compute_varlen_chunk_metadata( # Exclusive prefix sum over logical-chunk lengths if chunk_lens: - cu_chunk_seqlens = torch.tensor([0] + - list(itertools.accumulate(chunk_lens)), - device=device, - dtype=torch.int32) + cu_chunk_seqlens = torch.tensor( + [0] + list(itertools.accumulate(chunk_lens)), + device=device, + dtype=torch.int32, + ) # Final boundary must equal total tokens assert int(cu_chunk_seqlens[-1].item()) == total else: cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32) - last_chunk_indices_t = (torch.tensor( - last_chunk_indices, device=device, dtype=torch.int32) - if len(starts) > 0 else torch.empty( - (0, ), device=device, dtype=torch.int32)) - seq_idx_chunks_t = torch.tensor(seq_idx_chunks, - device=device, - dtype=torch.int32) + last_chunk_indices_t = ( + torch.tensor(last_chunk_indices, device=device, dtype=torch.int32) + if len(starts) > 0 + else torch.empty((0,), device=device, dtype=torch.int32) + ) + seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32) return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t class Mamba2AttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: return Mamba2AttentionMetadataBuilder @@ -135,37 +135,48 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] +): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( - "chunk_size needs to be set in the model config for Mamba2 models") + "chunk_size needs to be set in the model config for Mamba2 models" + ) if self.vllm_config.cache_config.enable_prefix_caching: self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, - cdiv(vllm_config.model_config.max_model_len, - kv_cache_spec.block_size)), + ( + self.decode_cudagraph_max_bs, + cdiv( + vllm_config.model_config.max_model_len, kv_cache_spec.block_size + ), + ), dtype=torch.int32, device=device, ) self.current_last_idx = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) self.last_state_idx = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> Mamba2AttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs seq_lens = common_attn_metadata.seq_lens @@ -192,12 +203,11 @@ class Mamba2AttentionMetadataBuilder( # Additional cache-related varaiables: mamba_block_size = self.kv_cache_spec.block_size seq_lens_pending = ( - torch.roll(common_attn_metadata.query_start_loc, -1, -1) - - common_attn_metadata.query_start_loc)[:-1] - context_lens = common_attn_metadata.seq_lens - \ - seq_lens_pending - last_computed_offset = \ - context_lens % mamba_block_size + torch.roll(common_attn_metadata.query_start_loc, -1, -1) + - common_attn_metadata.query_start_loc + )[:-1] + context_lens = common_attn_metadata.seq_lens - seq_lens_pending + last_computed_offset = context_lens % mamba_block_size # Indices: last_computed <= current_first <= current_last # Cases: # last_computed == current_first if last state was partially @@ -205,55 +215,65 @@ class Mamba2AttentionMetadataBuilder( # current_first == current_last if no block crossing occurs, and # only one state will be stored # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: - current_last_idx = cdiv(context_lens + seq_lens_pending, - mamba_block_size) - 1 + current_last_idx = ( + cdiv(context_lens + seq_lens_pending, mamba_block_size) - 1 + ) current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1 last_state_idx = cdiv(context_lens, mamba_block_size) - 1 # -1 in case it's non-computed and causes later issues with indexing - last_state_idx = \ - last_state_idx.clamp(min=0) + last_state_idx = last_state_idx.clamp(min=0) else: # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, - 0] + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # Additional cache-related varaiables: current_last_idx = None last_state_idx = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) # Compute seq_idx for prefill only if num_prefills > 0: - #[batch,] + # [batch,] has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) prep_initial_states = torch.any(has_initial_states_cpu).item() has_initial_states_p = has_initial_states_cpu.to( - common_attn_metadata.query_start_loc.device) + common_attn_metadata.query_start_loc.device + ) - query_start_loc_p = common_attn_metadata.query_start_loc[ - -num_prefills - 1:] - num_decode_tokens + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) if self.vllm_config.cache_config.enable_prefix_caching: assert context_lens is not None - context_lens_p = context_lens[num_reqs - num_prefills:num_reqs] + context_lens_p = context_lens[num_reqs - num_prefills : num_reqs] assert last_computed_offset is not None last_computed_offset_p = last_computed_offset[ - num_reqs - num_prefills:num_reqs] + num_reqs - num_prefills : num_reqs + ] assert current_first_idx is not None - current_first_idx_p = current_first_idx[num_reqs - - num_prefills:num_reqs] + current_first_idx_p = current_first_idx[ + num_reqs - num_prefills : num_reqs + ] - num_computed_tokens_p = \ - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills:num_reqs] - query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ - -num_prefills - 1:] - num_decode_tokens + num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + query_start_loc_p_cpu = ( + common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) # The code below carefully constructs the chunks such that: # 1. Chunks contain tokens from a *single* sequence only. @@ -271,8 +291,10 @@ class Mamba2AttentionMetadataBuilder( seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() - this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item( - ) - query_start_loc_p_cpu[req_idx].item() + this_new_tokens = ( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off @@ -280,8 +302,10 @@ class Mamba2AttentionMetadataBuilder( seq_idx.append(req_idx) cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? - chunk_len = cdiv(this_num_computed, self.chunk_size - ) * self.chunk_size - this_num_computed + chunk_len = ( + cdiv(this_num_computed, self.chunk_size) * self.chunk_size + - this_num_computed + ) # we can only use at most this_new_tokens chunk_len = min(chunk_len, this_new_tokens) seqlen_pos += chunk_len @@ -300,40 +324,40 @@ class Mamba2AttentionMetadataBuilder( cu_chunk_seqlen.append(seqlen_pos) - seq_idx_p = torch.as_tensor(seq_idx, - device=query_start_loc_p.device, - dtype=torch.int32) + seq_idx_p = torch.as_tensor( + seq_idx, device=query_start_loc_p.device, dtype=torch.int32 + ) cu_chunk_seqlen_p = torch.as_tensor( - cu_chunk_seqlen, - device=query_start_loc_p.device, - dtype=torch.int32) + cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32 + ) last_chunk_indices_p = torch.as_tensor( - last_chunk_indices, - device=query_start_loc_p.device, - dtype=torch.int32) + last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32 + ) - nums_dict, batch_ptr, token_chunk_offset_ptr = \ + nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(query_start_loc_p) + ) elif num_decodes <= self.decode_cudagraph_max_bs: # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) - self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, - non_blocking=True) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID if self.vllm_config.cache_config.enable_prefix_caching: - self.current_last_idx[:num_decodes].copy_(current_last_idx, - non_blocking=True) - current_last_idx = \ - self.current_last_idx[:num_input_tokens] + self.current_last_idx[:num_decodes].copy_( + current_last_idx, non_blocking=True + ) + current_last_idx = self.current_last_idx[:num_input_tokens] current_last_idx[num_decodes:] = 0 - self.last_state_idx[:num_decodes].copy_(last_state_idx, - non_blocking=True) - last_state_idx = \ - self.last_state_idx[:num_input_tokens] + self.last_state_idx[:num_decodes].copy_( + last_state_idx, non_blocking=True + ) + last_state_idx = self.last_state_idx[:num_input_tokens] last_state_idx[num_decodes:] = 0 attn_metadata = Mamba2AttentionMetadata( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index ef342ce421..5aafb9813d 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,9 +7,11 @@ from typing import ClassVar, TypeVar import torch from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec M = TypeVar("M") @@ -17,35 +19,44 @@ M = TypeVar("M") class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): reorder_batch_threshold: int = 1 - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) + self.compilation_config.max_capture_size, + ) self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs == m.num_actual_tokens, ( + "Mamba only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) m.max_query_len = 1 # decode-only diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 963f1c5abf..a266f89bbb 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -197,9 +197,12 @@ from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, +) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states @@ -207,21 +210,26 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod, +) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.utils.flashinfer import has_nvidia_artifactory -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -231,26 +239,29 @@ except ImportError: try: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper - from flashinfer.prefill import ( # noqa: F401 - cudnn_batch_prefill_with_kv_cache) + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401 + flashinfer_available = True except ImportError: flashinfer_available = False def is_rocm_aiter_fp8bmm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_FP8BMM \ + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_FP8BMM and envs.VLLM_ROCM_USE_AITER + ) if is_rocm_aiter_fp8bmm_enabled(): from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant - as aiter_triton_fp8_bmm) + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, + ) def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + ): DTYPE_MAX = torch.finfo(dtype).max min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) @@ -265,7 +276,6 @@ CUDNN_WORKSPACE_SIZE = 12800 class MLACommonBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod @@ -307,12 +317,13 @@ class MLACommonBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @dataclass class MLACommonPrefillMetadata: - """ Prefill Specific Metadata """ + """Prefill Specific Metadata""" @dataclass class ChunkedContextMetadata: @@ -340,16 +351,15 @@ class MLACommonPrefillMetadata: @dataclass class FlashInferPrefillMetadata(MLACommonPrefillMetadata): - prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None - prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( - default_factory=list) + prefill_main: Optional["BatchPrefillWithRaggedKVCacheWrapper"] = None + prefill_chunks: list["BatchPrefillWithRaggedKVCacheWrapper"] = field( + default_factory=list + ) @dataclass class CudnnPrefillMetadata(MLACommonPrefillMetadata): - - class ChunkedContextMetadata( - MLACommonPrefillMetadata.ChunkedContextMetadata): + class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor query_seq_lens: Optional[torch.Tensor] = None @@ -372,6 +382,7 @@ class MLACommonMetadata(Generic[D]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -398,9 +409,9 @@ class MLACommonMetadata(Generic[D]): head_dim: Optional[int] = None decode: Optional[D] = None - prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata, - CudnnPrefillMetadata]] = None + prefill: Optional[ + Union[MLACommonPrefillMetadata, FlashInferPrefillMetadata, CudnnPrefillMetadata] + ] = None def __post_init__(self): if self.head_dim is not None: @@ -414,15 +425,21 @@ A = TypeVar("A") def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. - return (not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available - and not envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100)) + return ( + not envs.VLLM_DISABLE_FLASHINFER_PREFILL + and flashinfer_available + and not envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + ) def use_cudnn_prefill() -> bool: - return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100) - and has_nvidia_artifactory()) + return ( + flashinfer_available + and envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + and has_nvidia_artifactory() + ) # Currently 394MB, this can be tuned based on GEMM sizes used. @@ -436,19 +453,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + reorder_batch_threshold: int = 1 @staticmethod - def determine_chunked_prefill_workspace_size( - vllm_config: VllmConfig) -> int: + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config model_config = vllm_config.model_config chunked_prefill_workspace_size = min( # Try for 8 full length request or at least 4 pages per-request - max(8 * model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * cache_config.block_size), + max( + 8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size, + ), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, # which would result in the workspace being: @@ -457,23 +476,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) - 64 * 1024) + 64 * 1024, + ) # Enforce that we enough for at least 1 page per request chunked_prefill_workspace_size = max( chunked_prefill_workspace_size, - scheduler_config.max_num_seqs * cache_config.block_size) + scheduler_config.max_num_seqs * cache_config.block_size, + ) return chunked_prefill_workspace_size - def __init__(self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[type[M]] = None): - self.metadata_cls = metadata_cls \ - if metadata_cls is not None else MLACommonMetadata + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[type[M]] = None, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) self.kv_cache_spec = kv_cache_spec scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config @@ -481,8 +505,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.compilation_config = vllm_config.compilation_config self.device = device - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() try: @@ -497,27 +520,31 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - self.chunked_prefill_workspace_size = \ + self.chunked_prefill_workspace_size = ( self.determine_chunked_prefill_workspace_size(vllm_config) + ) if self.dcp_world_size > 1: # Note(hc): The local kvcache is incomplete when DCP is triggered, # an additional kvcache allgather across the DCP group is therefore # required, so the workspace has to be enlarged by 1/DCP relative # to the original TP allocation. - assert self.chunked_prefill_workspace_size % \ - self.dcp_world_size == 0 + assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size + - self.chunked_prefill_workspace_size // self.dcp_world_size, - self.model_config.get_head_size()), + ( + self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size(), + ), dtype=self.model_config.dtype, device=device, ) else: self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), + ( + self.chunked_prefill_workspace_size, + self.model_config.get_head_size(), + ), dtype=self.model_config.dtype, device=device, ) @@ -526,23 +553,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self._use_fi_prefill = use_flashinfer_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata - if self._use_fi_prefill else CudnnPrefillMetadata - if self._use_cudnn_prefill else MLACommonPrefillMetadata) + if self._use_fi_prefill + else CudnnPrefillMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata + ) if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=device) + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) - self._fi_prefill_main: Optional[ - BatchPrefillWithRaggedKVCacheWrapper] = None - self._fi_prefill_chunks: list[ - BatchPrefillWithRaggedKVCacheWrapper] = [] + self._fi_prefill_main: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, - MLACommonImpl)) + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + ) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( @@ -561,7 +588,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._fi_prefill_main is None: self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass") + self._workspace_buffer, "NHD", backend="cutlass" + ) if has_context: num_chunks = chunked_context.cu_seq_lens.shape[0] @@ -570,7 +598,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): for _ in range(len(self._fi_prefill_chunks), num_chunks): self._fi_prefill_chunks.append( BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass")) + self._workspace_buffer, "NHD", backend="cutlass" + ) + ) assert num_chunks <= len(self._fi_prefill_chunks) # In MLA, the non-latent num_qo_heads == num_kv_heads @@ -581,8 +611,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): assert self.kv_cache_spec.num_kv_heads == 1 # Get non-latent head_dim_qk and head_dim_vo - head_dim_qk = (self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim) + head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim head_dim_vo = self.mla_dims.v_head_dim # For main run, qo_indptr == kv_indptr @@ -618,45 +647,50 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): causal=False, # This is context run sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, - logits_soft_cap=self._global_hyperparameters. - logits_soft_cap, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> MLACommonDecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs <= (m.num_actual_tokens * - self.reorder_batch_threshold), \ - "MLA only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + "MLA only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -676,18 +710,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - - query_seq_lens_cpu) + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] \ - // self.dcp_world_size + (self.dcp_rank <= \ - (seq_lens[:num_decodes] - 1) % self.dcp_world_size) + seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + ( + self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -698,13 +733,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] # Note(hc): The context lengths in the perspective of dcp rank0. - cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / - self.dcp_world_size).int() + cp_context_lens_cpu = torch.ceil( + context_lens_cpu.float() / self.dcp_world_size + ).int() origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) chunked_context_metadata = None if max_context_len_cpu > 0: @@ -716,16 +753,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_and_maybe_dequant_cache` kernel # cannot handle `context_chunk_starts` that are not aligned # to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -736,22 +773,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # Note(simon): this is done in CPU because of downstream's # of `to_list`. - chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) if self.dcp_world_size > 1: # Note(hc): The above max_context_chunk already enforces @@ -760,36 +798,37 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. assert max_context_chunk % self.dcp_world_size == 0 - cp_max_context_chunk = max_context_chunk // \ - self.dcp_world_size - cp_chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + cp_max_context_chunk = max_context_chunk // self.dcp_world_size + cp_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * cp_max_context_chunk + ) cp_chunk_ends = torch.min( cp_context_lens_cpu.unsqueeze(0), - cp_chunk_starts + cp_max_context_chunk) - cp_chunk_seq_lens = (cp_chunk_ends - - cp_chunk_starts).clamp(min=0) + cp_chunk_starts + cp_max_context_chunk, + ) + cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0) - cp_cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(cp_chunk_seq_lens, - dim=1, - out=cp_cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + cp_cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) - chunked_context_metadata_cls = \ - CudnnPrefillMetadata.ChunkedContextMetadata \ - if self._use_cudnn_prefill else \ - MLACommonPrefillMetadata.ChunkedContextMetadata + chunked_context_metadata_cls = ( + CudnnPrefillMetadata.ChunkedContextMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata.ChunkedContextMetadata + ) if self.dcp_world_size > 1: - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu \ - .to(device, non_blocking=True), + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=cp_chunk_starts.to(device, non_blocking=True), seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -797,16 +836,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): workspace=self.chunked_prefill_workspace, cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), origin_context_lens=origin_context_lens, - cp_cu_seq_lens=cp_cu_seq_lens_cpu \ - .to(device, non_blocking=True), + cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True), chunk_size=max_context_chunk, cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), ) else: - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu \ - .to(device, non_blocking=True), + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -817,8 +853,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens - assert max(chunked_context_metadata.max_seq_lens) <= \ - self.chunked_prefill_workspace_size + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], @@ -829,8 +867,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._use_cudnn_prefill: assert isinstance(prefill_metadata, CudnnPrefillMetadata) - prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ - - prefill_query_start_loc[:-1] + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) prefill_metadata.cudnn_workspace = self.cudnn_workspace decode_metadata = None @@ -839,8 +878,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], seq_lens_device=seq_lens[:num_decodes], - query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], - query_start_loc_device=query_start_loc[:num_decodes + 1], + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, ) @@ -897,12 +936,14 @@ def reorg_kvcache( k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, - origin_context_lens): + for cp_chunk_seq_len, origin_context_len in zip( + cp_chunk_seq_lens_lst, origin_context_lens + ): chunk_context_len = chunk_size if cp_chunk_seq_len != 0: chunk_context_len = min( - chunk_context_len, origin_context_len - chunk_size * chunk_idx) + chunk_context_len, origin_context_len - chunk_size * chunk_idx + ) cp_target_rank = (chunk_context_len - 1) % cp_world_size cur_seq_len = 0 for rank in range(cp_world_size): @@ -911,14 +952,16 @@ def reorg_kvcache( else: real_cp_chunk_seq_len = cp_chunk_seq_len if real_cp_chunk_seq_len: - kv_c_segment = allgatered_kv_c_normed[rank * toks + - src_token_idx:rank * - toks + src_token_idx + - real_cp_chunk_seq_len] - k_pe_segment = allgatered_k_pe[rank * toks + - src_token_idx:rank * toks + - src_token_idx + - real_cp_chunk_seq_len] + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) cur_seq_len += real_cp_chunk_seq_len @@ -983,25 +1026,24 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): self.q_pad_num_heads = q_pad_num_heads def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -1013,12 +1055,14 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1026,15 +1070,18 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) + W_K, dtype=current_platform.fp8_dtype() + ) self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + W_V, dtype=current_platform.fp8_dtype() + ) # The kernel operates on non-padded inputs. Hence, pre-compiling # triton kernel to avoid runtime compilation for unseen batch sizes @@ -1050,23 +1097,23 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ) for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -1078,11 +1125,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) if is_rocm_aiter_fp8bmm_enabled(): # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) # Copy result @@ -1095,8 +1140,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape( - -1, self.num_heads * self.v_head_dim) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) # Adjust output buffer shape back to the original (B, N * V) N, B, V = out.shape @@ -1120,8 +1164,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): self._pad_v = False elif use_cudnn_prefill(): logger.debug_once("Using CUDNN prefill for MLA") - self._run_prefill_context_chunk = \ - self._run_prefill_context_chunk_cudnn + self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention @@ -1136,9 +1179,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) + self.flash_attn_varlen_func = functools.partial( + flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + ) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do @@ -1146,25 +1189,25 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # We don't need to pad V if we are on a hopper system with FA3 self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) + and current_platform.get_device_capability()[0] == 9 + ) self.dcp_world_size: Optional[int] = None - self.chunked_prefill_workspace_size = \ + self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config()) + get_current_vllm_config() + ) + ) - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse @@ -1192,8 +1235,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out, lse return attn_out - def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fa( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1207,8 +1251,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return_softmax_lse=return_softmax_lse, ) - def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fi( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None ret = prefill.prefill_main.run( @@ -1223,8 +1268,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret[0], ret[1].transpose(0, 1).contiguous() return ret - def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, - q, k, v, return_softmax_lse): + def _run_prefill_new_tokens_cudnn( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None output, lse = cudnn_batch_prefill_with_kv_cache( @@ -1239,15 +1285,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), causal=True, return_lse=True, # do not support False for now - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, # Indicates actual_seq_lens are on GPU or CPU. ) if return_softmax_lse: return output, lse return output - def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fa( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert prefill.chunked_context is not None return self._flash_attn_varlen_diff_headdims( q=q, @@ -1262,8 +1308,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return_softmax_lse=True, ) - def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fi( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, FlashInferPrefillMetadata) attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, @@ -1274,9 +1321,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() - def _run_prefill_context_chunk_cudnn(self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_cudnn( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None @@ -1290,34 +1337,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): max_token_per_sequence=prefill.max_query_len, max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), - actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. - view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( + -1, 1, 1, 1 + ), causal=False, return_lse=True, - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, # Indicates actual_seq_lens are on GPU or CPU. ) def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -1329,12 +1375,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1342,15 +1390,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) + W_K, dtype=current_platform.fp8_dtype() + ) self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + W_V, dtype=current_platform.fp8_dtype() + ) # The kernel operates on non-padded inputs. Hence, pre-compiling # triton kernel to avoid runtime compilation for unseen batch sizes @@ -1366,23 +1417,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -1418,18 +1469,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): seq_starts=prefill_metadata.chunked_context.starts[i], ) - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1494,44 +1542,45 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # |------- N tokens --------|--------- N*dcp_size tokens ----------| # |<- use for loca_gather ->|<--------- use for allgather -------->| allgather_offset = workspace.shape[0] // (dcp_world_size + 1) - assert allgather_offset * (dcp_world_size + - 1) == workspace.shape[0] + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] assert toks <= allgather_offset local_gathered_kvcache = workspace[:toks] cur_allgather_workspace = workspace[ - allgather_offset:allgather_offset * (1 + dcp_world_size)] + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] - cur_allgather_kvcache = cur_allgather_workspace[:toks * - dcp_world_size] - cur_allgather_kvcache.copy_(get_dcp_group().all_gather( - local_gathered_kvcache, dim=0)) - assert cur_allgather_kvcache.shape[ - -1] == self.kv_lora_rank + self.qk_rope_head_dim - allgatered_kv_c_normed, allgatered_k_pe = \ - cur_allgather_kvcache.unsqueeze( - 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. - cp_chunk_seq_lens[i], - origin_context_lens=prefill_metadata.chunked_context. - origin_context_lens, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + i + ], + origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, cp_world_size=dcp_world_size, - sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] - [-1], + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], chunk_size=prefill_metadata.chunked_context.chunk_size, chunk_idx=i, - toks=toks) + toks=toks, + ) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1574,10 +1623,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) @@ -1592,14 +1641,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output if self.dcp_world_size > 1: - context_output, context_lse = \ + context_output, context_lse = ( self._context_parallel_compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, - k_scale=None, dcp_world_size=self.dcp_world_size) + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) else: - context_output, context_lse = \ - self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + context_output, context_lse = self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1612,7 +1666,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # unpad if necessary if self._pad_v: - output = output[..., :v.shape[-1]] + output = output[..., : v.shape[-1]] return output.flatten(start_dim=-2) @@ -1642,16 +1696,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: # During the profile run try to simulate to worse case output size # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # since this can be large _ = torch.empty( - (self.chunked_prefill_workspace_size, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim), + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), device=k_c_normed.device, dtype=k_c_normed.dtype, ) @@ -1675,9 +1732,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1705,39 +1764,47 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if has_prefill: output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + ) if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape - decode_pe_padded = decode_q_pe.new_empty( - (B, self.q_pad_num_heads, L)) + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) decode_pe_padded.resize_((B, N, L)) decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded if is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) + decode_ql_nope = aiter_triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) else: # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape _, _, L = self.W_UK_T.shape if self.q_pad_num_heads is not None: decode_ql_nope = decode_q_nope.new_empty( - (self.q_pad_num_heads, B, L)) + (self.q_pad_num_heads, B, L) + ) decode_ql_nope.resize_((N, B, L)) else: @@ -1751,15 +1818,17 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if fp8_attention: ql_nope_shape = decode_ql_nope.shape decode_ql_nope, _ = ops.scaled_fp8_quant( - decode_ql_nope.reshape([ - ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] - ]), layer._q_scale) + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape( - [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), - layer._q_scale) + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) decode_q_pe = decode_q_pe.reshape(q_pe_shape) decode_q = (decode_ql_nope, decode_q_pe) @@ -1771,8 +1840,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): decode_q = get_dcp_group().all_gather(decode_q, dim=1) # call decode attn - attn_out, lse = self._forward_decode(decode_q, kv_cache, - attn_metadata, layer) + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer + ) # recorect dcp attn_out with lse. if self.dcp_world_size > 1: diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index d44e20f2cb..a3c677ca21 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,13 +7,18 @@ from typing import ClassVar, Optional, Union import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) @@ -21,12 +26,12 @@ logger = init_logger(__name__) class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable full CUDA Graph support for decode-only capture - cudagraph_support: ClassVar[ - AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) class CutlassMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -41,11 +46,10 @@ class CutlassMLABackend(MLACommonBackend): class SM100Workspace: - def __init__(self, initial_workspace_size): - self._workspace_buf = torch.empty(initial_workspace_size, - device="cuda", - dtype=torch.uint8) + self._workspace_buf = torch.empty( + initial_workspace_size, device="cuda", dtype=torch.uint8 + ) self._block_size = 128 # Forced to 128 @@ -57,8 +61,7 @@ class SM100Workspace: def get_buf(self): return self._workspace_buf - def ensure_size(self, attn_metadata: MLACommonMetadata, - num_kv_splits: int): + def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int): batch_size = attn_metadata.num_reqs max_seq_len = attn_metadata.max_query_len @@ -66,7 +69,8 @@ class SM100Workspace: max_seq_len * self._block_size, batch_size, self._sm_count, - num_kv_splits=num_kv_splits) + num_kv_splits=num_kv_splits, + ) if self._workspace_buf.shape[0] < workspace_size: self._workspace_buf.resize_(workspace_size) @@ -81,51 +85,56 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - sliding_window, - kv_cache_dtype, - logits_soft_cap, - attn_type, - kv_sharing_target_layer_name, - q_pad_num_heads=MAX_HEADS, - **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + q_pad_num_heads=MAX_HEADS, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "CutlassMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "CutlassMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CutlassMLAImpl" + ) # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging # issues. In case the code hangs, use: # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning_once("Forcing num_kv_splits to %d", - int(force_num_kv_splits)) + logger.warning_once("Forcing num_kv_splits to %d", int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -144,14 +153,13 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale: float, num_kv_splits: int, ) -> tuple[torch.Tensor, torch.Tensor]: - assert (q_nope.ndim == 3 - ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" - assert ( - q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}" - assert ( - kv_c_and_k_pe_cache.ndim == 3 - ), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( - kv_c_and_k_pe_cache.ndim) + assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" + assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" + assert kv_c_and_k_pe_cache.ndim == 3, ( + "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( + kv_c_and_k_pe_cache.ndim + ) + ) B_q, H, D_q_nope = q_nope.shape B_q_2, H_2, D_q_pe = q_pe.shape @@ -171,28 +179,31 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape assert B_block_table == B_q - assert (block_num - > 0), f"block num must be greater than 0, got {block_num}" + assert block_num > 0, f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 - assert q_nope.dtype in ( - torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( - f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got " - f"{q_nope.dtype}.") + assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( + f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}." + ) assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype - assert ( - seq_lens.dtype == torch.int32 - ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." - assert ( - page_table.dtype == torch.int32 - ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." + assert seq_lens.dtype == torch.int32, ( + f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." + ) + assert page_table.dtype == torch.int32, ( + f"page_table.dtype needs to be int32 but got {page_table.dtype}." + ) - dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) - else q_nope.dtype) + dtype = ( + torch.bfloat16 + if is_quantized_kv_cache(self.kv_cache_dtype) + else q_nope.dtype + ) out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) - lse = (torch.empty( - (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) - if self.need_to_return_lse_for_decode else torch.Tensor()) + lse = ( + torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode + else torch.Tensor() + ) ops.sm100_cutlass_mla_decode( out, @@ -228,7 +239,8 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_nope, q_pe = q else: q_nope, q_pe = torch.split( - q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 652b1cdb6b..c0c2dbe1f9 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -7,18 +7,25 @@ from typing import ClassVar, Optional, Union import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) -from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, - get_flash_attn_version) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, + get_flash_attn_version, +) from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata @@ -27,7 +34,6 @@ logger = init_logger(__name__) class FlashAttnMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" @@ -59,22 +65,27 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): pass -class FlashAttnMLAMetadataBuilder( - MLACommonMetadataBuilder[FlashAttnMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH +class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 512 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashAttnMLAMetadata) + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata + ) self.max_num_splits = 0 # No upper bound on the number of splits. - self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.fa_aot_schedule = get_flash_attn_version() == 3 - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) if self.use_full_cuda_graph and self.fa_aot_schedule: self.max_cudagraph_size = self.compilation_config.max_capture_size @@ -83,8 +94,8 @@ class FlashAttnMLAMetadataBuilder( # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. raise ValueError( - "Capture size larger than 992 is not supported for " - "full cuda graph.") + "Capture size larger than 992 is not supported for full cuda graph." + ) self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -94,16 +105,17 @@ class FlashAttnMLAMetadataBuilder( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = ( - envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. - self.reorder_batch_threshold = 1 \ - if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + self.reorder_batch_threshold = ( + 1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + ) - def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def _schedule_decode( + self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): if self.fa_aot_schedule: return get_scheduler_metadata( batch_size=num_reqs, @@ -122,13 +134,16 @@ class FlashAttnMLAMetadataBuilder( ) return None - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: - query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> FlashAttnMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_cpu.max().item() @@ -146,9 +161,10 @@ class FlashAttnMLAMetadataBuilder( if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] # Ensure the persistent buffer is large enough - assert n <= self.scheduler_metadata.shape[0], \ - f"Scheduler metadata size {n} exceeds buffer size " + \ - f"{self.scheduler_metadata.shape[0]}" + assert n <= self.scheduler_metadata.shape[0], ( + f"Scheduler metadata size {n} exceeds buffer size " + + f"{self.scheduler_metadata.shape[0]}" + ) self.scheduler_metadata[:n] = scheduler_metadata # NOTE(woosuk): We should zero out the rest of the scheduler # metadata to guarantee the correctness. Otherwise, some thread @@ -179,42 +195,55 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) - assert flash_attn_supports_mla(), \ - "FlashAttnMLA is not supported on this device" + assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device" unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashAttnMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttnMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttnMLAImpl" + ) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "FlashAttnMLA V1 with FP8 KV cache not yet supported") + "FlashAttnMLA V1 with FP8 KV cache not yet supported" + ) def _forward_decode( self, @@ -230,14 +259,14 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): q_nope, q_pe = q else: q_nope, q_pe = torch.split( - q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError( - "FP8 FlashAttention MLA not yet supported") + raise NotImplementedError("FP8 FlashAttention MLA not yet supported") - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the # kernel uses this to calculate grid dimensions. Ensure it's at least 1 diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 701248670f..f0ea1d653c 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -8,9 +8,11 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) logger = init_logger(__name__) @@ -18,7 +20,6 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 class FlashInferMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @@ -36,37 +37,49 @@ g_fi_workspace = torch.zeros( class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashInferMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl" + ) self._workspace_buffer = g_fi_workspace self.bmm1_scale: Optional[float] = None @@ -90,8 +103,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): q = q.unsqueeze(1) if self.bmm1_scale is None: - self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * - self.scale) + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 67c21f83cf..56480832bc 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -7,16 +7,20 @@ from typing import ClassVar, Optional, Union import torch from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -24,7 +28,6 @@ logger = init_logger(__name__) class FlashMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASHMLA" @@ -54,16 +57,22 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashMLAMetadata) + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata + ) self.num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None @@ -82,19 +91,22 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_num_splits = torch.empty( (vllm_config.scheduler_config.max_num_seqs + 1), device=self.device, - dtype=torch.int32) + dtype=torch.int32, + ) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashMLADecodeMetadata: - tile_scheduler_metadata, num_splits = \ - get_mla_metadata( + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, self.num_q_heads, - 1, # MQA for the decode path + 1, # MQA for the decode path ) # TODO: we can disambiguate between decode and mixed-prefill decode here @@ -107,8 +119,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): sm_parts = tile_scheduler_metadata.size(0) # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) - tile_scheduler_metadata_view = \ - self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[ + :sm_parts + ] tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) tile_scheduler_metadata = tile_scheduler_metadata_view @@ -133,27 +146,36 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) is_supported, reason = is_flashmla_supported() assert is_supported, reason @@ -162,13 +184,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl" + ) def _forward_decode( self, @@ -191,8 +216,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata.decode. - tile_scheduler_metadata, + tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata, num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 36c3c18804..21d67f832b 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -8,21 +8,28 @@ import numpy as np import torch from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, +) from vllm.attention.backends.utils import get_mla_dims -from vllm.attention.ops.flashmla import (flash_mla_sparse_prefill, - flash_mla_with_kvcache, - get_mla_metadata) +from vllm.attention.ops.flashmla import ( + flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -47,11 +54,10 @@ structured as: def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: # Convert base-2 LSE to natural-log LSE # Keep FP32 for numerical stability during the merge. - return (lse_base2.to(torch.float32) * math.log(2.0)) + return lse_base2.to(torch.float32) * math.log(2.0) class FlashMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod @@ -113,13 +119,14 @@ class FlashMLASparseDecodeAndContextMetadata: dummy_block_table: torch.Tensor = None def filter_prefill_indices( - self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, indices: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.prefill_context_lengths is not None prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) - context_indices = torch.where(indices < prefill_context_lengths, - indices, -1) - new_token_indices = torch.where(indices >= prefill_context_lengths, - indices - prefill_context_lengths, -1) + context_indices = torch.where(indices < prefill_context_lengths, indices, -1) + new_token_indices = torch.where( + indices >= prefill_context_lengths, indices - prefill_context_lengths, -1 + ) return context_indices, new_token_indices @@ -194,8 +201,9 @@ def _convert_req_index_to_global_index_kernel( base = tl.load(bt_ptr, mask=valid_block, other=0) # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where(is_invalid_tok | (~valid_block), -1, - base * BLOCK_SIZE + inblock_off) + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 @@ -203,31 +211,30 @@ def _convert_req_index_to_global_index_kernel( def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch. - Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns ): """ out[token_id, indice_id] = - block_table[req_id[token_id], + block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be + For safety, we also output -1 if the derived block_id would be out-of-bounds. """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ - f"BLOCK_N ({BLOCK_N})" + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) num_tokens = req_id.shape[0] num_requests, max_num_blocks_per_req = block_table.shape @@ -268,14 +275,16 @@ def triton_convert_req_index_to_global_index( @dataclass -class FlashMLASparseMetadataBuilder( - AttentionMetadataBuilder[FlashMLASparseMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): +class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config @@ -285,28 +294,27 @@ class FlashMLASparseMetadataBuilder( props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" - self.topk_tokens_tensor = torch.tensor([self.topk_tokens], - device=device, - dtype=torch.int32) + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 + ) self.max_model_len_tensor = torch.tensor( - [self.model_config.max_model_len], - device=device, - dtype=torch.int32) + [self.model_config.max_model_len], device=device, dtype=torch.int32 + ) # this is ignored by `flash_mla_with_kvcache` if indices not None - self.dummy_block_table = torch.empty((1, 1), - dtype=torch.int32, - device=self.device) + self.dummy_block_table = torch.empty( + (1, 1), dtype=torch.int32, device=self.device + ) # Equation taken from FlashMLA/csrc/pybind.cpp h_q, h_k = self.num_heads, 1 s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest max_num_sm_parts = int( - max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) + ) if current_platform.is_device_capability(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( @@ -314,34 +322,38 @@ class FlashMLASparseMetadataBuilder( # see: FlashMLA/csrc/params.h (max_num_sm_parts, 8), dtype=torch.int32, - device=device) + device=device, + ) self.num_splits_buffer = torch.empty( # We pack all the tokens into one batch for sparse attention. # Otherwise, we can exceed the sm of `get_mla_metadata`. - ( - 2, ), + (2,), dtype=torch.int32, - device=device) + device=device, + ) self.req_id_per_token_buffer = torch.empty( - (vllm_config.scheduler_config.max_num_batched_tokens, ), + (vllm_config.scheduler_config.max_num_batched_tokens,), dtype=torch.int32, - device=device) - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashMLASparseMetadata: + device=device, + ) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: num_tokens = common_attn_metadata.num_actual_tokens - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, - dtype=np.int32) + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( - np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) - self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ - .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata = None @@ -357,8 +369,9 @@ class FlashMLASparseMetadataBuilder( num_sm_parts = tile_scheduler_metadata.size(0) # Copy to persistent buffer for full-CG support - tile_scheduler_metadata_buffer = \ - self.tile_scheduler_metadata_buffer[:num_sm_parts] + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) self.num_splits_buffer.copy_(num_splits) @@ -371,7 +384,8 @@ class FlashMLASparseMetadataBuilder( # accidentally mark indices invalid, we will use -1 exclusively # to mark invalid indices cache_lens=self.max_model_len_tensor, - dummy_block_table=self.dummy_block_table) + dummy_block_table=self.dummy_block_table, + ) metadata = FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -390,62 +404,79 @@ class FlashMLASparseMetadataBuilder( class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - topk_indice_buffer: Optional[torch.Tensor] = None, - indexer: Optional["Indexer"] = None, - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + topk_indice_buffer: Optional[torch.Tensor] = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability( - 100) else 64 + self.padding = 128 if current_platform.is_device_capability(100) else 64 def _forward_bf16_kv( - self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1]) + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 - logger.warning_once(f"padding num_heads to {self.padding} \ - due to sparse attn kernel requirement") + logger.warning_once( + f"padding num_heads to {self.padding} \ + due to sparse attn kernel requirement" + ) q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) - q_padded[:, :self.num_heads, :] = q + q_padded[:, : self.num_heads, :] = q q = q_padded topk_indices = topk_indices.view(num_tokens, 1, -1) - output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices, - self.softmax_scale)[0] - output = output[:, :self.num_heads, :] + output = flash_mla_sparse_prefill( + q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale + )[0] + output = output[:, : self.num_heads, :] return output - def _forward_fp8_kv(self, q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: - + def _forward_fp8_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: assert attn_metadata.fp8_extra_metadata is not None extra_metadata = attn_metadata.fp8_extra_metadata @@ -483,8 +514,8 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: # The zero fill is required when used with DP + EP @@ -500,8 +531,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) @@ -534,11 +564,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ) if self.kv_cache_dtype != "fp8_ds_mla": - attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global, - attn_metadata) + attn_out = self._forward_bf16_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) else: - attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, - attn_metadata) + attn_out = self._forward_fp8_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 94b963f34e..1344840af6 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -5,21 +5,21 @@ from typing import ClassVar, Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): - @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return DeepseekV32IndexerMetadata @@ -76,7 +76,6 @@ class DeepSeekV32IndexerDecodeMetadata: @dataclass class DeepseekV32IndexerMetadata: - # FIXME (zyongye) # hacky way to access the data now, need to be in chunked meta seq_lens: torch.Tensor @@ -104,27 +103,27 @@ class DeepseekV32IndexerMetadata: # TODO (zyongye) optimize this, this is now vibe coded def kv_spans_from_batches( - start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, - device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: """ Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. - Example: [0, 2, 4, 7] -> + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. Example: [5, 9, 4]. Returns: - start_tensor: 1D long tensor [N], start offset in the + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. (So the attended KV slice is kv[start:end].) - Assumes each batch contributes its full `seq_len_per_batch[i]` - keys to the KV cache, andthe selected tokens within a batch + Assumes each batch contributes its full `seq_len_per_batch[i]` + keys to the KV cache, andthe selected tokens within a batch are the **last** `counts[i]` positions of that sequence. """ q = start_seq_loc.to(dtype=torch.long) @@ -138,8 +137,10 @@ def kv_spans_from_batches( B = L.numel() if N == 0: - return (torch.empty(0, dtype=torch.long, device=device), - torch.empty(0, dtype=torch.long, device=device)) + return ( + torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) # KV start offsets per batch in the concatenated KV cache kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] @@ -155,8 +156,9 @@ def kv_spans_from_batches( L_expand = torch.repeat_interleave(L, counts) # [N] m_expand = torch.repeat_interleave(counts, counts) # [N] # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, dtype=torch.long) - - torch.repeat_interleave(q[:-1], counts) + 1) + pos_within = ( + torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1 + ) local_pos = L_expand - m_expand + pos_within # [N], 1-based end_location = start_tensor + local_pos # exclusive end @@ -171,9 +173,9 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): return max_model_len * 2 -def split_prefill_chunks(seq_lens_cpu: torch.Tensor, - max_prefill_buffer_size: int, - reqs_start: int) -> list[tuple[int, int]]: +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int +) -> list[tuple[int, int]]: """ Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) such that the total sequence length of each chunk is less than the @@ -183,7 +185,7 @@ def split_prefill_chunks(seq_lens_cpu: torch.Tensor, seq_lens_cpu: The sequence lengths of the prefill requests. max_prefill_buffer_size: The maximum prefill buffer size. reqs_start: The start index of the prefill requests. - + Returns: A list of tuples of (reqs_start, reqs_end). """ @@ -203,20 +205,22 @@ def split_prefill_chunks(seq_lens_cpu: torch.Tensor, class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) reorder_batch_threshold: int = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) scheduler_config = self.vllm_config.scheduler_config - #NOTE(Chen):an estimated max size of flattened_kv. Need to double check. - self.max_prefill_buffer_size = get_max_prefill_buffer_size( - self.vllm_config) + # NOTE(Chen):an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0) + if self.vllm_config.speculative_config + else 0 + ) # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) @@ -225,31 +229,38 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): self.num_sms = sm_count self.decode_lens_buffer = torch.empty( - (scheduler_config.max_num_seqs, ), - dtype=torch.int32, - device=self.device) + (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device + ) # See: DeepGMM/csrc/apis/attention.hpp - self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2), - dtype=torch.int32, - device=self.device) + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) - def build_one_prefill_chunk(self, reqs_start, reqs_end, - query_start_loc_cpu, seq_lens_cpu, - block_table): - prefill_query_start_loc = query_start_loc_cpu[ - reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] + def build_one_prefill_chunk( + self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table + ): + prefill_query_start_loc = ( + query_start_loc_cpu[reqs_start : reqs_end + 1] + - query_start_loc_cpu[reqs_start] + ) cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], - self.device) + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + ) token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() assert total_seq_lens <= self.max_prefill_buffer_size - cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32), - seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) - ]).to(torch.int32).to(self.device) + cu_seq_lens = ( + torch.cat( + [ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + ] + ) + .to(torch.int32) + .to(self.device) + ) return DeepseekV32IndexerPrefillChunkMetadata( cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ke=cu_seqlen_ke, @@ -261,19 +272,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): num_reqs=reqs_end - reqs_start, ) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> DeepseekV32IndexerMetadata: - + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> DeepseekV32IndexerMetadata: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -287,33 +300,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ) chunks = [ self.build_one_prefill_chunk( - reqs_start, reqs_end, query_start_loc_cpu, + reqs_start, + reqs_end, + query_start_loc_cpu, common_attn_metadata.seq_lens_cpu, - common_attn_metadata.block_table_tensor) + common_attn_metadata.block_table_tensor, + ) for reqs_start, reqs_end in chunk_seq_ids ] prefill_metadata = DeepseekV32IndexerPrefillMetadata( - chunks=chunks, ) + chunks=chunks, + ) decode_metadata = None if num_decodes > 0: - torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], - out=self.decode_lens_buffer[:num_decodes]) + torch.diff( + common_attn_metadata.query_start_loc[: num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes], + ) decode_lens = self.decode_lens_buffer[:num_decodes] decode_lens_cpu = torch.diff( - common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) + common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] + ) # Use CPU to avoid GPU sync; breaking async scheduling - requires_padding = (decode_lens_cpu.max() - > decode_lens_cpu.min()).item() + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() seq_lens = common_attn_metadata.seq_lens[:num_decodes] self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, self.num_sms) + seq_lens, self.kv_cache_spec.block_size, self.num_sms + ) decode_metadata = DeepSeekV32IndexerDecodeMetadata( - block_table=common_attn_metadata. - block_table_tensor[:num_decodes, ...], + block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], seq_lens=common_attn_metadata.seq_lens[:num_decodes], decode_lens=decode_lens, requires_padding=requires_padding, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 79247e569b..aa9be9119d 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -11,13 +11,16 @@ from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv + # yapf conflicts with isort for this docstring # yapf: disable -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -25,12 +28,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_MLA + return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA class AiterMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "ROCM_AITER_MLA" @@ -68,19 +69,28 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - AiterMLAMetadata) - assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ - "only supports block size 1." + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata + ) + assert self.kv_cache_spec.block_size == 1, ( + "AITER MLAonly supports block size 1." + ) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -89,74 +99,78 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # so we can only use the persistent buffer if a cudagraph is actually # being used. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=device) - self.paged_kv_indices = torch.zeros(max_num_pages, - dtype=torch.int32, - device=device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=device) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=device + ) - self.qo_indptr = torch.arange(0, - max_num_reqs + 1, - dtype=torch.int32, - device=device) + self.qo_indptr = torch.arange( + 0, max_num_reqs + 1, dtype=torch.int32, device=device + ) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> AiterMLADecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device num_reqs = seq_lens_device.size(0) - mask = (torch.arange(block_table_tensor.size(1), - dtype=block_table_tensor.dtype, - device=device).unsqueeze(0) - < block_table_bounds.unsqueeze(1)) + mask = torch.arange( + block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device + ).unsqueeze(0) < block_table_bounds.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] paged_kv_last_page_len = seq_lens_device % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) + paged_kv_last_page_len = torch.where( + paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len + ) - paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, device=device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32) - ]) + paged_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32), + ] + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - num_actual_pages = paged_kv_indices.size(0) - self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, - non_blocking=True) + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True + ) self.paged_kv_indices[num_actual_pages:].fill_(-1) paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, - non_blocking=True) - self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) - paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] + self.paged_kv_indptr[: 1 + num_reqs].copy_( + paged_kv_indptr, non_blocking=True + ) + self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] self.paged_kv_last_page_len[:num_reqs].copy_( - paged_kv_last_page_len, non_blocking=True) + paged_kv_last_page_len, non_blocking=True + ) self.paged_kv_last_page_len[num_reqs:].fill_(1) paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] - qo_indptr = self.qo_indptr[:1 + num_reqs] + qo_indptr = self.qo_indptr[: 1 + num_reqs] else: - qo_indptr = torch.arange(0, - num_reqs + 1, - step=1, - dtype=torch.int32, - device=device) + qo_indptr = torch.arange( + 0, num_reqs + 1, step=1, dtype=torch.int32, device=device + ) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, @@ -164,51 +178,60 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - qo_indptr=qo_indptr) + qo_indptr=qo_indptr, + ) return attn_metadata class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - assert (num_heads == 16 or num_heads == 128), ( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + assert num_heads == 16 or num_heads == 128, ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" - "Try adjusting tensor_parallel_size value.") + "Try adjusting tensor_parallel_size value." + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): output = self.flash_attn_varlen_func( q=q, k=k, @@ -235,21 +258,25 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): assert isinstance(q, torch.Tensor) B = q.shape[0] - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + o = torch.zeros( + B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.decode.qo_indptr, max_seqlen_qo, - attn_metadata.decode.paged_kv_indptr, - attn_metadata.decode.paged_kv_indices, - attn_metadata.decode.paged_kv_last_page_len) + aiter_mla_decode_fwd( + q, + kv_buffer, + o, + self.scale, + attn_metadata.decode.qo_indptr, + max_seqlen_qo, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len, + ) return o, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 076152061d..3b6718c48d 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,22 +6,26 @@ from typing import Optional, Union import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) logger = init_logger(__name__) class TritonMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -35,54 +39,64 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl" + ) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported") + "TritonMLA V1 with FP8 KV cache not yet supported" + ) self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.triton_fa_func = triton_attention if HAS_TRITON else None - def _flash_attn_varlen_diff_headdims_rocm(self, - q, - k, - v, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims_rocm( + self, q, k, v, softmax_scale=None, **kwargs + ): assert self.triton_fa_func is not None # Triton Attention requires a padded V - padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) # The output of triton_attention is a tuple of # [output_tensor, encoded_softmax] where encoded_softmax is always None output_tensor, _ = self.triton_fa_func( @@ -101,18 +115,17 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): return output_tensor - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): - if current_platform.is_rocm() \ - and self.use_triton_flash_attn \ - and not return_softmax_lse: + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + if ( + current_platform.is_rocm() + and self.use_triton_flash_attn + and not return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims_rocm( - q, k, v, softmax_scale=softmax_scale, **kwargs) + q, k, v, softmax_scale=softmax_scale, **kwargs + ) else: return super()._flash_attn_varlen_diff_headdims( q, @@ -120,7 +133,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): v, return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, - **kwargs) + **kwargs, + ) def _forward_decode( self, @@ -141,11 +155,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): assert isinstance(q, torch.Tensor) B = q.shape[0] q_num_heads = q.shape[1] - o = torch.zeros(B, - q_num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + o = torch.zeros( + B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) num_kv_splits = 4 # TODO: heuristic @@ -165,13 +177,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse, - attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) + decode_attention_fwd( + q, + kv_c_and_k_pe_cache, + kv_c_cache, + o, + lse, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, + attn_logits, + num_kv_splits, + self.scale, + PAGE_SIZE, + ) return o, lse diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 7ac1a063f5..7e83e7a681 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -6,8 +6,12 @@ from typing import Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionType, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, next_power_of_2 @@ -41,49 +45,62 @@ except ImportError: from torch_xla.experimental.custom_kernel import XLA_LIB @requires_jax - def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, num_slices_per_block: int): + def kv_cache_update_op_impl( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ): from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( kv_cache_update, - (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) + (kv, slot_mapping, kv_cache, num_kv_update_slices), + {"page_size": page_size, "num_slices_per_block": num_slices_per_block}, + ) return new_kv_cache - XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ - "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ - "int num_slices_per_block)" \ - "-> Tensor", ) + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," + "int num_slices_per_block)" + "-> Tensor", + ) @impl(XLA_LIB, "kv_cache_update_op", "XLA") - def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) + def kv_cache_update_op_xla( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl( + kv, + slot_mapping, + kv_cache, + num_kv_update_slices, + page_size, + num_slices_per_block, + ) return new_kv_cache @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") - def kv_cache_update_op_non_xla(kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: + def kv_cache_update_op_non_xla( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ) -> torch.Tensor: return kv_cache class PallasAttentionBackend(AttentionBackend): - @staticmethod def get_name() -> str: return "PALLAS" @@ -104,8 +121,9 @@ class PallasAttentionBackend(AttentionBackend): head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) @staticmethod @@ -122,10 +140,12 @@ class PallasAttentionBackend(AttentionBackend): # we simply make sure that the size is smaller than half of SMEM capacity. @staticmethod def get_min_page_size(vllm_config: VllmConfig) -> int: - max_num_page_per_req = (1024 * 1024 // 2 // - vllm_config.scheduler_config.max_num_seqs // 4) - min_page_size = cdiv(vllm_config.model_config.max_model_len, - max_num_page_per_req) + max_num_page_per_req = ( + 1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4 + ) + min_page_size = cdiv( + vllm_config.model_config.max_model_len, max_num_page_per_req + ) min_page_size = 1 << (min_page_size - 1).bit_length() return min_page_size @@ -146,8 +166,7 @@ class PallasAttentionBackend(AttentionBackend): # handle VREG spills. if vllm_config.model_config.max_model_len > 8192: return 16 - page_size = next_power_of_2( - vllm_config.model_config.max_model_len) // 16 + page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16 if page_size <= 16: return 16 if page_size >= 256: @@ -176,7 +195,6 @@ class PallasMetadata: class PallasAttentionBackendImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -203,15 +221,18 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("Alibi slopes is not supported.") if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl" + ) self.kv_cache_quantized_dtype = None if kv_cache_dtype != "auto": self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( - kv_cache_dtype.lower().strip()) + kv_cache_dtype.lower().strip() + ) def forward( self, @@ -240,7 +261,8 @@ class PallasAttentionBackendImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for PallasAttentionBackendImpl") + " for PallasAttentionBackendImpl" + ) # For determine_available_memory case. if kv_cache.numel() == 0: @@ -253,15 +275,18 @@ class PallasAttentionBackendImpl(AttentionImpl): key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - padded_head_size = cdiv( - self.head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) query = torch.nn.functional.pad( - query, (0, padded_head_size - self.head_size), value=0.0) + query, (0, padded_head_size - self.head_size), value=0.0 + ) key = torch.nn.functional.pad( - key, (0, padded_head_size - self.head_size), value=0.0) + key, (0, padded_head_size - self.head_size), value=0.0 + ) value = torch.nn.functional.pad( - value, (0, padded_head_size - self.head_size), value=0.0) + value, (0, padded_head_size - self.head_size), value=0.0 + ) if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: # Write input keys and values to the KV cache. @@ -280,9 +305,9 @@ class PallasAttentionBackendImpl(AttentionImpl): ) if self.kv_cache_quantized_dtype is not None and ( - layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0): - raise ValueError( - "k_scale_float and v_scale_float must be non-zero") + layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0 + ): + raise ValueError("k_scale_float and v_scale_float must be non-zero") output = torch.ops.xla.ragged_paged_attention( query, kv_cache, @@ -305,7 +330,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - output = output[:, :, :self.head_size] + output = output[:, :, : self.head_size] return output.reshape(num_tokens, hidden_size) @@ -321,7 +346,7 @@ def write_to_kv_cache( k_scale: float = 1.0, v_scale: float = 1.0, ) -> None: - """ Write the key and values to the KV cache. + """Write the key and values to the KV cache. Args: key: shape = [num_tokens, num_kv_heads, head_size] @@ -330,8 +355,7 @@ def write_to_kv_cache( num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape - head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT if kv_cache_quantized_dtype is not None: dtype_info = torch.finfo(kv_cache_quantized_dtype) @@ -343,15 +367,19 @@ def write_to_kv_cache( value = torch.clamp(value, dtype_info.min, dtype_info.max) value = value.to(kv_cache_quantized_dtype) - kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, - head_size) + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, slot_mapping, kv_cache, num_kv_update_slices, page_size, - num_slices_per_kv_cache_update_block) + kv, + slot_mapping, + kv_cache, + num_kv_update_slices, + page_size, + num_slices_per_kv_cache_update_block, + ) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) @@ -389,15 +417,18 @@ def get_dtype_packing(dtype): if 32 % bits != 0: raise ValueError( f"The bit width must be divisible by 32, but got bits={bits}, " - "dtype={dtype}") + "dtype={dtype}" + ) return 32 // bits -def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, - kv_cache_dtype: torch.dtype) -> int: +def get_page_size_bytes( + block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype +) -> int: """Returns the size in bytes of one page of the KV cache.""" - padded_head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) num_combined_kv_heads = num_kv_heads * 2 # NOTE: for the implicit padding in XLA @@ -405,5 +436,6 @@ def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) - return (block_size * num_combined_kv_heads * padded_head_size * - kv_cache_dtype_bits // 8) + return ( + block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8 + ) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index ed63c7b1bd..348eca55ee 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1,19 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" + from dataclasses import dataclass from typing import Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 @@ -43,55 +50,63 @@ if current_platform.is_rocm(): batch_idx = tl.program_id(0) block_idx = tl.program_id(1) - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + - tl.arange(0, 2)) + batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2)) batch_query_start, batch_query_end = tl.split(batch_query_indexes) query_len = batch_query_end - batch_query_start if query_len <= 1: return - batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + - tl.arange(0, 2)) + batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2)) batch_token_start, batch_token_end = tl.split(batch_token_indexes) seq_len = batch_token_end - batch_token_start if block_idx * BLOCK_SIZE < seq_len: - block_mask = (block_idx * BLOCK_SIZE + - tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len + block_mask = ( + block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ) < seq_len - kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + - block_idx).to(tl.int64) + kv_idx = tl.load( + block_table + batch_idx * block_table_stride_0 + block_idx + ).to(tl.int64) - kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( - 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] - k_vals = tl.load(k_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) + kv_buffer_off = ( + kv_idx * BLOCK_SIZE * E_DIM + + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + + tl.arange(0, E_DIM)[None, :] + ) + k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) if k_vals.dtype.is_fp8(): - k_vals = (k_vals.to(tl.float32) * - tl.load(k_scale)).to(output_dtype) + k_vals = (k_vals.to(tl.float32) * tl.load(k_scale)).to(output_dtype) else: k_vals = k_vals.to(output_dtype) - v_vals = tl.load(v_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) + v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) if v_vals.dtype.is_fp8(): - v_vals = (v_vals.to(tl.float32) * - tl.load(v_scale)).to(output_dtype) + v_vals = (v_vals.to(tl.float32) * tl.load(v_scale)).to(output_dtype) else: v_vals = v_vals.to(output_dtype) - kv_values_off = batch_token_start * E_DIM + \ - block_idx * BLOCK_SIZE * E_DIM + \ - tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \ - tl.arange(0, E_DIM)[None, :] + kv_values_off = ( + batch_token_start * E_DIM + + block_idx * BLOCK_SIZE * E_DIM + + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + + tl.arange(0, E_DIM)[None, :] + ) tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask) tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) - def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, - k_cache, v_cache, max_seq_len, k_scale, v_scale, - output_dtype, total_tokens): + def vllm_layout_trans( + b_query_lens_loc, + b_seq_lens_loc, + block_table, + k_cache, + v_cache, + max_seq_len, + k_scale, + v_scale, + output_dtype, + total_tokens, + ): H_KV = v_cache.shape[2] D = v_cache.shape[3] BLOCK_SIZE = v_cache.shape[1] @@ -107,8 +122,7 @@ if current_platform.is_rocm(): device=v_cache.device, ) - grid = (block_table.shape[0], - (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) + grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) if output_dtype == torch.float16: output_dtype = tl.float16 @@ -117,19 +131,21 @@ if current_platform.is_rocm(): else: raise ValueError(f"Unsupported output dtype: {output_dtype}") - _vllm_layout_trans_kernel[grid](k_cache, - v_cache, - k_values, - v_values, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table.stride(0), - k_scale, - v_scale, - output_dtype=output_dtype, - E_DIM=H_KV * D, - BLOCK_SIZE=BLOCK_SIZE) + _vllm_layout_trans_kernel[grid]( + k_cache, + v_cache, + k_values, + v_values, + b_query_lens_loc, + b_seq_lens_loc, + block_table, + block_table.stride(0), + k_scale, + v_scale, + output_dtype=output_dtype, + E_DIM=H_KV * D, + BLOCK_SIZE=BLOCK_SIZE, + ) return k_values, v_values @@ -152,9 +168,18 @@ if current_platform.is_rocm(): ) -> torch.Tensor: if total_tokens == 0: total_tokens = int(cu_seqlens_k[-1].item()) - k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table, - k_cache, v_cache, max_seqlen_k, k_scale, - v_scale, q.dtype, total_tokens) + k, v = vllm_layout_trans( + cu_seqlens_q, + cu_seqlens_k, + block_table, + k_cache, + v_cache, + max_seqlen_k, + k_scale, + v_scale, + q.dtype, + total_tokens, + ) output = aiter.flash_attn_varlen_func( q=q, @@ -190,16 +215,17 @@ if current_platform.is_rocm(): v_scale: torch.Tensor, total_tokens: int = 0, ) -> torch.Tensor: - return torch.empty(q.shape[0], - q.shape[1], - v_cache.shape[-2], - dtype=q.dtype, - device=q.device) + return torch.empty( + q.shape[0], q.shape[1], v_cache.shape[-2], dtype=q.dtype, device=q.device + ) - direct_register_custom_op("flash_attn_varlen_func", - flash_attn_varlen_func_impl, ["out"], - flash_attn_varlen_func_fake, - dispatch_key=current_platform.dispatch_key) + direct_register_custom_op( + "flash_attn_varlen_func", + flash_attn_varlen_func_impl, + ["out"], + flash_attn_varlen_func_fake, + dispatch_key=current_platform.dispatch_key, + ) logger = init_logger(__name__) @@ -231,11 +257,17 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[AiterFlashAttentionMetadata]): + AttentionMetadataBuilder[AiterFlashAttentionMetadata] +): cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -243,9 +275,9 @@ class AiterFlashAttentionMetadataBuilder( self.cache_config = vllm_config.cache_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size # Sliding window size to be used with the AOT scheduler will be @@ -254,19 +286,22 @@ class AiterFlashAttentionMetadataBuilder( self.total_tokens: int = 0 def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - self.total_tokens = self.model_config.max_model_len \ + self, common_attn_metadata: CommonAttentionMetadata + ): + self.total_tokens = ( + self.model_config.max_model_len * self.vllm_config.scheduler_config.max_num_partial_prefills - res = self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + ) + res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) self.total_tokens = 0 return res - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> "AiterFlashAttentionMetadata": num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len @@ -277,20 +312,18 @@ class AiterFlashAttentionMetadataBuilder( if max_query_len > 1: # We pre-compute cumulative seq len needed for prefill attention # here to avoid recomputing it for every layer - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, - dtype=torch.int32, - device=seq_lens.device) - torch.cumsum(seq_lens, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) + cu_seq_lens = torch.zeros( + seq_lens.shape[0] + 1, dtype=torch.int32, device=seq_lens.device + ) + torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) num_actual_kv_tokens = int(cu_seq_lens[-1].item()) else: cu_seq_lens = None num_actual_kv_tokens = 0 - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): return None use_cascade = common_prefix_len > 0 @@ -316,7 +349,6 @@ class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -336,7 +368,8 @@ class AiterFlashAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -368,7 +401,6 @@ class AiterFlashAttentionBackend(AttentionBackend): class AiterFlashAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -396,7 +428,7 @@ class AiterFlashAttentionImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0. + logits_soft_cap = 0.0 self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -406,10 +438,12 @@ class AiterFlashAttentionImpl(AttentionImpl): AiterFlashAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl" + ) def forward( self, @@ -442,8 +476,8 @@ class AiterFlashAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -512,13 +546,14 @@ class AiterFlashAttentionImpl(AttentionImpl): _, num_heads, head_size = query.shape nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 num_seqs = seqused_k.shape[0] - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - - 1) // _PARTITION_SIZE_ROCM + max_num_partitions = ( + max_seqlen_k + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * - nbytes_per_qo_elem + 2 * - (num_seqs * num_heads * max_num_partitions) * 4, + (num_seqs * num_heads * max_num_partitions * head_size) + * nbytes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, dtype=torch.uint8, device=output.device, ) @@ -546,4 +581,5 @@ class AiterFlashAttentionImpl(AttentionImpl): return output else: raise NotImplementedError( - "Cascade attention is not implemented for ROCM AITER") + "Cascade attention is not implemented for ROCM AITER" + ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 1748a48168..4c24770aa2 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" + from dataclasses import dataclass from functools import cache from typing import ClassVar, Optional @@ -9,20 +10,27 @@ import torch from vllm import _custom_ops as ops from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) + QuantKey, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -58,21 +66,25 @@ class RocmAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None -class RocmAttentionMetadataBuilder( - AttentionMetadataBuilder[RocmAttentionMetadata]): +class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( @@ -93,10 +105,12 @@ class RocmAttentionMetadataBuilder( return attn_metadata - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> RocmAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> RocmAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -109,14 +123,13 @@ class RocmAttentionMetadataBuilder( use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None @@ -143,7 +156,6 @@ class RocmAttentionMetadataBuilder( class RocmAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -163,7 +175,8 @@ class RocmAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -203,12 +216,10 @@ def use_aiter_unified_attention() -> bool: """Check if aiter unified attention should be used.""" # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set # to 1 as default - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION class RocmAttentionImpl(AttentionImpl): - def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -249,29 +260,30 @@ class RocmAttentionImpl(AttentionImpl): RocmAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "RocmAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "RocmAttentionImpl" + ) self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION if not self.force_prefill_decode_attn: # If not using prefill decode attention, we use the Triton # unified attention implementation. if use_aiter_unified_attention(): - logger.info_once( - "Using aiter unified attention for RocmAttentionImpl") - from aiter.ops.triton.unified_attention import ( - unified_attention) + logger.info_once("Using aiter unified attention for RocmAttentionImpl") + from aiter.ops.triton.unified_attention import unified_attention + self.unified_attention = unified_attention else: - logger.info_once( - "Using vllm unified attention for RocmAttentionImpl") + logger.info_once("Using vllm unified attention for RocmAttentionImpl") from vllm.attention.ops.triton_unified_attention import ( - unified_attention) + unified_attention, + ) + self.unified_attention = unified_attention self.sinks = sinks @@ -279,7 +291,8 @@ class RocmAttentionImpl(AttentionImpl): assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") + f"num_heads: {num_heads}." + ) def forward( self, @@ -310,7 +323,8 @@ class RocmAttentionImpl(AttentionImpl): if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" - " for RocmAttentionImpl") + " for RocmAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -332,7 +346,8 @@ class RocmAttentionImpl(AttentionImpl): if use_prefill_decode_attn: key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size + ) else: key_cache, value_cache = kv_cache.unbind(0) @@ -366,16 +381,17 @@ class RocmAttentionImpl(AttentionImpl): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale_float == 1.0, \ + assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." + ) if current_platform.is_cuda(): # Skip Q quantization on ROCm and XPU, enable this on cuda # only, since dequantizing back to f32 in the attention kernel # is not supported. query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc @@ -430,6 +446,7 @@ class RocmAttentionImpl(AttentionImpl): k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, - output_scale=output_scale) + output_scale=output_scale, + ) return output diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index ba0fba4281..74cfecca76 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -6,16 +6,16 @@ from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) class ShortConvAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: return ShortConvAttentionMetadataBuilder @@ -39,12 +39,14 @@ class ShortConvAttentionMetadata: class ShortConvAttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]): - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> ShortConvAttentionMetadata: + BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata] +): + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] @@ -54,28 +56,38 @@ class ShortConvAttentionMetadataBuilder( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) has_initial_states_p = None if num_prefills > 0: has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) - has_initial_states_p = has_initial_states_cpu.to( - query_start_loc.device) + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) + has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device) - query_start_loc_p = common_attn_metadata.query_start_loc[ - -num_prefills - 1:] - num_decode_tokens + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) - nums_dict, batch_ptr, token_chunk_offset_ptr = \ + nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(query_start_loc_p) + ) - elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph): + elif ( + num_decodes > 0 + and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) - self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, - non_blocking=True) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 583756129a..2a7770c87d 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -8,14 +8,21 @@ from typing import TYPE_CHECKING, Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -28,7 +35,6 @@ logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -48,7 +54,8 @@ class TreeAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -114,9 +121,9 @@ class TreeAttentionMetadata: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -124,8 +131,8 @@ class TreeAttentionMetadata: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -139,9 +146,9 @@ class TreeAttentionMetadata: # metadata structure return self._cached_decode_metadata - q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_start_loc = self.query_start_loc[: self.num_decodes + 1] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[:self.num_decodes] + kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, @@ -149,16 +156,14 @@ class TreeAttentionMetadata: query_start_loc=q_start_loc, max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], tree_attn_bias=self.tree_attn_bias, ) return self._cached_decode_metadata -class TreeAttentionMetadataBuilder( - AttentionMetadataBuilder[TreeAttentionMetadata]): - +class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, @@ -172,10 +177,9 @@ class TreeAttentionMetadataBuilder( spec_config = vllm_config.speculative_config spec_token_tree = (spec := spec_config) and spec.speculative_token_tree - tree_choices: list[tuple[int, - ...]] = (ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else - [(0, )]) + tree_choices: list[tuple[int, ...]] = ( + ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] + ) # Construct the tree attention bias. depth_counts = _get_depth_counts(tree_choices) self.tree_attn_bias = _prepare_tree_attn_bias( @@ -185,12 +189,12 @@ class TreeAttentionMetadataBuilder( device=device, ) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.tree_attn_bias.shape[0]) + input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0] + ) def build( self, @@ -200,8 +204,10 @@ class TreeAttentionMetadataBuilder( ) -> TreeAttentionMetadata: decode_threshold = self.tree_attn_bias.shape[0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=decode_threshold)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=decode_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -241,8 +247,7 @@ class TreeAttentionMetadataBuilder( # Slice the tree attention bias for drafting. Exclude # the root level. start, end = 1, 1 + common_attn_metadata.max_query_len - self.tree_attn_bias = self.tree_attn_bias[start:end, - start:end].contiguous() + self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() # Build attention bias. attn_metadata = self.build(0, common_attn_metadata, fast_build=True) @@ -273,10 +278,9 @@ def _prepare_tree_attn_bias( ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) + tree_attn_mask = torch.full( + (tree_len, tree_len), -torch.inf, device=device, dtype=dtype + ) # Set diagonal to all zeros. Each token should # attend to itself. @@ -298,14 +302,14 @@ def _prepare_tree_attn_bias( ancestor_idx = [] for c in range(len(cur_tree_choice) - 1): ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 + ) tree_attn_mask[j + start + 1, ancestor_idx] = mask_val start += depth_counts[i] return tree_attn_mask class TreeAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -341,10 +345,12 @@ class TreeAttentionImpl(AttentionImpl): TreeAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TreeAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TreeAttentionImpl." + ) def forward( self, @@ -374,8 +380,8 @@ class TreeAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for TreeAttentionImpl") + "fused output quantization is not yet supported for TreeAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -404,8 +410,7 @@ class TreeAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens - descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) if prefill_meta := attn_metadata.prefill_metadata: unified_attention( q=query[num_decode_tokens:num_actual_tokens], diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 3983c5edc7..9997ed16be 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,24 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """High-Performance Triton-only Attention layer.""" + from dataclasses import dataclass from typing import ClassVar, Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash) + triton_reshape_and_cache_flash, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) + QuantKey, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec if current_platform.is_cuda_alike(): @@ -59,21 +69,25 @@ class TritonAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None -class TritonAttentionMetadataBuilder( - AttentionMetadataBuilder[TritonAttentionMetadata]): +class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( @@ -86,10 +100,12 @@ class TritonAttentionMetadataBuilder( attn_metadata.seq_lens.fill_(1) return attn_metadata - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TritonAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TritonAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -102,14 +118,13 @@ class TritonAttentionMetadataBuilder( use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None @@ -136,7 +151,6 @@ class TritonAttentionMetadataBuilder( class TritonAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -151,7 +165,8 @@ class TritonAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by TritonAttention." f"Head sizes need to be larger or equal 32 for this backend. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -187,7 +202,6 @@ class TritonAttentionBackend(AttentionBackend): class TritonAttentionImpl(AttentionImpl): - def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -228,10 +242,12 @@ class TritonAttentionImpl(AttentionImpl): TritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonAttentionImpl" + ) self.fp8_dtype = current_platform.fp8_dtype() @@ -240,7 +256,8 @@ class TritonAttentionImpl(AttentionImpl): assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") + f"num_heads: {num_heads}." + ) def forward( self, @@ -271,7 +288,8 @@ class TritonAttentionImpl(AttentionImpl): if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" - " for TritonAttentionImpl") + " for TritonAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -316,16 +334,17 @@ class TritonAttentionImpl(AttentionImpl): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale_float == 1.0, \ + assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." + ) if current_platform.is_cuda(): # Skip Q quantization on ROCm and XPU, enable this on cuda # only, since dequantizing back to f32 in the attention kernel # is not supported. query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f37a829f40..bddb2f22f0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,18 @@ import enum import functools from abc import abstractmethod from dataclasses import dataclass, fields, make_dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, - Protocol, TypeVar, Union, get_args) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + Optional, + Protocol, + TypeVar, + Union, + get_args, +) import numpy as np import torch @@ -21,11 +31,11 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) + get_kv_connector_cache_layout, +) from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.ubatch_utils import UBatchSlice @@ -46,7 +56,7 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - + For many of the tensors we keep both GPU and CPU versions. """ @@ -89,26 +99,27 @@ def slice_query_start_locs( request_slice: slice, ) -> torch.Tensor: """ - Creates a new query_start_loc that corresponds to the requests in + Creates a new query_start_loc that corresponds to the requests in request_slice. Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ - return query_start_loc[request_slice.start: request_slice.stop + 1] -\ - query_start_loc[request_slice.start] + return ( + query_start_loc[request_slice.start : request_slice.stop + 1] + - query_start_loc[request_slice.start] + ) def _make_metadata_with_slice( - ubatch_slice: UBatchSlice, - attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata +) -> CommonAttentionMetadata: """ - This function creates a new CommonAttentionMetadata that corresponds to + This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice """ - assert not ubatch_slice.is_empty(), ( - f"Ubatch slice {ubatch_slice} is empty") + assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice @@ -119,10 +130,12 @@ def _make_metadata_with_slice( last_req = request_slice.stop - 1 last_tok = token_slice.stop - 1 - assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \ + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( "Token slice start outside of first request" - assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \ + ) + assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], ( "Token slice end outside of last request" + ) # If the "middle" request has tokens in both ubatches, we have to split it. # If ubatch_slice is the first ubatch then we will be splitting the last @@ -132,12 +145,13 @@ def _make_metadata_with_slice( splits_last_request = last_tok < start_locs[last_req + 1] - 1 query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) - query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, - request_slice) + query_start_loc = slice_query_start_locs( + attn_metadata.query_start_loc, request_slice + ) assert len(query_start_loc) >= 2, ( - f"query_start_loc must have at least 2 elements, " - f"got {len(query_start_loc)}") + f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" + ) if splits_first_request: tokens_skipped = first_tok - start_locs[first_req] @@ -159,14 +173,13 @@ def _make_metadata_with_slice( seq_lens_cpu[-1] -= tokens_skipped max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ - request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( - torch.max(torch.abs(query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1])).item()) + torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() + ) # This is to account for the case where we are in a dummy # run and query_start_loc_cpu is full of 0s @@ -196,15 +209,14 @@ def split_attn_metadata( common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ - Creates a new CommonAttentionMetadata instance that corresponds to the + Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: - results.append( - _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) return results @@ -213,7 +225,7 @@ M = TypeVar("M") class AttentionCGSupport(enum.Enum): - """ Constants for the cudagraph support of the attention backend + """Constants for the cudagraph support of the attention backend Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" @@ -231,46 +243,53 @@ class AttentionCGSupport(enum.Enum): class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: Optional[int] = None @abstractmethod - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): self.kv_cache_spec = kv_cache_spec self.layer_names = layer_names self.vllm_config = vllm_config self.device = device def _init_reorder_batch_threshold( - self, - reorder_batch_threshold: int = 1, - supports_spec_as_decode: bool = False) -> None: + self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False + ) -> None: self.reorder_batch_threshold = reorder_batch_threshold - if self.reorder_batch_threshold is not None \ - and supports_spec_as_decode: + if self.reorder_batch_threshold is not None and supports_spec_as_decode: # If the backend supports spec-as-decode kernels, then we can set # the reorder_batch_threshold based on the number of speculative # tokens from the config. speculative_config = self.vllm_config.speculative_config - if (speculative_config is not None - and speculative_config.num_speculative_tokens is not None): - self.reorder_batch_threshold = \ + if ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ): + self.reorder_batch_threshold = ( 1 + speculative_config.num_speculative_tokens + ) @abstractmethod - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. - + Args: common_prefix_len: The length of the common prefix of the batch. common_attn_metadata: The common attention metadata. @@ -280,8 +299,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may @@ -298,14 +318,16 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): raise NotImplementedError def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) def build_for_drafting( self, @@ -314,7 +336,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ) -> M: """ Build attention metadata for draft model. Uses build by default. - + Args: common_attn_metadata: The common attention metadata. draft_index: The index of the current draft operation. @@ -323,9 +345,11 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): For tree-based attention, this index instead refers to the draft attempt for the i-th level in the tree of tokens. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - fast_build=True) + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) def use_cascade_attention( self, @@ -348,8 +372,11 @@ def get_kv_cache_layout(): if _KV_CACHE_LAYOUT_OVERRIDE is not None: cache_layout = _KV_CACHE_LAYOUT_OVERRIDE - logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ - "Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " + "Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout # Format specified by the user. @@ -359,8 +386,11 @@ def get_kv_cache_layout(): cache_layout = get_kv_connector_cache_layout() else: assert is_valid_kv_cache_layout(cache_layout) - logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ - "detected. Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`VLLM_KV_CACHE_LAYOUT` environment variable " + "detected. Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout @@ -385,8 +415,8 @@ class PerLayerParameters: def get_per_layer_parameters( - vllm_config: VllmConfig, layer_names: list[str], - cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: + vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"] +) -> dict[str, PerLayerParameters]: """ Scan layers in `layer_names` and determine some hyperparameters to use during `plan`. @@ -406,17 +436,18 @@ def get_per_layer_parameters( sm_scale = impl.scale has_sinks = getattr(impl, "sinks", None) is not None - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale, - has_sinks) + per_layer_params[key] = PerLayerParameters( + window_left, logits_soft_cap, sm_scale, has_sinks + ) return per_layer_params def infer_global_hyperparameters( - per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + per_layer_params: dict[str, PerLayerParameters], +) -> PerLayerParameters: """ - Currently, FlashInfer backend other than trtllm-gen + Currently, FlashInfer backend other than trtllm-gen only support models in which all layers share the same values for the following hyperparameters: - `window_left` @@ -437,13 +468,15 @@ def infer_global_hyperparameters( for params in param_sets: if params.window_left != global_params.window_left: raise ValueError( - "Window left is not the same for all layers. " \ - "One potential fix is to set disable_sliding_window=True") + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) assert params == global_params, ( "FlashInfer backend currently only supports models in which all" "layers share the same values " "for the following hyperparameters:" - "`window_left`, `logits_soft_cap`, `sm_scale`.") + "`window_left`, `logits_soft_cap`, `sm_scale`." + ) return global_params @@ -525,11 +558,10 @@ def make_local_attention_virtual_batches( # new_tokens_in_first_block = [2, 1, 4] # local_blocks = [2, 4, 2] q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), - q_seqlens).astype(np.int32) + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, - attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) # Once we know the number of local blocks we can compute the request spans # for each batch idx, we can figure out the number of "virtual" requests we @@ -550,14 +582,13 @@ def make_local_attention_virtual_batches( rarange = np.repeat(local_blocks, local_blocks) - arange - 1 # Then we can compute the seqlens_q_local, handling the fact that the # first and last blocks could be partial - seqlens_q_local = \ - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) # set the first block since this may be a partial block seqlens_q_local[arange == 0] = q_tokens_in_first_block # set the remaining blocks seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), - attn_chunk_size)[arange > 0] + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] # convert from q_seqlens to cu_seqlens_q cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) @@ -569,22 +600,20 @@ def make_local_attention_virtual_batches( # batch # For our example this will be: # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], - attn_chunk_size, - dtype=np.int32) + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block num_computed_tokens_local = seqlens_k_local - seqlens_q_local - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ - (rarange * attn_chunk_size + \ - np.repeat(tokens_in_last_block, local_blocks)) + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] block_starts = k_seqstarts_absolute // block_size - assert attn_chunk_size % block_size == 0, \ - f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by block_size {block_size}" + assert attn_chunk_size % block_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}" + ) pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks @@ -605,12 +634,14 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices = (block_starts[:, None] + - np.arange(pages_per_local_batch, dtype=np.int32)) - block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - - 1) - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch) + block_indices = block_starts[:, None] + np.arange( + pages_per_local_batch, dtype=np.int32 + ) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance # regression when using numpy arrays (batch and block indices) to index into @@ -618,8 +649,9 @@ def make_local_attention_virtual_batches( # tensor first, which recovers perf. batch_indices_torch = torch.from_numpy(batch_indices) block_indices_torch = torch.from_numpy(block_indices) - block_table_local = block_table[batch_indices_torch, block_indices_torch]\ - .view(virtual_batches, -1) + block_table_local = block_table[batch_indices_torch, block_indices_torch].view( + virtual_batches, -1 + ) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) @@ -627,8 +659,7 @@ def make_local_attention_virtual_batches( return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, - query_start_loc=query_start_loc_cpu.to(device=device, - non_blocking=True), + query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True), seq_lens_cpu=seq_lens_cpu, seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), @@ -668,9 +699,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Find how many decode indices belong to each request # request_ids: [0, 1, 1, 2] - request_ids = torch.bucketize(logits_indices, - query_start_loc[1:], - right=True) + request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True) # Figure out how many tokens are in each request # num_decode_tokens: [1, 2, 1] @@ -678,9 +707,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Calculate new query_start_loc with tokens in generation_indices # decode_query_start_loc: [0, 1, 3, 4] - decode_query_start_loc = torch.empty(num_reqs + 1, - device=query_start_loc.device, - dtype=query_start_loc.dtype) + decode_query_start_loc = torch.empty( + num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype + ) decode_query_start_loc[0] = 0 decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) @@ -689,8 +718,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( common_attn_metadata = CommonAttentionMetadata( query_start_loc=decode_query_start_loc, - query_start_loc_cpu=decode_query_start_loc.to("cpu", - non_blocking=True), + query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True), seq_lens=seq_lens, seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, @@ -706,22 +734,25 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( def subclass_attention_backend( - name_prefix: str, attention_backend_cls: type[AttentionBackend], - builder_cls: type[AttentionMetadataBuilder[M]] + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]], ) -> type[AttentionBackend]: """ Return a new subclass where `get_builder_cls` returns `builder_cls`. """ name: str = name_prefix + attention_backend_cls.__name__ # type: ignore - return type(name, (attention_backend_cls, ), - {"get_builder_cls": lambda: builder_cls}) + return type( + name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} + ) def split_decodes_and_prefills( - common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, - require_uniform: bool = False) -> tuple[int, int, int, int]: + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, + require_uniform: bool = False, +) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. @@ -745,8 +776,9 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold and \ - (not require_uniform or decode_threshold <= 1): + if max_query_len <= decode_threshold and ( + not require_uniform or decode_threshold <= 1 + ): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] @@ -779,7 +811,7 @@ def reorder_batch_to_split_decodes_and_prefills( """ Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch. - + Returns: True if the batch was modified, False otherwise. """ @@ -834,8 +866,7 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch -def reshape_query_for_spec_decode(query: torch.Tensor, - batch_size: int) -> torch.Tensor: +def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor: """ Reshapes the query tensor for the specified batch size, so that it has shape (batch_size, seq_len, num_heads, head_dim). @@ -845,13 +876,13 @@ def reshape_query_for_spec_decode(query: torch.Tensor, num_heads = query.shape[1] head_dim = query.shape[2] assert total_tokens % batch_size == 0, ( - f"{total_tokens=} is not divisible by {batch_size=}") + f"{total_tokens=} is not divisible by {batch_size=}" + ) seq_len = total_tokens // batch_size return query.view(batch_size, seq_len, num_heads, head_dim) -def reshape_attn_output_for_spec_decode( - attn_output: torch.Tensor) -> torch.Tensor: +def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor: """ Reshapes the attention output tensor, so that the batch_size and seq_len dimensions are combined. @@ -859,16 +890,14 @@ def reshape_attn_output_for_spec_decode( if attn_output.dim() == 3: # Already in the correct shape return attn_output - assert attn_output.dim() == 4, \ - f"attn_output must be 4D, got {attn_output.dim()}D" + assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D" total_tokens = attn_output.shape[0] * attn_output.shape[1] - return attn_output.view(total_tokens, attn_output.shape[2], - attn_output.shape[3]) + return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ('logits_indices_padded', Optional[torch.Tensor], None), - ('num_logits_indices', int, 0), + ("logits_indices_padded", Optional[torch.Tensor], None), + ("num_logits_indices", int, 0), ] @@ -881,7 +910,7 @@ def subclass_attention_metadata( Return a new subclass of `metadata_cls` with additional fields """ name: str = name_prefix + metadata_cls.__name__ # type: ignore - Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) + Wrapped = make_dataclass(name, fields, bases=(metadata_cls,)) return Wrapped @@ -895,55 +924,55 @@ def create_fast_prefill_custom_backend( prefix: str, underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: - underlying_builder = underlying_attn_backend.get_builder_cls() class FastPrefillAttentionBuilder(underlying_builder): # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: - new_common_attn_metadata =\ - make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) - metadata = super().build(common_prefix_len, - new_common_attn_metadata, fast_build) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_common_attn_metadata = ( + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + ) + metadata = super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) class KVSharingFastPrefillAttentionMetadata( - metadata.__class__, # type: ignore - KVSharingFastPrefillMetadata): - + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata, + ): def __init__(self, metadata, common_attn_metadata): # Shallow copy all fields in metadata cls for field in fields(metadata.__class__): - setattr(self, field.name, - getattr(metadata, field.name)) + setattr(self, field.name, getattr(metadata, field.name)) # Set additional fields that will be used in model code - assert (common_attn_metadata.logits_indices_padded - is not None - and common_attn_metadata.num_logits_indices - is not None) - self.logits_indices_padded = \ + assert ( + common_attn_metadata.logits_indices_padded is not None + and common_attn_metadata.num_logits_indices is not None + ) + self.logits_indices_padded = ( common_attn_metadata.logits_indices_padded - self.num_logits_indices = \ - common_attn_metadata.num_logits_indices + ) + self.num_logits_indices = common_attn_metadata.num_logits_indices - return KVSharingFastPrefillAttentionMetadata( - metadata, common_attn_metadata) + return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=FastPrefillAttentionBuilder) + builder_cls=FastPrefillAttentionBuilder, + ) return attn_backend def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): - # Needed for causal_conv1d - seqlens = query_start_loc_p.diff().to('cpu') + seqlens = query_start_loc_p.diff().to("cpu") nums_dict = {} # type: ignore batch_ptr = None token_chunk_offset_ptr = None @@ -951,40 +980,39 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): for BLOCK_M in [8]: # cover all BLOCK_M values nums = -(-seqlens // BLOCK_M) nums_dict[BLOCK_M] = {} - nums_dict[BLOCK_M]['nums'] = nums - nums_dict[BLOCK_M]['tot'] = nums.sum().item() + nums_dict[BLOCK_M]["nums"] = nums + nums_dict[BLOCK_M]["tot"] = nums.sum().item() mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) - nums_dict[BLOCK_M]['mlist'] = mlist - mlist_len = len(nums_dict[BLOCK_M]['mlist']) - nums_dict[BLOCK_M]['mlist_len'] = mlist_len + nums_dict[BLOCK_M]["mlist"] = mlist + mlist_len = len(nums_dict[BLOCK_M]["mlist"]) + nums_dict[BLOCK_M]["mlist_len"] = mlist_len MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 offsetlist = [] # type: ignore for idx, num in enumerate(nums): offsetlist.extend(range(num)) offsetlist = torch.tensor(offsetlist, dtype=torch.int32) - nums_dict[BLOCK_M]['offsetlist'] = offsetlist + nums_dict[BLOCK_M]["offsetlist"] = offsetlist if batch_ptr is None: # Update default value after class definition - batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) else: if batch_ptr.nelement() < MAX_NUM_PROGRAMS: batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) token_chunk_offset_ptr.resize_( # type: ignore - MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + MAX_NUM_PROGRAMS + ).fill_(PAD_SLOT_ID) batch_ptr[0:mlist_len].copy_(mlist) token_chunk_offset_ptr[ # type: ignore - 0:mlist_len].copy_(offsetlist) - nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr - nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr - ) # type: ignore + 0:mlist_len + ].copy_(offsetlist) + nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr + nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore return nums_dict, batch_ptr, token_chunk_offset_ptr diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 9d667ee04f..17e752277c 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -7,20 +7,29 @@ from typing import TYPE_CHECKING, Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( - AttentionBias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask) + AttentionBias, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ) XFORMERS_AVAILABLE = True except ImportError: @@ -36,7 +45,6 @@ logger = init_logger(__name__) class XFormersAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -86,7 +94,8 @@ class XFormersAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -153,9 +162,9 @@ class XFormersAttentionMetadata: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -163,8 +172,8 @@ class XFormersAttentionMetadata: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -180,24 +189,24 @@ class XFormersAttentionMetadata: q_start_loc = self.query_start_loc q_seqlens = torch.diff(q_start_loc) - decode_kv_seqlens = self.seq_lens[:self.num_decodes] + decode_kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersAttentionMetadata( num_actual_tokens=self.num_decode_tokens, - max_query_len=int(q_seqlens[:self.num_decodes].max().item()), - query_start_loc=q_start_loc[:self.num_decodes + 1], + max_query_len=int(q_seqlens[: self.num_decodes].max().item()), + query_start_loc=q_start_loc[: self.num_decodes + 1], max_seq_len=int(decode_kv_seqlens.max().item()), seq_lens=decode_kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], attn_bias=self.attn_bias, ) return self._cached_decode_metadata class XFormersAttentionMetadataBuilder( - AttentionMetadataBuilder[XFormersAttentionMetadata]): - + AttentionMetadataBuilder[XFormersAttentionMetadata] +): reorder_batch_threshold: int = 1 def __init__( @@ -214,12 +223,12 @@ class XFormersAttentionMetadataBuilder( self._num_decodes = 0 self._num_decode_tokens = 0 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold) + input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold + ) def build( self, @@ -229,8 +238,9 @@ class XFormersAttentionMetadataBuilder( ) -> XFormersAttentionMetadata: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -246,14 +256,13 @@ class XFormersAttentionMetadataBuilder( # Construct the decoder bias. decode_q_seqlens = q_seqlens[:num_decodes] decode_kv_seqlens = kv_seqlens[:num_decodes] - bias = ( - PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( - q_seqlen=decode_q_seqlens.tolist(), - kv_seqlen=decode_kv_seqlens.tolist(), - page_size=self.block_size, - block_tables=block_table[:num_decodes], - device=block_table.device, - )) + bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=decode_q_seqlens.tolist(), + kv_seqlen=decode_kv_seqlens.tolist(), + page_size=self.block_size, + block_tables=block_table[:num_decodes], + device=block_table.device, + ) return XFormersAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -272,7 +281,6 @@ class XFormersAttentionMetadataBuilder( class XFormersAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -289,8 +297,7 @@ class XFormersAttentionImpl(AttentionImpl): if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if alibi_slopes is not None: - raise NotImplementedError( - "XFormers does not support alibi slopes yet.") + raise NotImplementedError("XFormers does not support alibi slopes yet.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -313,10 +320,12 @@ class XFormersAttentionImpl(AttentionImpl): XFormersAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "XFormersAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "XFormersAttentionImpl." + ) def forward( self, @@ -347,7 +356,8 @@ class XFormersAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for XFormersAttentionImpl") + " for XFormersAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -377,8 +387,7 @@ class XFormersAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens if prefill_meta := attn_metadata.prefill_metadata: - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1]) unified_attention( q=query[num_decode_tokens:num_actual_tokens], k=key_cache, @@ -403,36 +412,38 @@ class XFormersAttentionImpl(AttentionImpl): # Query for decode. KV is not needed because it is already cached. decode_query = query[:num_decode_tokens] # Reshape query to [1, B_T, G, H, D]. - q = decode_query.view(1, -1, self.num_kv_heads, - self.num_queries_per_kv, self.head_size) + q = decode_query.view( + 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size + ) # Reshape the k and v caches to [1, Bkv_T, G, H, D] - cache_k = key_cache.view(1, -1, self.num_kv_heads, 1, - self.head_size).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) - cache_v = value_cache.view(1, -1, self.num_kv_heads, 1, - self.head_size).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) + cache_k = key_cache.view( + 1, -1, self.num_kv_heads, 1, self.head_size + ).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) + cache_v = value_cache.view( + 1, -1, self.num_kv_heads, 1, self.head_size + ).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) attn_bias = decode_meta.attn_bias - output[: - num_decode_tokens] = xops.memory_efficient_attention_forward( - q, - cache_k, - cache_v, - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - ).view(decode_query.shape) + output[:num_decode_tokens] = xops.memory_efficient_attention_forward( + q, + cache_k, + cache_v, + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + ).view(decode_query.shape) # Reshape the output tensor. return output diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 617a724a1a..ddfd943227 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,16 +3,24 @@ from collections.abc import Iterable from typing import Any, Optional, Union -from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, - BlockRemoved, BlockStored, - KVCacheEvent) +from vllm.distributed.kv_events import ( + MEDIUM_GPU, + AllBlocksCleared, + BlockRemoved, + BlockStored, + KVCacheEvent, +) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - ExternalBlockHash, - FreeKVCacheBlockQueue, KVCacheBlock, - get_block_hash, - make_block_hash_with_group_id, - maybe_convert_block_hash) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashWithGroupId, + ExternalBlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash, +) from vllm.v1.request import Request logger = init_logger(__name__) @@ -20,7 +28,7 @@ logger = init_logger(__name__) class BlockHashToBlockMap: """ - Cache of blocks that are used for prefix caching. It caches blocks + Cache of blocks that are used for prefix caching. It caches blocks from hash directly to a block or multiple blocks (i.e. {block_hash: KVCacheBlocks}) - Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks @@ -42,11 +50,11 @@ class BlockHashToBlockMap: """ def __init__(self): - self._cache: dict[BlockHashWithGroupId, - Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {} + self._cache: dict[ + BlockHashWithGroupId, Union[KVCacheBlock, dict[int, KVCacheBlock]] + ] = {} - def get_one_block(self, - key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + def get_one_block(self, key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: """ Gets any block with the given block hash key. """ @@ -77,8 +85,7 @@ class BlockHashToBlockMap: else: self._unexpected_blocks_type(blocks) - def pop(self, key: BlockHashWithGroupId, - block_id: int) -> Optional[KVCacheBlock]: + def pop(self, key: BlockHashWithGroupId, block_id: int) -> Optional[KVCacheBlock]: """ Checks if block_hash exists and pop block_id from the cache """ @@ -148,8 +155,7 @@ class BlockPool: self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) # Cache for block lookup - self.cached_block_hash_to_block: BlockHashToBlockMap = \ - BlockHashToBlockMap() + self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap() # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to @@ -161,9 +167,9 @@ class BlockPool: self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block( - self, block_hash: BlockHash, - kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: - """Get the cached block by the block hash for each group in + self, block_hash: BlockHash, kv_cache_group_ids: list[int] + ) -> Optional[list[KVCacheBlock]]: + """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -177,9 +183,11 @@ class BlockPool: cached_blocks = [] for group_id in kv_cache_group_ids: block_hash_with_group_id = make_block_hash_with_group_id( - block_hash, group_id) + block_hash, group_id + ) block = self.cached_block_hash_to_block.get_one_block( - block_hash_with_group_id) + block_hash_with_group_id + ) if not block: return None cached_blocks.append(block) @@ -218,17 +226,18 @@ class BlockPool: new_block_hashes = request.block_hashes[num_cached_blocks:] new_hashes: Optional[list[ExternalBlockHash]] = ( - [] if self.enable_kv_cache_events else None) + [] if self.enable_kv_cache_events else None + ) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = make_block_hash_with_group_id( - block_hash, kv_cache_group_id) + block_hash, kv_cache_group_id + ) blk.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block.insert(block_hash_with_group_id, - blk) + self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) @@ -239,20 +248,21 @@ class BlockPool: parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None parent_block_hash = maybe_convert_block_hash( - get_block_hash(parent_block.block_hash)) + get_block_hash(parent_block.block_hash) + ) self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, parent_block_hash=parent_block_hash, - token_ids=request. - all_token_ids[num_cached_blocks * - block_size:num_full_blocks * block_size], + token_ids=request.all_token_ids[ + num_cached_blocks * block_size : num_full_blocks * block_size + ], block_size=block_size, - lora_id=request.lora_request.id - if request.lora_request else None, + lora_id=request.lora_request.id if request.lora_request else None, medium=MEDIUM_GPU, - )) + ) + ) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -266,8 +276,7 @@ class BlockPool: A list of new block. """ if num_blocks > self.get_num_free_blocks(): - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") + raise ValueError(f"Cannot get {num_blocks} free blocks from the pool") ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) @@ -299,8 +308,7 @@ class BlockPool: # The block doesn't have hash, eviction is not needed return False - if self.cached_block_hash_to_block.pop(block_hash, - block.block_id) is None: + if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: # block not found in cached_block_hash_to_block, # eviction is not needed return False @@ -313,10 +321,11 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[ - maybe_convert_block_hash(get_block_hash(block_hash)) - ], - medium=MEDIUM_GPU)) + BlockRemoved( + block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))], + medium=MEDIUM_GPU, + ) + ) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: @@ -347,10 +356,9 @@ class BlockPool: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - self.free_block_queue.append_n([ - block for block in blocks_list - if block.ref_cnt == 0 and not block.is_null - ]) + self.free_block_queue.append_n( + [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] + ) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -365,7 +373,9 @@ class BlockPool: if num_used_blocks != 1: # The null block is always marked as used logger.warning( "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks - 1) + "blocks (%d) are not freed yet", + num_used_blocks - 1, + ) return False # Remove all hashes so that no new blocks will hit. @@ -405,7 +415,7 @@ class BlockPool: def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. - + Returns: A list of KV cache events. """ diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index eadea15a2e..c70025992e 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -33,12 +33,12 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Cache is enabled to share embeddings of same multimodal data - item (identified by their hash value) between different requests, - and eviction takes place at allocation time when there's no free + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free space for new embeddings. Oldest cached embeddings with no request referenced will be first evicted. - + Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. @@ -99,27 +99,31 @@ class EncoderCacheManager: self.cached[mm_hash].add(request.request_id) return True - def can_allocate(self, request: Request, input_id: int, - encoder_compute_budget: int, - num_tokens_to_schedule: int) -> bool: - """Check if there's sufficient cache space for a multimodal input. + def can_allocate( + self, + request: Request, + input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int, + ) -> bool: + """Check if there's sufficient cache space for a multimodal input. If there is, return True and update EncoderCacheManager state. If there is not enough free space in `num_free_slots` but there is enough reclaimable space in `num_freeable_slots`, entries will be evicted from `freeable` (their mm_hash appended to `freed`) until - enough space is available, and then this method returns True. + enough space is available, and then this method returns True. Older entries are evicted first. - - Returns False only if the requested number of tokens exceeds both + + Returns False only if the requested number of tokens exceeds both the free and reclaimable capacities combined. Args: request: The request containing the multimodal input. input_id: Index of the multimodal input within the request. - encoder_compute_budget: Number of encoder tokens allowed to be + encoder_compute_budget: Number of encoder tokens allowed to be computed when this method is invoked. - num_tokens_to_schedule: Number of tokens already scheduled to be + num_tokens_to_schedule: Number of tokens already scheduled to be allocated with cache space when this method is invoked. Returns: @@ -127,7 +131,7 @@ class EncoderCacheManager: input (possibly after reclaiming `freeable` entries); otherwise False. - Note: This method does not allocate physical memory for the encoder + Note: This method does not allocate physical memory for the encoder output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) @@ -202,7 +206,7 @@ class EncoderCacheManager: When the reference set for the corresponding `mm_hash` becomes empty, the entry is appended to `freeable` and `num_freeable_slots` is - increased by the number of encoder tokens for that input. + increased by the number of encoder tokens for that input. The entry is NOT physically freed until capacity is needed (e.g., by `can_allocate`). @@ -221,8 +225,8 @@ class EncoderCacheManager: def free(self, request: Request) -> None: """Free all encoder input cache reference held by *request*. - For each cached input ID, `free_encoder_input` is invoked. - The data stays in memory until eviction is triggered by a future + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future attempt allocation called by 'can_allocate'. Typically called when a request is finished, cancelled, or aborted. @@ -236,9 +240,9 @@ class EncoderCacheManager: Returns: List of mm_hash strings that were actually evicted since the last - call to be used by the scheduler to notify workers about which - encoder outputs can be removed from their caches. The internal - list is cleared after this call. + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -250,7 +254,7 @@ def compute_encoder_budget( scheduler_config: "SchedulerConfig", mm_registry: MultiModalRegistry, ) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler + """Compute the encoder cache budget based on the model and scheduler configurations. Returns: @@ -260,8 +264,9 @@ def compute_encoder_budget( from the input sequence. """ if mm_registry.supports_multimodal_inputs(model_config): - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) + ) return compute_mm_encoder_budget( scheduler_config, @@ -271,18 +276,17 @@ def compute_encoder_budget( return compute_text_encoder_budget(scheduler_config) -def compute_text_encoder_budget( - scheduler_config: "SchedulerConfig") -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler +def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler configurations for a text-only model. Args: scheduler_config: Scheduler configuration. Returns: - - Compute budget for encoder execution, in unit of number of tokens + - Compute budget for encoder execution, in unit of number of tokens in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens + - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ # Currently text-only encoder-decoder models are not supported @@ -293,7 +297,7 @@ def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler + """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: @@ -312,22 +316,28 @@ def compute_mm_encoder_budget( logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " - "not be initialized.") + "not be initialized." + ) return 0, 0 max_tokens_per_mm_item = max(max_tokens_by_modality.values()) - if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item - > scheduler_config.max_num_batched_tokens): + if ( + scheduler_config.disable_chunked_mm_input + and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens + ): raise ValueError( "Chunked MM input disabled but max_tokens_per_mm_item " f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens" f" ({scheduler_config.max_num_batched_tokens}). Please increase " - "max_num_batched_tokens.") + "max_num_batched_tokens." + ) - encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, - max_tokens_per_mm_item) - encoder_cache_size = max(scheduler_config.encoder_cache_size, - max_tokens_per_mm_item) + encoder_compute_budget = max( + scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item + ) + encoder_cache_size = max( + scheduler_config.encoder_cache_size, max_tokens_per_mm_item + ) return encoder_compute_budget, encoder_cache_size diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 86771060c4..37e1b7ca39 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,9 +6,11 @@ from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + CrossAttentionManager, + FullAttentionManager, + get_manager_for_kv_cache_spec, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.request import Request @@ -30,8 +32,9 @@ class KVCacheCoordinator(ABC): self.max_model_len = max_model_len self.enable_caching = enable_caching - self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, - enable_kv_cache_events) + self.block_pool = BlockPool( + kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events + ) # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle @@ -41,19 +44,23 @@ class KVCacheCoordinator(ABC): block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, - ) for i, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups)) + ) + for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) + ) - def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[ - list[KVCacheBlock], ...], - num_encoder_tokens: int) -> int: + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[list[KVCacheBlock], ...], + num_encoder_tokens: int, + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -69,15 +76,17 @@ class KVCacheCoordinator(ABC): # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_encoder_tokens, []) + request_id, num_encoder_tokens, [] + ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + request_id, num_tokens, new_computed_blocks[i] + ) return num_blocks_to_allocate def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None: + self, request_id: str, new_computed_blocks: tuple[list[KVCacheBlock], ...] + ) -> None: """ Add the new computed blocks to the request. @@ -87,21 +96,18 @@ class KVCacheCoordinator(ABC): prefix cache. """ for i, manager in enumerate(self.single_type_managers): - manager.save_new_computed_blocks(request_id, - new_computed_blocks[i]) + manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) def allocate_new_blocks( - self, - request_id: str, - num_tokens: int, - num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]: + self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 + ) -> tuple[list[KVCacheBlock], ...]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. @@ -111,9 +117,13 @@ class KVCacheCoordinator(ABC): """ return tuple( manager.allocate_new_blocks( - request_id, num_encoder_tokens if isinstance( - manager, CrossAttentionManager) else num_tokens) - for manager in self.single_type_managers) + request_id, + num_encoder_tokens + if isinstance(manager, CrossAttentionManager) + else num_tokens, + ) + for manager in self.single_type_managers + ) def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ @@ -138,8 +148,9 @@ class KVCacheCoordinator(ABC): for manager in self.single_type_managers: manager.free(request_id) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> list[int]: """ Get the number of common prefix blocks for all requests in the RUNNING state for each kv cache group. @@ -154,16 +165,14 @@ class KVCacheCoordinator(ABC): the RUNNING state for each kv cache group. """ num_blocks_per_group = [ - manager.get_num_common_prefix_blocks(request_id, - num_running_requests) + manager.get_num_common_prefix_blocks(request_id, num_running_requests) for manager in self.single_type_managers ] return num_blocks_per_group - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and replace + Remove the blocks that are no longer needed from `blocks` and replace the removed blocks with null_block. Args: @@ -179,7 +188,8 @@ class KVCacheCoordinator(ABC): """ return tuple( manager.req_to_blocks.get(request_id) or [] - for manager in self.single_type_managers) + for manager in self.single_type_managers + ) @abstractmethod def find_longest_cache_hit( @@ -198,19 +208,27 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): Does not implement any features related to prefix caching. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_kv_cache_events: bool, - dcp_world_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - False, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) self.num_single_type_manager = len(self.single_type_managers) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> list[int]: return [0] * self.num_single_type_manager def find_longest_cache_hit( @@ -219,7 +237,8 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(self.num_single_type_manager)) + [] for _ in range(self.num_single_type_manager) + ) return blocks, 0 @@ -230,23 +249,31 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): full attention or all attention layers use sliding window attention. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) - self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "UnitaryKVCacheCoordinator assumes only one kv cache group") + "UnitaryKVCacheCoordinator assumes only one kv cache group" + ) def find_longest_cache_hit( self, @@ -269,26 +296,34 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and thus multiple kv cache groups. - To simplify `find_longest_cache_hit`, it only supports the combination of + To simplify `find_longest_cache_hit`, it only supports the combination of two types of KV cache groups, and one of them must be full attention. May extend to more general cases in the future. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: """ - Verifies that the model has exactly two types of KV cache groups, and + Verifies that the model has exactly two types of KV cache groups, and one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ @@ -303,7 +338,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): else: assert full_attention_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes exactly one type of " - "full attention groups now.") + "full attention groups now." + ) self.full_attention_group_ids.append(i) else: if other_spec is None: @@ -311,19 +347,22 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): else: assert other_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes " - "exactly one other type of groups now.") + "exactly one other type of groups now." + ) self.other_group_ids.append(i) assert full_attention_spec is not None, ( "HybridKVCacheCoordinator assumes exactly one type of full " - "attention groups now.") + "attention groups now." + ) assert other_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of other " - "groups now.") + "HybridKVCacheCoordinator assumes exactly one type of other groups now." + ) self.full_attention_manager_cls = FullAttentionManager self.other_attention_cls = self.single_type_managers[ - self.other_group_ids[0]].__class__ + self.other_group_ids[0] + ].__class__ self.full_attention_spec = full_attention_spec self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size @@ -334,7 +373,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): divisible = self.other_block_size % self.full_attention_block_size assert divisible == 0, ( "KVCacheCoordinator assumes the block_size of full " - "attention layers is divisible by other layers now.") + "attention layers is divisible by other layers now." + ) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -347,7 +387,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): "do not interleave, either full attention group ids " "are before other attention group ids or vice versa." "This is for simplifying merging hit_blocks_full_attn and " - "hit_blocks_other_attn to hit_blocks.") + "hit_blocks_other_attn to hit_blocks." + ) def find_longest_cache_hit( self, @@ -367,29 +408,26 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. - hit_blocks_full_attn = ( - self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=block_hashes, - max_length=max_cache_hit_length, - kv_cache_group_ids=self.full_attention_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.full_attention_spec, - use_eagle=self.use_eagle, - )) - hit_length = len( - hit_blocks_full_attn[0]) * self.full_attention_block_size + hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=self.full_attention_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.full_attention_spec, + use_eagle=self.use_eagle, + ) + hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. - hit_blocks_other_attn = ( - self.other_attention_cls.find_longest_cache_hit( - block_hashes=block_hashes, - max_length=hit_length, - kv_cache_group_ids=self.other_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.other_spec, - use_eagle=self.use_eagle, - )) + hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.other_spec, + use_eagle=self.use_eagle, + ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size # NOTE: the prefix cache hit length must be a multiple of block_size as @@ -404,7 +442,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): # Truncate the full attention cache hit to the length of the # cache hit of the other attention. for group_hit_blocks in hit_blocks_full_attn: - del group_hit_blocks[hit_length // self.full_attention_block_size:] + del group_hit_blocks[hit_length // self.full_attention_block_size :] # Merge the hit blocks of full attention and other attention. if self.full_attn_first: @@ -414,27 +452,36 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): return hit_blocks, hit_length -def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, - max_model_len: int, use_eagle: bool, - enable_caching: bool, - enable_kv_cache_events: bool, - dcp_world_size: int) -> KVCacheCoordinator: +def get_kv_cache_coordinator( + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, +) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, - max_model_len, - use_eagle, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + return KVCacheCoordinatorNoPrefixCache( + kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) - return HybridKVCacheCoordinator(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + return UnitaryKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + return HybridKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0af98e7ba2..3e1a83a8a2 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -22,6 +22,7 @@ class KVCacheBlocks: Scheduler and KVCacheManager, to hide KVCacheManager's internal data structure from the Scheduler. """ + blocks: tuple[list[KVCacheBlock], ...] """ `blocks[i][j]` refers to the i-th kv_cache_group @@ -35,22 +36,20 @@ class KVCacheBlocks: def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( - tuple(blk1 + blk2 - for blk1, blk2 in zip(self.blocks, other.blocks))) + tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)) + ) @overload def get_block_ids( self, allow_none: Literal[False] = False, - ) -> tuple[list[int], ...]: - ... + ) -> tuple[list[int], ...]: ... @overload def get_block_ids( self, allow_none: Literal[True] = True, - ) -> Optional[tuple[list[int], ...]]: - ... + ) -> Optional[tuple[list[int], ...]]: ... def get_block_ids( self, @@ -72,10 +71,7 @@ class KVCacheBlocks: def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" assert len(self.blocks) == 1, "Only one group is supported" - return [ - block.block_id for block in self.blocks[0] - if block.block_hash is None - ] + return [block.block_id for block in self.blocks[0] if block.block_hash is None] def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" @@ -83,7 +79,6 @@ class KVCacheBlocks: class KVCacheManager: - def __init__( self, kv_cache_config: KVCacheConfig, @@ -104,12 +99,18 @@ class KVCacheManager: self.block_size: Optional[int] = None if self.enable_caching: - assert len( - set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) - ) == 1, "Only one block size is supported for now" + assert ( + len( + set( + g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups + ) + ) + == 1 + ), "Only one block size is supported for now" self.block_size = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size + 0 + ].kv_cache_spec.block_size if dcp_world_size > 1: assert len(kv_cache_config.kv_cache_groups) == 1 @@ -151,8 +152,7 @@ class KVCacheManager: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, - request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -166,9 +166,10 @@ class KVCacheManager: """ # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. - if (not self.enable_caching - or (request.sampling_params is not None - and request.sampling_params.prompt_logprobs is not None)): + if not self.enable_caching or ( + request.sampling_params is not None + and request.sampling_params.prompt_logprobs is not None + ): return self.create_empty_block_list(), 0 # NOTE: When all tokens hit the cache, we must recompute the last token @@ -179,8 +180,10 @@ class KVCacheManager: # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(request.block_hashes, - max_cache_hit_length)) + self.coordinator.find_longest_cache_hit( + request.block_hashes, max_cache_hit_length + ) + ) if self.log_stats: assert self.prefix_cache_stats is not None @@ -188,8 +191,7 @@ class KVCacheManager: # Previously preempted request self.prefix_cache_stats.preempted_requests += 1 self.prefix_cache_stats.preempted_queries += request.num_tokens - self.prefix_cache_stats.preempted_hits += ( - num_new_computed_tokens) + self.prefix_cache_stats.preempted_hits += num_new_computed_tokens else: # New request self.prefix_cache_stats.requests += 1 @@ -250,7 +252,8 @@ class KVCacheManager: new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = tuple( - [] for _ in range(len(self.kv_cache_config.kv_cache_groups))) + [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) + ) # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -258,16 +261,17 @@ class KVCacheManager: # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self.coordinator.remove_skipped_blocks(request.request_id, - request.num_computed_tokens) + self.coordinator.remove_skipped_blocks( + request.request_id, request.num_computed_tokens + ) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits - num_computed_tokens = (request.num_computed_tokens + - num_new_computed_tokens) + num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, - self.max_model_len) + self.max_model_len, + ) num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, @@ -285,16 +289,18 @@ class KVCacheManager: self.block_pool.touch(new_computed_block_list) else: assert not any(new_computed_block_list), ( - "Computed blocks should be empty when " - "prefix caching is disabled") + "Computed blocks should be empty when prefix caching is disabled" + ) # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.coordinator.save_new_computed_blocks(request.request_id, - new_computed_block_list) + self.coordinator.save_new_computed_blocks( + request.request_id, new_computed_block_list + ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot, num_encoder_tokens) + request.request_id, num_tokens_need_slot, num_encoder_tokens + ) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -305,8 +311,9 @@ class KVCacheManager: # num_new_tokens, but must exclude "non-committable" tokens (e.g., # draft tokens that could be rejected). Therefore, we cap the number # at `request.num_tokens`, ensuring only "finalized" tokens are cached. - num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, - request.num_tokens) + num_tokens_to_cache = min( + num_computed_tokens + num_new_tokens, request.num_tokens + ) self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) @@ -378,7 +385,8 @@ class KVCacheManager: """ assert request.status == RequestStatus.RUNNING return self.coordinator.get_num_common_prefix_blocks( - request.request_id, num_running_requests) + request.request_id, num_running_requests + ) def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -403,5 +411,4 @@ class KVCacheManager: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] - for _ in range(self.num_kv_cache_groups))) + return KVCacheBlocks(tuple([] for _ in range(self.num_kv_cache_groups))) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bbfd93413f..4683ad6298 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -13,11 +13,16 @@ from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import GiB_bytes, cdiv, sha256_cbor -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec, - UniformTypeKVCacheSpecs) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -37,16 +42,16 @@ BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes) ExternalBlockHash = Union[bytes, int] -def make_block_hash_with_group_id(block_hash: BlockHash, - group_id: int) -> BlockHashWithGroupId: +def make_block_hash_with_group_id( + block_hash: BlockHash, group_id: int +) -> BlockHashWithGroupId: """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. The group id is encoded using 4 bytes in big-endian order and appended to the block hash bytes. This representation avoids creating tuples while still allowing us to recover both components when needed. """ - return BlockHashWithGroupId(block_hash + - group_id.to_bytes(4, "big", signed=False)) + return BlockHashWithGroupId(block_hash + group_id.to_bytes(4, "big", signed=False)) def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: @@ -87,7 +92,8 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]): "PYTHONHASHSEED is not set. This will lead to non-reproducible " "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " - "reproducibility.") + "reproducibility." + ) if hash_seed is None: NONE_HASH = BlockHash(os.urandom(32)) @@ -143,9 +149,10 @@ class PrefixCachingMetrics: # Remove the oldest stats until number of requests does not exceed # the limit. # NOTE: We preserve the latest added stats regardless. - while len( - self.query_queue - ) > 1 and self.aggregated_requests > self.max_recent_requests: + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): old_requests, old_queries, old_hits = self.query_queue.popleft() self.aggregated_requests -= old_requests self.aggregated_query_total -= old_queries @@ -169,6 +176,7 @@ class PrefixCachingMetrics: @dataclass class KVCacheBlock: """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int # Reference count. @@ -192,7 +200,8 @@ class KVCacheBlock: @block_hash.setter def block_hash(self, block_hash: BlockHashWithGroupId): assert self.block_hash is None, ( - "The block already has a hash. This should not happen.") + "The block already has a hash. This should not happen." + ) self._block_hash = block_hash def reset_hash(self): @@ -202,15 +211,15 @@ class KVCacheBlock: def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = (self.prev_free_block.block_id - if self.prev_free_block else None) - next_block_id = (self.next_free_block.block_id - if self.next_free_block else None) - return (f"KVCacheBlock(block_id={self.block_id}, " - f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash!r}, " - f"prev_free_block={prev_block_id}, " - f"next_free_block={next_block_id})") + prev_block_id = self.prev_free_block.block_id if self.prev_free_block else None + next_block_id = self.next_free_block.block_id if self.next_free_block else None + return ( + f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " + f"_block_hash={self._block_hash!r}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})" + ) class FreeKVCacheBlockQueue: @@ -271,12 +280,14 @@ class FreeKVCacheBlockQueue: Returns: The first free block. """ - if (self.fake_free_list_head.next_free_block - is self.fake_free_list_tail - or self.fake_free_list_head.next_free_block is None): + if ( + self.fake_free_list_head.next_free_block is self.fake_free_list_tail + or self.fake_free_list_head.next_free_block is None + ): assert self.num_free_blocks == 0, ( f"num_free_blocks ({self.num_free_blocks}) is out of sync " - "with the free list.") + "with the free list." + ) raise ValueError("No free blocks available") first_block: KVCacheBlock = self.fake_free_list_head.next_free_block @@ -284,8 +295,10 @@ class FreeKVCacheBlockQueue: if first_block.next_free_block is None: # This should not happen if the block is from the free list. # It indicates a bug in the caller's logic. - raise RuntimeError("Invalid block found in popleft() " - "which doesn't have a valid next_free_block") + raise RuntimeError( + "Invalid block found in popleft() " + "which doesn't have a valid next_free_block" + ) # Connect fake_head and the next block of first_block (i.e. second block # or fake tail). @@ -360,7 +373,8 @@ class FreeKVCacheBlockQueue: """ if self.fake_free_list_tail.prev_free_block is None: raise RuntimeError( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block # Connect the new block after the last block. @@ -384,7 +398,8 @@ class FreeKVCacheBlockQueue: last_block = self.fake_free_list_tail.prev_free_block assert last_block is not None, ( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) # Add inter-connections between consecutive blocks for block in blocks: block.prev_free_block = last_block @@ -406,7 +421,8 @@ class FreeKVCacheBlockQueue: ret = [] if self.fake_free_list_head.next_free_block is None: raise RuntimeError( - "next_free_block of fake_free_list_head should always exist") + "next_free_block of fake_free_list_head should always exist" + ) # Start from the first block curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block # As long as next_free_block is available, we haven't reached to @@ -430,14 +446,16 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_features) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return ( + bool(request.mm_features) + or (request.lora_request is not None) + or (request.cache_salt is not None) + ) -def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, - end_token_idx: int, - start_mm_idx: int) -> tuple[list[Any], int]: +def _gen_mm_extra_hash_keys( + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -515,8 +533,8 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -531,10 +549,12 @@ def generate_block_hash_extra_keys( """ mm_extra_keys: list[Any] mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( - request, start_token_idx, end_token_idx, start_mm_idx) + request, start_token_idx, end_token_idx, start_mm_idx + ) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - cache_salt_keys: list[str] = [request.cache_salt] if ( - start_token_idx == 0 and request.cache_salt) else [] + cache_salt_keys: list[str] = ( + [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] + ) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys @@ -545,10 +565,11 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable[[Any], bytes], - parent_block_hash: Optional[BlockHash], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None, +) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -569,8 +590,8 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( - hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) + hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) + ) def get_request_block_hasher( @@ -597,8 +618,9 @@ def get_request_block_hasher( # last mm input. curr_mm_idx = -1 - prev_block_hash_value = (request.block_hashes[-1] - if request.block_hashes else None) + prev_block_hash_value = ( + request.block_hashes[-1] if request.block_hashes else None + ) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -608,13 +630,14 @@ def get_request_block_hasher( # MM and LoRA requests need extra keys for block-hash computation. extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, curr_mm_idx) + request, start_token_idx, end_token_idx, curr_mm_idx + ) # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] - block_hash = hash_block_tokens(caching_hash_fn, - prev_block_hash_value, block_tokens, - extra_keys) + block_hash = hash_block_tokens( + caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys + ) new_block_hashes.append(block_hash) start_token_idx += block_size @@ -625,18 +648,20 @@ def get_request_block_hasher( return request_block_hasher -def max_memory_usage_bytes(vllm_config: VllmConfig, - kv_cache_specs: Iterable[KVCacheSpec]) -> int: +def max_memory_usage_bytes( + vllm_config: VllmConfig, kv_cache_specs: Iterable[KVCacheSpec] +) -> int: """ Get the maximum memory usage in bytes for the given KV cache specs. """ - return sum( - spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) + return sum(spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) -def estimate_max_model_len(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> int: +def estimate_max_model_len( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> int: """ Estimates the maximum model length that can fit in the available memory using binary search. @@ -655,8 +680,7 @@ def estimate_max_model_len(vllm_config: VllmConfig, # Modify the max_model_len for this calculation vllm_config.model_config.max_model_len = model_len # Calculate memory needed for the given model length - memory_needed = max_memory_usage_bytes(vllm_config, - kv_cache_spec.values()) + memory_needed = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) return memory_needed <= available_memory # Binary search for the maximum model length @@ -679,9 +703,11 @@ def estimate_max_model_len(vllm_config: VllmConfig, return result -def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int): +def check_enough_kv_cache_memory( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +): """ Checks whether `available_memory` is enough for the KV cache to hold at least one request with the model's max_model_len. @@ -700,36 +726,41 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, return if available_memory <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_model_len = vllm_config.model_config.max_model_len needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) if needed_memory > available_memory: # Estimate the maximum model length that can fit in the available memory - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - available_memory) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, available_memory + ) estimated_msg = "" if estimated_max_len > 0: estimated_msg = ( "Based on the available memory, " - f"the estimated maximum model length is {estimated_max_len}.") + f"the estimated maximum model length is {estimated_max_len}." + ) raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " + f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/GiB_bytes:.2f} GiB). " + f"memory ({available_memory / GiB_bytes:.2f} GiB). " f"{estimated_msg} " f"Try increasing `gpu_memory_utilization` or decreasing " - f"`max_model_len` when initializing the engine.") + f"`max_model_len` when initializing the engine." + ) def create_kv_cache_group_specs( - kv_cache_spec: dict[str, KVCacheSpec], - grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]] +) -> list[KVCacheGroupSpec]: """ Create KVCacheGroupSpec object for each kv cache group layer. The layers in the same group should share the same @@ -752,7 +783,8 @@ def create_kv_cache_group_specs( ] merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) + ) return kv_cache_groups @@ -782,19 +814,22 @@ def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: def get_max_concurrency_for_kv_cache_config( - vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> float: """ Get the maximum concurrency for the given KV cache configuration. """ num_layer_per_group = max( - len(group.layer_names) for group in kv_cache_config.kv_cache_groups) + len(group.layer_names) for group in kv_cache_config.kv_cache_groups + ) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( - vllm_config, - (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)) - memory_per_block = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes * num_layer_per_group - num_block_per_request = cdiv(max_memory_usage_per_request, - memory_per_block) + vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) + ) + memory_per_block = ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes + * num_layer_per_group + ) + num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block) max_concurrency = kv_cache_config.num_blocks / num_block_per_request return max_concurrency @@ -804,18 +839,20 @@ def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: Override the number of kv cache blocks if `num_gpu_blocks_override` is set. """ if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override + num_gpu_blocks_override = vllm_config.cache_config.num_gpu_blocks_override logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + num_blocks, + num_gpu_blocks_override, + ) num_blocks = num_gpu_blocks_override return num_blocks -def get_num_blocks(vllm_config: VllmConfig, num_layers: int, - available_memory: int, page_size: int) -> int: +def get_num_blocks( + vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int +) -> int: """ Get the number of kv cache blocks. @@ -841,9 +878,10 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: def _get_kv_cache_groups_uniform_spec( - kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + kv_cache_specs: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model with the same KV cache + Generates the KV cache configuration for a model with the same KV cache spec for all layers. Args: @@ -853,12 +891,12 @@ def _get_kv_cache_groups_uniform_spec( The generated KVCacheGroupSpecs """ - return create_kv_cache_group_specs(kv_cache_specs, - [list(kv_cache_specs.keys())]) + return create_kv_cache_group_specs(kv_cache_specs, [list(kv_cache_specs.keys())]) def _get_kv_cache_groups_uniform_type( - spec: UniformTypeKVCacheSpecs) -> list[KVCacheGroupSpec]: + spec: UniformTypeKVCacheSpecs, +) -> list[KVCacheGroupSpec]: """ Generates the KV cache configuration for a model with one type of KV cache but different hidden sizes. All layers are merged into one group. @@ -873,8 +911,7 @@ def _get_kv_cache_groups_uniform_type( return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)] -def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. Args: @@ -888,70 +925,69 @@ def is_kv_cache_page_size_uniform( return len(page_sizes) == 1 -def is_kv_cache_type_attention_free( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: - +def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: # kv_cache_spec is an empty dict for attention free models return not kv_cache_spec def _get_kv_cache_groups_uniform_page_size( - kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache groups for hybrid models with multiple - attention types but still with a uniform page size (physical memory per + Generates the KV cache groups for hybrid models with multiple + attention types but still with a uniform page size (physical memory per block per layer) for all layers. Detailed explanation about kv cache management of hybrid models: The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. The KVCacheManager allocates different block tables for each of the 3 layers - in the pattern, and repeats each of them 10 times to generate the + in the pattern, and repeats each of them 10 times to generate the block_table for the 30 layers in the model. Therefore, we can group the layers in the model into 3 kv_cache_groups, each of which contains 10 layers in the model. The KVCacheManager allocates the block_table for each group based on its - kv_cache spec, and the model runner applies the block table to each layer + kv_cache spec, and the model runner applies the block table to each layer in the group. For example: - 1. A model only uses full attention. The pattern is - (num_hidden_layers * full), so there is only one group and the block table - is shared by all layers. It is already handled by + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. It is already handled by `_get_kv_cache_config_uniform_type`. - 2. A model with 10 full attention layers and 20 sliding window - attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + 2. A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so there are 3 kv_cache_groups, each of which represents 10 layers. To simplify the implementation, we make the following assumptions: - 1. Physical memory per block: Must be the same across all KV cache groups. + 1. Physical memory per block: Must be the same across all KV cache groups. Breaking this assumption is non-trivial due to memory fragmentation concerns when allocating blocks of different sizes. - 2. Tokens per block (block_size): Currently, we directly use - `CacheConfig.block_size` for all layers. It can be extended to vary by KV - cache group, but within each KV cache group, all layers must share the same + 2. Tokens per block (block_size): Currently, we directly use + `CacheConfig.block_size` for all layers. It can be extended to vary by KV + cache group, but within each KV cache group, all layers must share the same block size. - 3. Physical memory per token per layer: This property is decided by model - config. Currently we only support models that have the same physical memory - per token per layer for all layers. Can be relaxed with a simple extension, + 3. Physical memory per token per layer: This property is decided by model + config. Currently we only support models that have the same physical memory + per token per layer for all layers. Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. - 4. Number of layers per group: Currently assumed the same for all layers. - Can be relaxed with a simple extension, but still need to keep physical + 4. Number of layers per group: Currently assumed the same for all layers. + Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. 5. Attention type within groups: All layers in a group must share the same - attention type. One exception is that, when - `--disable-hybrid-kv-cache-manager` is true, the single group for full - attention layers may also include attention layers using sliding window or + attention type. One exception is that, when + `--disable-hybrid-kv-cache-manager` is true, the single group for full + attention layers may also include attention layers using sliding window or LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details. - 6. Support for multiple attention types: The design for most components is - general to an arbitrary number of attention types. But - `find_longest_cache_hit` only supports one attention type or two + 6. Support for multiple attention types: The design for most components is + general to an arbitrary number of attention types. But + `find_longest_cache_hit` only supports one attention type or two types of full-attention plus exactly one another type. The general - implementation of this function is feasible but we don't know how to + implementation of this function is feasible but we don't know how to implement it cleanly yet. - As we assume tokens per block, physical memory per token per layer, and - number of layers per group are the same now, we can ensure that physical + As we assume tokens per block, physical memory per token per layer, and + number of layers per group are the same now, we can ensure that physical memory per block is the same for all groups. Args: @@ -1005,10 +1041,12 @@ def _get_kv_cache_groups_uniform_page_size( return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) -def get_kv_cache_config_from_groups(vllm_config: VllmConfig, - kv_cache_groups: list[KVCacheGroupSpec], - kv_cache_specs: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def get_kv_cache_config_from_groups( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + kv_cache_specs: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: """ Generate the KV cache configuration from the KV cache groups and spec of each layer. @@ -1031,19 +1069,22 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, ) # Determine how model runners should initialize the KV cache tensors. - if len(kv_cache_groups) == 1 and \ - isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs): + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): # Special case: all layers have the same type of KV cache but with # different hidden size. Allocate different amount of memory for each # layer based on its hidden size. - num_blocks = available_memory // kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes + num_blocks = ( + available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes + ) num_blocks = may_override_num_blocks(vllm_config, num_blocks) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs kv_cache_tensors = [ - KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes * - num_blocks, - shared_by=[layer_name]) + KVCacheTensor( + size=per_layer_specs[layer_name].page_size_bytes * num_blocks, + shared_by=[layer_name], + ) for layer_name in kv_cache_groups[0].layer_names ] else: @@ -1059,8 +1100,9 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, page_size = get_uniform_page_size(kv_cache_specs) assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks(vllm_config, group_size, available_memory, - page_size) + num_blocks = get_num_blocks( + vllm_config, group_size, available_memory, page_size + ) kv_cache_tensors = [] for i in range(group_size): shared_by = [] @@ -1068,8 +1110,8 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, if i < len(kv_cache_groups[j].layer_names): shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( - KVCacheTensor(size=page_size * num_blocks, - shared_by=shared_by)) + KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) + ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -1077,8 +1119,7 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, kv_cache_groups=kv_cache_groups, ) - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_groups]) + min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. num_tokens = num_blocks // len(kv_cache_groups) * min_block_size @@ -1086,14 +1127,19 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, num_tokens *= vllm_config.parallel_config.decode_context_parallel_size logger.info( "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size) + vllm_config.parallel_config.decode_context_parallel_size, + ) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) + vllm_config, kv_cache_config + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) return kv_cache_config @@ -1108,25 +1154,27 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ if is_kv_cache_spec_uniform( - kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type( - kv_cache_spec): + kv_cache_spec + ) or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec): return logger.warning( "Hybrid KV cache manager is disabled for this hybrid model, " "This means we do not enable any optimizations for saving KV cache " "memory (e.g., dropping the KV cache outside the sliding window). " - "The compute of layers like sliding window is still saved.") + "The compute of layers like sliding window is still saved." + ) has_full_attention = any( - isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) + isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values() + ) has_sliding_window = any( - isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) + isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values() + ) has_chunked_local_attention = any( - isinstance(spec, ChunkedLocalAttentionSpec) - for spec in kv_cache_spec.values()) - if has_full_attention and (has_sliding_window - or has_chunked_local_attention): + isinstance(spec, ChunkedLocalAttentionSpec) for spec in kv_cache_spec.values() + ) + if has_full_attention and (has_sliding_window or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( @@ -1145,15 +1193,19 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if not (is_kv_cache_spec_uniform(kv_cache_spec) - or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)): - raise ValueError("Hybrid KV cache manager is disabled but failed to " - "convert the KV cache specs to one unified type.") + if not ( + is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec) + ): + raise ValueError( + "Hybrid KV cache manager is disabled but failed to " + "convert the KV cache specs to one unified type." + ) def get_kv_cache_groups( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: """ Split the layers in the model into groups with the same KV cache spec. @@ -1192,14 +1244,14 @@ def get_kv_cache_groups( def generate_scheduler_kv_cache_config( - kv_cache_configs: list[KVCacheConfig]) -> KVCacheConfig: + kv_cache_configs: list[KVCacheConfig], +) -> KVCacheConfig: """ Generate the KV cache configuration for the scheduler. """ - assert all([ - cfg.num_blocks == kv_cache_configs[0].num_blocks - for cfg in kv_cache_configs - ]) + assert all( + [cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs] + ) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. cfg = copy.deepcopy(kv_cache_configs[0]) @@ -1208,15 +1260,18 @@ def generate_scheduler_kv_cache_config( # All layers in the UniformTypeKVCacheSpecs have the same type, # so use an arbitrary one to initialize the scheduler. group.kv_cache_spec = next( - iter(group.kv_cache_spec.kv_cache_specs.values())) + iter(group.kv_cache_spec.kv_cache_specs.values()) + ) return cfg -def get_kv_cache_configs(vllm_config: VllmConfig, - kv_cache_specs: list[dict[str, KVCacheSpec]], - available_memory: list[int]) -> list[KVCacheConfig]: +def get_kv_cache_configs( + vllm_config: VllmConfig, + kv_cache_specs: list[dict[str, KVCacheSpec]], + available_memory: list[int], +) -> list[KVCacheConfig]: """ - Generates the KV cache configurations for a model. + Generates the KV cache configurations for a model. Since we use a shared centralized controller for all workers, we need the `kv_cache_config` to be consistent across all workers to make sure the KV cache allocation can be applied to all workers. However, different @@ -1235,7 +1290,7 @@ def get_kv_cache_configs(vllm_config: VllmConfig, vllm_config: The global VllmConfig kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. available_memory: Memory available for KV cache in bytes for each - worker. + worker. Returns: The generated KVCacheConfigs for each worker. @@ -1243,9 +1298,11 @@ def get_kv_cache_configs(vllm_config: VllmConfig, # Check if the available memory is enough for each worker. for kv_cache_spec_one_worker, available_memory_one_worker in zip( - kv_cache_specs, available_memory): - check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker, - available_memory_one_worker) + kv_cache_specs, available_memory + ): + check_enough_kv_cache_memory( + vllm_config, kv_cache_spec_one_worker, available_memory_one_worker + ) # Merge the KV cache specs of all workers. Different PP stages may have # different layer names, and different TP ranks of the same PP stage should @@ -1258,37 +1315,42 @@ def get_kv_cache_configs(vllm_config: VllmConfig, else: assert merged_kv_cache_specs[layer_name] == layer_spec, ( "The KV cache specs for the same layer are different " - "across workers. This is not supported yet.") - global_kv_cache_groups = get_kv_cache_groups(vllm_config, - merged_kv_cache_specs) + "across workers. This is not supported yet." + ) + global_kv_cache_groups = get_kv_cache_groups(vllm_config, merged_kv_cache_specs) kv_cache_configs: list[KVCacheConfig] = [] for kv_cache_spec_one_worker, available_memory_one_worker in zip( - kv_cache_specs, available_memory): + kv_cache_specs, available_memory + ): kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] for group in global_kv_cache_groups: group_layer_names_one_worker = [ - layer_name for layer_name in group.layer_names + layer_name + for layer_name in group.layer_names if layer_name in kv_cache_spec_one_worker ] kv_cache_groups_one_worker.append( - KVCacheGroupSpec(group_layer_names_one_worker, - group.kv_cache_spec)) + KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec) + ) assert sum( - len(group.layer_names) for group in - kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), ( - "Some layers are not assigned to any group.") + len(group.layer_names) for group in kv_cache_groups_one_worker + ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." kv_cache_configs.append( - get_kv_cache_config_from_groups(vllm_config, - kv_cache_groups_one_worker, - kv_cache_spec_one_worker, - available_memory_one_worker)) + get_kv_cache_config_from_groups( + vllm_config, + kv_cache_groups_one_worker, + kv_cache_spec_one_worker, + available_memory_one_worker, + ) + ) # Change the num_blocks of each rank to the smallest among all ranks. We # do not need to shrink the tensor size because it is valid to only use the # first `num_blocks` blocks of the tensor. - min_num_blocks = min(kv_cache_config.num_blocks - for kv_cache_config in kv_cache_configs) + min_num_blocks = min( + kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + ) for kv_cache_config in kv_cache_configs: kv_cache_config.num_blocks = min_num_blocks diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 74ff626173..968b4db530 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -12,7 +12,6 @@ logger = init_logger(__name__) class AsyncScheduler(Scheduler): - def _update_after_schedule( self, scheduler_output: SchedulerOutput, @@ -20,8 +19,10 @@ class AsyncScheduler(Scheduler): super()._update_after_schedule(scheduler_output) for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] - if (request.num_computed_tokens == request.num_tokens + - request.num_output_placeholders): + if ( + request.num_computed_tokens + == request.num_tokens + request.num_output_placeholders + ): # The request will generate a new token in this scheduling step. # TODO(woosuk): Support speculative decoding. request.num_output_placeholders += 1 @@ -33,7 +34,8 @@ class AsyncScheduler(Scheduler): ) -> tuple[list[int], bool]: status_before_update = request.status new_token_ids, stopped = super()._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Update the number of output placeholders. request.num_output_placeholders -= len(new_token_ids) @@ -42,6 +44,6 @@ class AsyncScheduler(Scheduler): # Cache the new tokens. Preempted requests should be skipped. if status_before_update == RequestStatus.RUNNING: self.kv_cache_manager.cache_blocks( - request, - request.num_computed_tokens - request.num_output_placeholders) + request, request.num_computed_tokens - request.num_output_placeholders + ) return new_token_ids, stopped diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 5b1de3a66c..b92ef395e9 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -14,7 +14,6 @@ if TYPE_CHECKING: class SchedulerInterface(ABC): - @abstractmethod def schedule(self) -> "SchedulerOutput": """Schedule the requests to process in this scheduling step. @@ -72,7 +71,7 @@ class SchedulerInterface(ABC): @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. - + Args: request: The new request being added. """ @@ -91,7 +90,7 @@ class SchedulerInterface(ABC): 1. When the request is aborted by the client. 2. When the frontend process detects a stop string of the request after de-tokenizing its generated tokens. - + Args: request_ids: A single or a list of request IDs. finished_status: The finished status of the given requests. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6874e713af..5d31811662 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -13,8 +13,7 @@ if TYPE_CHECKING: import numpy.typing as npt import torch - from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorMetadata) + from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams @@ -25,7 +24,6 @@ if TYPE_CHECKING: @bc_linter_include @dataclass class NewRequestData: - req_id: str prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] @@ -55,42 +53,43 @@ class NewRequestData: ) def __repr__(self) -> str: - prompt_embeds_shape = (self.prompt_embeds.shape - if self.prompt_embeds else None) - return (f"NewRequestData(" - f"req_id={self.req_id}," - f"prompt_token_ids={self.prompt_token_ids}," - f"mm_features={self.mm_features}," - f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," - f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}," - f"prompt_embeds_shape={prompt_embeds_shape}" - ")") + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) # Version of __repr__ with the prompt data obfuscated def anon_repr(self) -> str: - prompt_token_ids_len = len( - self.prompt_token_ids - ) if self.prompt_token_ids is not None else None - prompt_embeds_shape = (self.prompt_embeds.shape - if self.prompt_embeds else None) - return (f"NewRequestData(" - f"req_id={self.req_id}," - f"prompt_token_ids_len={prompt_token_ids_len}," - f"mm_features={self.mm_features}," - f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," - f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}," - f"prompt_embeds_shape={prompt_embeds_shape}" - ")") + prompt_token_ids_len = ( + len(self.prompt_token_ids) if self.prompt_token_ids is not None else None + ) + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={prompt_token_ids_len}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) @bc_linter_include @dataclass class CachedRequestData: - req_ids: list[str] # If resumed_from_preemption is False, new_block_ids will be appended to # the request's block IDs. If True, new_block_ids will be used as the @@ -122,7 +121,6 @@ class CachedRequestData: @bc_linter_include @dataclass class SchedulerOutput: - # list of the requests that are scheduled for the first time. # We cache the request's data in each worker process, so that we don't # need to re-send it every scheduling step. diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index fc2bc30b9a..33e5ec72eb 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -14,6 +14,7 @@ from vllm.v1.request import Request class SchedulingPolicy(Enum): """Enum for scheduling policies.""" + FCFS = "fcfs" PRIORITY = "priority" @@ -111,9 +112,7 @@ class FCFSRequestQueue(deque[Request], RequestQueue): def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - filtered_requests = [ - req for req in self if req not in requests_to_remove - ] + filtered_requests = [req for req in self if req not in requests_to_remove] # deque does not support in-place filtering, so we need to clear # and extend self.clear() @@ -150,8 +149,7 @@ class PriorityRequestQueue(RequestQueue): def add_request(self, request: Request) -> None: """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, - (request.priority, request.arrival_time, request)) + heapq.heappush(self._heap, (request.priority, request.arrival_time, request)) def pop_request(self) -> Request: """Pop a request from the queue according to priority policy.""" @@ -169,15 +167,15 @@ class PriorityRequestQueue(RequestQueue): def prepend_request(self, request: Request) -> None: """Add a request to the queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the + + Note: In a priority queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: """Add all requests from another queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the + + Note: In a priority queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" for request in requests: self.add_request(request) @@ -190,8 +188,9 @@ class PriorityRequestQueue(RequestQueue): def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - self._heap = [(p, t, r) for p, t, r in self._heap - if r not in requests_to_remove] + self._heap = [ + (p, t, r) for p, t, r in self._heap if r not in requests_to_remove + ] heapq.heapify(self._heap) def __bool__(self) -> bool: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 6983ccca51..24ff87cd0a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -11,25 +11,24 @@ from typing import Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + compute_encoder_budget, +) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -41,7 +40,6 @@ logger = init_logger(__name__) class Scheduler(SchedulerInterface): - def __init__( self, vllm_config: VllmConfig, @@ -67,16 +65,17 @@ class Scheduler(SchedulerInterface): # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) + defaultdict(set) if include_finished_set else None + ) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) + and self.kv_events_config.enable_kv_cache_events + ) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -85,12 +84,14 @@ class Scheduler(SchedulerInterface): if self.vllm_config.kv_transfer_config is not None: assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " - "with KV connectors") + "with KV connectors" + ) assert not self.is_encoder_decoder, ( - "Encoder-decoder models are not currently supported " - "with KV connectors") + "Encoder-decoder models are not currently supported with KV connectors" + ) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + config=self.vllm_config, role=KVConnectorRole.SCHEDULER + ) self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -102,8 +103,7 @@ class Scheduler(SchedulerInterface): self.block_size = self.cache_config.block_size - self.dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size + self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # Note(hc): The scheduler’s block_size must be multiplied # by dcp_world_size, since block hashes are computed on the # original full token sequence at a granularity of @@ -120,7 +120,8 @@ class Scheduler(SchedulerInterface): self.policy = SchedulingPolicy.FCFS else: raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") + f"Unknown scheduling policy: {self.scheduler_config.policy}" + ) # Priority queues for requests. self.waiting = create_request_queue(self.policy) self.running: list[Request] = [] @@ -153,8 +154,7 @@ class Scheduler(SchedulerInterface): # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -211,30 +211,35 @@ class Scheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, - self.max_model_len - request.num_computed_tokens) + num_new_tokens, self.max_model_len - request.num_computed_tokens + ) # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -257,7 +262,8 @@ class Scheduler(SchedulerInterface): new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is not None: # The request can be scheduled. @@ -282,8 +288,9 @@ class Scheduler(SchedulerInterface): preempted_req.num_computed_tokens = 0 preempted_req.num_preemptions += 1 if self.log_stats: - preempted_req.record_event(EngineCoreEventType.PREEMPTED, - scheduled_timestamp) + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) @@ -304,19 +311,21 @@ class Scheduler(SchedulerInterface): # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids + ) # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -326,8 +335,10 @@ class Scheduler(SchedulerInterface): scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -350,7 +361,8 @@ class Scheduler(SchedulerInterface): else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -368,9 +380,14 @@ class Scheduler(SchedulerInterface): # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -382,15 +399,17 @@ class Scheduler(SchedulerInterface): # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + request, num_new_local_computed_tokens + ) + ) if num_external_computed_tokens is None: # The request cannot be scheduled because @@ -401,13 +420,15 @@ class Scheduler(SchedulerInterface): continue # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + self.kv_cache_manager.create_empty_block_list() + ) num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -424,15 +445,21 @@ class Scheduler(SchedulerInterface): # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + self.scheduler_config.long_prefill_token_threshold + ) # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.chunked_prefill_enabled and \ - num_new_tokens > token_budget: + if ( + not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget + ): self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -442,11 +469,16 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -456,9 +488,9 @@ class Scheduler(SchedulerInterface): # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -466,8 +498,9 @@ class Scheduler(SchedulerInterface): # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens =\ + num_encoder_tokens = ( self.scheduler_config.max_num_encoder_input_tokens + ) else: num_encoder_tokens = 0 @@ -509,20 +542,21 @@ class Scheduler(SchedulerInterface): req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + self.kv_cache_manager.get_blocks(request.request_id) + ) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -533,7 +567,8 @@ class Scheduler(SchedulerInterface): # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -551,23 +586,26 @@ class Scheduler(SchedulerInterface): # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) + any_request, len(self.running) + ) + ) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + req, req_to_new_blocks[req.request_id].get_block_ids() + ) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -577,11 +615,12 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) + scheduled_requests = ( + scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs + ) + structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( + scheduled_requests, scheduled_spec_decode_tokens + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -595,8 +634,7 @@ class Scheduler(SchedulerInterface): # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -678,16 +716,18 @@ class Scheduler(SchedulerInterface): for req in itertools.chain(running_reqs, resumed_reqs): req_id = req.request_id req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] + token_ids = req.all_token_ids[ + req.num_computed_tokens : req.num_computed_tokens + num_tokens + ] new_token_ids.append(token_ids) elif use_connector: # When using a KVConnector, we add a placeholder to avoid index @@ -695,7 +735,8 @@ class Scheduler(SchedulerInterface): # is updated to handle token IDs properly. new_token_ids.append([]) new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + req_to_new_blocks[req_id].get_block_ids(allow_none=True) + ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(len(req.output_token_ids)) # Because resumed_reqs is usually empty, it is more efficient to do @@ -764,7 +805,8 @@ class Scheduler(SchedulerInterface): if self.is_encoder_decoder and num_computed_tokens > 0: assert start_pos == 0, ( "Encoder input should be processed at the beginning of " - "the sequence when encoder-decoder models are used.") + "the sequence when encoder-decoder models are used." + ) # Encoder input has already been computed # The calculation here is a bit different. We don't turn encoder # output into tokens that get processed by the decoder and @@ -788,8 +830,7 @@ class Scheduler(SchedulerInterface): # current step. continue - if self.encoder_cache_manager.check_and_update_cache( - request, i): + if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached from a # previous step. continue @@ -797,16 +838,18 @@ class Scheduler(SchedulerInterface): # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + if ( + self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens) + ): num_new_tokens = start_pos - num_computed_tokens break if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, - num_tokens_to_schedule): + request, i, encoder_compute_budget, num_tokens_to_schedule + ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -879,8 +922,9 @@ class Scheduler(SchedulerInterface): outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats = (kv_connector_output.kv_connector_stats - if kv_connector_output else None) + kv_connector_stats = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: @@ -888,7 +932,8 @@ class Scheduler(SchedulerInterface): # load. Identify affected requests and adjust their computed token # count to trigger recomputation of the invalid blocks. failed_kv_load_req_ids = self._handle_invalid_blocks( - kv_connector_output.invalid_block_ids) + kv_connector_output.invalid_block_ids + ) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -908,11 +953,13 @@ class Scheduler(SchedulerInterface): continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] + ) scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 @@ -926,7 +973,8 @@ class Scheduler(SchedulerInterface): spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted) + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -937,14 +985,14 @@ class Scheduler(SchedulerInterface): # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -954,28 +1002,29 @@ class Scheduler(SchedulerInterface): stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): + if new_token_ids and self.structured_output_manager.should_advance(request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # checked above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + req_id, new_token_ids + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -990,7 +1039,8 @@ class Scheduler(SchedulerInterface): kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, - )) + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1023,11 +1073,13 @@ class Scheduler(SchedulerInterface): eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + finished_requests=finished_set + ) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1058,8 +1110,9 @@ class Scheduler(SchedulerInterface): return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) + cached_encoder_input_ids = self.encoder_cache_manager.get_cached_input_ids( + request + ) # OPTIMIZATION: Avoid list(set) if the set is empty. if not cached_encoder_input_ids: return @@ -1074,21 +1127,19 @@ class Scheduler(SchedulerInterface): # With Whisper, as soon as we've generated a single token, # we know we're done with the encoder input. Cross Attention # KVs have been calculated and cached already. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, ): request = self.requests.get(req_id) if request is None or request.is_finished(): @@ -1102,7 +1153,8 @@ class Scheduler(SchedulerInterface): elif self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_token_ids + ) else: request.spec_token_ids = spec_token_ids @@ -1128,7 +1180,7 @@ class Scheduler(SchedulerInterface): """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): - request_ids = (request_ids, ) + request_ids = (request_ids,) else: request_ids = set(request_ids) @@ -1198,15 +1250,15 @@ class Scheduler(SchedulerInterface): return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - return SchedulerStats(num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), + kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + ) def make_spec_decoding_stats( self, @@ -1219,8 +1271,8 @@ class Scheduler(SchedulerInterface): if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens + ) return spec_decoding_stats def shutdown(self) -> None: @@ -1237,7 +1289,8 @@ class Scheduler(SchedulerInterface): return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + self, request: Request + ) -> tuple[bool, Optional[dict[str, Any]]]: """ Invoke the KV connector request_finished() method if applicable. @@ -1247,7 +1300,7 @@ class Scheduler(SchedulerInterface): if self.connector is None: return False, None - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -1271,8 +1324,7 @@ class Scheduler(SchedulerInterface): # updated in _update_requests_with_invalid_blocks if request.num_computed_tokens: # Cache any valid computed tokens. - self.kv_cache_manager.cache_blocks(request, - request.num_computed_tokens) + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) else: # No valid computed tokens, release allocated blocks. # There may be a local cache hit on retry. @@ -1281,8 +1333,7 @@ class Scheduler(SchedulerInterface): self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids( - request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size # Handle the case where num request tokens less than one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens) @@ -1298,8 +1349,7 @@ class Scheduler(SchedulerInterface): self.finished_recving_kv_req_ids.remove(request.request_id) return True - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1314,21 +1364,23 @@ class Scheduler(SchedulerInterface): self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): + for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): + for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) if req_id not in self.requests: logger.warning( "Got finished sending KV transfer for request %s," - "but the request is already freed.", req_id) + "but the request is already freed.", + req_id, + ) else: self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], - invalid_block_ids: set[int]) -> tuple[set[str], int]: + self, requests: Iterable[Request], invalid_block_ids: set[int] + ) -> tuple[set[str], int]: """ Identify and update requests affected by invalid KV cache blocks. @@ -1359,25 +1411,25 @@ class Scheduler(SchedulerInterface): marked_invalid_block = False req_id = request.request_id # TODO (davidb): add support for hybrid memory allocator - (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) + (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) # We iterate only over blocks that may contain externally computed # tokens if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: # Async loading. If num_computed_tokens is set it implies we # already processed some block failures for it in a prior step req_num_computed_tokens = ( - request.num_computed_tokens if req_id - in self.failed_recving_kv_req_ids else len(req_block_ids) * - self.block_size) + request.num_computed_tokens + if req_id in self.failed_recving_kv_req_ids + else len(req_block_ids) * self.block_size + ) else: # Sync loading. num_computed_tokens includes new tokens req_num_computed_tokens = request.num_cached_tokens - req_num_computed_blocks = (req_num_computed_tokens + - self.block_size - 1) // self.block_size - for idx, block_id in zip(range(req_num_computed_blocks), - req_block_ids): - + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): if block_id not in invalid_block_ids: continue @@ -1402,8 +1454,9 @@ class Scheduler(SchedulerInterface): marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += (req_num_computed_tokens - - request.num_computed_tokens) + total_affected_tokens += ( + req_num_computed_tokens - request.num_computed_tokens + ) if is_affected: if not marked_invalid_block: @@ -1412,8 +1465,9 @@ class Scheduler(SchedulerInterface): # Revert to considering only cached tokens as computed. # Currently this only applies to sync loading; Async # loading does not yet support block sharing - total_affected_tokens += (request.num_computed_tokens - - request.num_cached_tokens) + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) request.num_computed_tokens = request.num_cached_tokens affected_req_ids.add(request.request_id) @@ -1426,11 +1480,15 @@ class Scheduler(SchedulerInterface): # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- async_load_reqs = ( - req for req in self.waiting - if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + req + for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ) async_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(async_load_reqs, - invalid_block_ids)) + self._update_requests_with_invalid_blocks( + async_load_reqs, invalid_block_ids + ) + ) total_requests_to_reschedule += len(async_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1441,8 +1499,8 @@ class Scheduler(SchedulerInterface): # --- Handle sync KV loads (running requests) --- sync_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(self.running, - invalid_block_ids)) + self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + ) total_requests_to_reschedule += len(sync_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1451,7 +1509,9 @@ class Scheduler(SchedulerInterface): logger.warning( "Recovered from KV load failure: " "%d request(s) rescheduled (%d tokens affected).", - total_requests_to_reschedule, total_tokens_to_reschedule) + total_requests_to_reschedule, + total_tokens_to_reschedule, + ) # Return the IDs of affected running requests to skip in # update_from_output. diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 6b321f4ebb..0979100ed3 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -40,11 +40,13 @@ def remove_all(lst: list, items_to_remove: set) -> list: return [item for item in lst if item not in items_to_remove] -def check_stop(request: Request, - max_model_len: int, - pooler_output: Optional[torch.Tensor] = None) -> bool: - if (request.num_tokens > max_model_len - or request.num_output_tokens >= request.max_tokens): +def check_stop( + request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None +) -> bool: + if ( + request.num_tokens > max_model_len + or request.num_output_tokens >= request.max_tokens + ): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True @@ -57,8 +59,7 @@ def check_stop(request: Request, sampling_params = request.sampling_params assert sampling_params is not None last_token_id = request.output_token_ids[-1] - if (not sampling_params.ignore_eos - and last_token_id == request.eos_token_id): + if not sampling_params.ignore_eos and last_token_id == request.eos_token_id: request.status = RequestStatus.FINISHED_STOPPED return True diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 07777efc32..0f71796014 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,16 +7,21 @@ from collections import defaultdict from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - CrossAttentionSpec, FullAttentionSpec, - KVCacheSpec, MambaSpec, - MLAAttentionSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) from vllm.v1.request import Request class SingleTypeKVCacheManager(ABC): """ - An abstract base class for a manager that handle the kv cache management + An abstract base class for a manager that handle the kv cache management logic of one specific type of attention layer. """ @@ -44,8 +49,7 @@ class SingleTypeKVCacheManager(ABC): # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. @@ -57,14 +61,14 @@ class SingleTypeKVCacheManager(ABC): self._null_block = block_pool.null_block def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -74,20 +78,23 @@ class SingleTypeKVCacheManager(ABC): """ num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = (num_required_blocks - len(new_computed_blocks) - - len(self.req_to_blocks[request_id])) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block # to a computed block when the request is allocated, so we also count # it as needed to be allocated. num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks) + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: """ Add the new computed blocks to the request. @@ -106,15 +113,16 @@ class SingleTypeKVCacheManager(ABC): # A running request. Should not have new computed blocks. assert len(new_computed_blocks) == 0 - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). Returns: @@ -136,7 +144,7 @@ class SingleTypeKVCacheManager(ABC): Args: request: The request. - num_tokens: The total number of tokens that need to be cached + num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ num_cached_blocks = self.num_cached_block[request.request_id] @@ -174,8 +182,9 @@ class SingleTypeKVCacheManager(ABC): self.num_cached_block.pop(request_id, None) @abstractmethod - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ Get the number of common prefix blocks for all requests in the RUNNING state. @@ -205,12 +214,12 @@ class SingleTypeKVCacheManager(ABC): dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ - Get the longest cache hit prefix of the blocks that is not longer than - `max_length`. The prefix should be a common prefix hit for all the - kv cache groups in `kv_cache_group_ids`. If no cache hit is found, - return an empty list. - If eagle is enabled, drop the last matched block to force recompute the - last block to get the required hidden states for eagle drafting head. + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. The prefix should be a common prefix hit for all the + kv cache groups in `kv_cache_group_ids`. If no cache hit is found, + return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type. Args: @@ -235,10 +244,9 @@ class SingleTypeKVCacheManager(ABC): raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and free the + Remove the blocks that are no longer needed from `blocks` and free the blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. @@ -250,7 +258,6 @@ class SingleTypeKVCacheManager(ABC): class FullAttentionManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -264,10 +271,13 @@ class FullAttentionManager(SingleTypeKVCacheManager): ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) - ), "FullAttentionManager can only be used for full attention " \ + ), ( + "FullAttentionManager can only be used for full attention " "and chunked local attention groups" + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) + [] for _ in range(len(kv_cache_group_ids)) + ) block_size = kv_cache_spec.block_size if dcp_world_size > 1: block_size *= dcp_world_size @@ -277,7 +287,8 @@ class FullAttentionManager(SingleTypeKVCacheManager): # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: @@ -287,13 +298,13 @@ class FullAttentionManager(SingleTypeKVCacheManager): computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # No need to remove blocks for full attention. pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: blocks = self.req_to_blocks[request_id] num_common_blocks = 0 for block in blocks: @@ -305,9 +316,9 @@ class FullAttentionManager(SingleTypeKVCacheManager): class SlidingWindowManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - **kwargs) -> None: + def __init__( + self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.sliding_window = kv_cache_spec.sliding_window self._null_block = block_pool.null_block @@ -324,13 +335,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager): dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( - "SlidingWindowManager can only be used for sliding window groups") + "SlidingWindowManager can only be used for sliding window groups" + ) assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window sliding_window_contiguous_blocks = cdiv( - kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size) + kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size + ) if use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of @@ -344,14 +357,17 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // kv_cache_spec.block_size - computed_blocks = tuple([block_pool.null_block] * max_num_blocks - for _ in range(len(kv_cache_group_ids))) + computed_blocks = tuple( + [block_pool.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids)) + ) num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( - block_hashes[i], kv_cache_group_ids): + block_hashes[i], kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached num_contiguous_blocks += 1 @@ -360,7 +376,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. for computed in computed_blocks: - del computed[i + num_contiguous_blocks:] + del computed[i + num_contiguous_blocks :] match_found = True break else: @@ -375,8 +391,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 @@ -393,21 +408,22 @@ class SlidingWindowManager(SingleTypeKVCacheManager): blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ NOTE(Chen): The prefix blocks are null blocks for sliding window layers. - So it's not correct to count ref_cnt like FullAttentionManager. Return - 0 here for correctness. Need to support cascade attention + sliding + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding window in the future. """ return 0 class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, - block_pool: BlockPool, **kwargs) -> None: + def __init__( + self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size self._null_block = block_pool.null_block @@ -428,19 +444,19 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): prefix of the blocks that is not longer than `max_length`. The prefix should be a common prefix hit for all the kv cache groups in `kv_cache_group_ids`. If no cache hit is found, return an empty list. - note we mark as computed if the whole block is outside of the local + note we mark as computed if the whole block is outside of the local window, and set the block as null. Examples: 1. Attention chunk size of 8, block size of 4, max length of 15 - for next token at 15th (zero-indexed), 8th - 14th tokens are in - the window(needs lookup), 0th - 7th are not in the window, - so they are already marked as computed. We check the complete - block3 (8th - 11th tokens), Assume block 3 is hit, we will return + for next token at 15th (zero-indexed), 8th - 14th tokens are in + the window(needs lookup), 0th - 7th are not in the window, + so they are already marked as computed. We check the complete + block3 (8th - 11th tokens), Assume block 3 is hit, we will return [null, null, block 3], otherwise, we return [null, null] 2. Attention chunk size of 8, block size of 4, max length of 16 - for next token at 16th (zero-indexed), 0th - 15th tokens are not - in the window, so they are already marked as computed. + for next token at 16th (zero-indexed), 0th - 15th tokens are not + in the window, so they are already marked as computed. we return 4 blocks[null, null, null, null] Args: @@ -455,39 +471,45 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): A list of cached blocks """ assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( - "ChunkedLocalAttentionManager can only be used for " + - "chunked local attention groups") - assert use_eagle is False, ("Hybrid KV cache is not supported for " + - "eagle + chunked local attention.") + "ChunkedLocalAttentionManager can only be used for " + + "chunked local attention groups" + ) + assert use_eagle is False, ( + "Hybrid KV cache is not supported for " + "eagle + chunked local attention." + ) assert dcp_world_size == 1, "DCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: - local_attention_start_idx = (max_length // - kv_cache_spec.attention_chunk_size * - kv_cache_spec.attention_chunk_size) + local_attention_start_idx = ( + max_length + // kv_cache_spec.attention_chunk_size + * kv_cache_spec.attention_chunk_size + ) else: local_attention_start_idx = 0 # we marked blocks out of window as computed # with null blocks, and blocks inside window based on cache lookup # result [null] [null] ... [null] [hit block 1 (1st block contain # last window)] [hit block 2] ... [hit block x] - local_attention_start_block_idx = (local_attention_start_idx // - kv_cache_spec.block_size) + local_attention_start_block_idx = ( + local_attention_start_idx // kv_cache_spec.block_size + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [block_pool.null_block] * local_attention_start_block_idx - for _ in range(len(kv_cache_group_ids))) + for _ in range(len(kv_cache_group_ids)) + ) for i in range(local_attention_start_block_idx, max_num_blocks): block_hash = block_hashes[i] if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: break return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the chunked attention # window and skipped during the attention computation. @@ -499,13 +521,14 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): # is 1024. for 1023, it will be 0. num_cached_block = self.num_cached_block.get(request_id, 0) local_attention_start_idx = ( - num_computed_tokens - ) // self.attention_chunk_size * self.attention_chunk_size + (num_computed_tokens) + // self.attention_chunk_size + * self.attention_chunk_size + ) first_useful_block_idx = local_attention_start_idx // self.block_size if num_cached_block > 0: # Make sure we don't delete the last cached block - first_useful_block_idx = min(first_useful_block_idx, - num_cached_block - 1) + first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1) # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> # block 8, 372 (= 128 * 2 + 116) -> block 2 blocks = self.req_to_blocks[request_id] @@ -521,8 +544,9 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ cascade attention is not supported by chunked local attention. """ @@ -530,7 +554,6 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): class MambaManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -542,18 +565,20 @@ class MambaManager(SingleTypeKVCacheManager): use_eagle: bool, dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance( - kv_cache_spec, - MambaSpec), ("MambaManager can only be used for mamba groups") + assert isinstance(kv_cache_spec, MambaSpec), ( + "MambaManager can only be used for mamba groups" + ) assert dcp_world_size == 1, "DCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) + [] for _ in range(len(kv_cache_group_ids)) + ) max_num_blocks = max_length // kv_cache_spec.block_size # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( - block_hashes[i], kv_cache_group_ids): + block_hashes[i], kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): # the hit length logic later assumes: # hit_length = len(hit_blocks_other_attn[0]) @@ -566,40 +591,46 @@ class MambaManager(SingleTypeKVCacheManager): return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Here unused blocks may be freed up for running requests. # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 # (for which find_longest_cache_hit returns block_pool.null_block) pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ cascade attention is not supported by mamba """ return 0 def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += (self.kv_cache_spec.block_size * - self.kv_cache_spec.num_speculative_blocks) - return super().get_num_blocks_to_allocate(request_id, num_tokens, - new_computed_blocks) + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks + ) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += (self.kv_cache_spec.block_size * - self.kv_cache_spec.num_speculative_blocks) + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) return super().allocate_new_blocks(request_id, num_tokens) @@ -607,8 +638,8 @@ class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: # We do not cache blocks for cross-attention to be shared between # requests, so `new_computed_blocks` should always be empty. assert len(new_computed_blocks) == 0 @@ -618,8 +649,9 @@ class CrossAttentionManager(SingleTypeKVCacheManager): # requests, so this method is not relevant. raise ValueError("Should not be called as prefix caching is disabled.") - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: # Cross-attention blocks contain request-specific encoder states # and are not shared between different requests return 0 @@ -644,11 +676,9 @@ class CrossAttentionManager(SingleTypeKVCacheManager): # 2. Encoder states are computed once per request, not incrementally # 3. No reusable prefix exists between different multimodal inputs # Return empty blocks to indicate no cache hits - raise NotImplementedError( - "CrossAttentionManager does not support caching") + raise NotImplementedError("CrossAttentionManager does not support caching") - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Cross-attention blocks represent encoder states which are needed # for the entire decoding process, so no blocks should be skipped pass @@ -664,8 +694,9 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { } -def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, - **kwargs) -> SingleTypeKVCacheManager: +def get_manager_for_kv_cache_spec( + kv_cache_spec: KVCacheSpec, **kwargs +) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 29bb220760..ce47147028 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -12,14 +12,14 @@ class CudagraphDispatcher: cudagraphs. The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one - for FULL cudagraph runtime mode. The keys are initialized depending on - attention support and what cudagraph mode is set in CompilationConfig. The + for FULL cudagraph runtime mode. The keys are initialized depending on + attention support and what cudagraph mode is set in CompilationConfig. The keys stored in dispatcher are the only source of truth for valid cudagraphs that can be dispatched at runtime. - At runtime, the dispatch method generates the runtime cudagraph mode (FULL, + At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) - based on the input key. After dispatching (communicated via forward + based on the input key. After dispatching (communicated via forward context), the cudagraph wrappers will trust the dispatch key to either capture or replay (if the mode matches), or pass through to the underlying runnable without cudagraph (if the mode does not match or mode is NONE). @@ -37,28 +37,35 @@ class CudagraphDispatcher: } not_use_piecewise_compilation = ( - not self.cudagraph_mode.requires_piecewise_compilation()) + not self.cudagraph_mode.requires_piecewise_compilation() + ) - assert not_use_piecewise_compilation or \ - self.compilation_config.is_attention_compiled_piecewise(), \ - "Compilation level should be CompilationLevel.PIECEWISE when "\ - "cudagraph_mode piecewise cudagraphs is used, "\ - "and attention should be in splitting_ops or "\ - "inductor splitting should be used. " \ - f"cudagraph_mode={self.cudagraph_mode}, "\ - f"compilation_level={self.compilation_config.level}, "\ + assert ( + not_use_piecewise_compilation + or self.compilation_config.is_attention_compiled_piecewise() + ), ( + "Compilation level should be CompilationLevel.PIECEWISE when " + "cudagraph_mode piecewise cudagraphs is used, " + "and attention should be in splitting_ops or " + "inductor splitting should be used. " + f"cudagraph_mode={self.cudagraph_mode}, " + f"compilation_level={self.compilation_config.level}, " f"splitting_ops={self.compilation_config.splitting_ops}" + ) self.keys_initialized = False - def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, - batch_descriptor: BatchDescriptor): - assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + def add_cudagraph_key( + self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor + ): + assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( f"Invalid cudagraph runtime mode for keys: {runtime_mode}" + ) self.cudagraph_keys[runtime_mode].add(batch_descriptor) - def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, - uniform_decode_query_len: int): + def initialize_cudagraph_keys( + self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int + ): # This should be called only after attention backend is initialized. # Note: we create all valid keys for cudagraph here but do not @@ -68,33 +75,38 @@ class CudagraphDispatcher: for bs in self.compilation_config.cudagraph_capture_sizes: self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor(num_tokens=bs, uniform_decode=False)) + BatchDescriptor(num_tokens=bs, uniform_decode=False), + ) # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ - and cudagraph_mode.separate_routine(): - max_num_tokens = uniform_decode_query_len * \ - self.vllm_config.scheduler_config.max_num_seqs + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + uniform_decode_query_len + * self.vllm_config.scheduler_config.max_num_seqs + ) cudagraph_capture_sizes_for_decode = [ - x for x in self.compilation_config.cudagraph_capture_sizes + x + for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs in cudagraph_capture_sizes_for_decode: self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor(num_tokens=bs, uniform_decode=True)) + BatchDescriptor(num_tokens=bs, uniform_decode=True), + ) self.keys_initialized = True def dispatch( - self, - batch_descriptor: BatchDescriptor, - use_cascade_attn: bool = False + self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. - A new batch descriptor is returned as we might dispatch a uniform batch + A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ # if not initialized, just skip dispatching. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 345f5a464c..163c050e55 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -32,6 +32,7 @@ class FinishReason(enum.IntEnum): abort - aborted for another reason """ + STOP = 0 LENGTH = 1 ABORT = 2 @@ -41,11 +42,11 @@ class FinishReason(enum.IntEnum): class EngineCoreRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] request_id: str prompt_token_ids: Optional[list[int]] mm_features: Optional[list[MultiModalFeatureSpec]] @@ -73,6 +74,7 @@ class EngineCoreRequest( class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" + QUEUED = 1 SCHEDULED = 2 PREEMPTED = 3 @@ -85,23 +87,24 @@ class EngineCoreEvent(msgspec.Struct): frontend to calculate intervals between engine core events. These timestamps should not be compared with timestamps from other processes. """ + type: EngineCoreEventType timestamp: float @classmethod - def new_event(cls, - event_type: EngineCoreEventType, - timestamp: Optional[float] = None) -> "EngineCoreEvent": + def new_event( + cls, event_type: EngineCoreEventType, timestamp: Optional[float] = None + ) -> "EngineCoreEvent": timestamp = time.monotonic() if timestamp is None else timestamp return cls(event_type, timestamp) class EngineCoreOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] request_id: str new_token_ids: list[int] @@ -132,10 +135,10 @@ class UtilityResult: class UtilityOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] call_id: int # Non-None implies the call failed, result should be None. @@ -144,11 +147,11 @@ class UtilityOutput( class EngineCoreOutputs( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout @@ -179,12 +182,13 @@ class EngineCoreRequestType(enum.Enum): Request types defined as hex byte strings, so it can be sent over sockets without separate encoding step. """ - ADD = b'\x00' - ABORT = b'\x01' - START_DP_WAVE = b'\x02' - UTILITY = b'\x03' + + ADD = b"\x00" + ABORT = b"\x01" + START_DP_WAVE = b"\x02" + UTILITY = b"\x03" # Sentinel used within EngineCoreProc. - EXECUTOR_FAILED = b'\x04' + EXECUTOR_FAILED = b"\x04" class ReconfigureDistributedRequest(msgspec.Struct): @@ -199,5 +203,6 @@ class ReconfigureRankType(enum.IntEnum): """ Rank type for reconfiguring distributed request. """ + KEEP_CURRENT_RANK = -1 SHUTDOWN_CURRENT_RANK = -2 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ab3a4e5e6f..ca668bc217 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -27,18 +27,14 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - init_tokenizer_from_configs) +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, - deprecate_kwargs) +from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError -from vllm.v1.engine.output_processor import (OutputProcessor, - RequestOutputCollector) +from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -50,7 +46,6 @@ logger = init_logger(__name__) class AsyncLLM(EngineClient): - def __init__( self, vllm_config: VllmConfig, @@ -91,7 +86,8 @@ class AsyncLLM(EngineClient): "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) # Ensure we can serialize custom transformer configs maybe_register_config_serialize_by_value() @@ -105,14 +101,16 @@ class AsyncLLM(EngineClient): if not log_stats and stat_loggers is not None: logger.info( "AsyncLLM created with log_stats=False and non-empty custom " - "logger list; enabling logging without default stat loggers") + "logger list; enabling logging without default stat loggers" + ) if self.model_config.skip_tokenizer_init: self.tokenizer = None else: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config) + model_config=vllm_config.model_config + ) # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -122,12 +120,13 @@ class AsyncLLM(EngineClient): ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + self.tokenizer, log_stats=self.log_stats + ) if self.observability_config.otlp_traces_endpoint is not None: tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) + "vllm.llm_engine", self.observability_config.otlp_traces_endpoint + ) self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). @@ -163,7 +162,8 @@ class AsyncLLM(EngineClient): if envs.VLLM_TORCH_PROFILER_DIR: logger.info( "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 - envs.VLLM_TORCH_PROFILER_DIR) + envs.VLLM_TORCH_PROFILER_DIR, + ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" self.profiler = torch.profiler.profile( activities=[ @@ -171,37 +171,39 @@ class AsyncLLM(EngineClient): ], with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, - worker_name=worker_name, - use_gzip=True)) + envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True + ), + ) else: self.profiler = None @classmethod @deprecate_kwargs( "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), + additional_message=( + "This argument will have no effect. Use `enable_log_requests` instead." + ), ) def from_vllm_config( - cls, - vllm_config: VllmConfig, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0, - disable_log_requests: bool = True, # Deprecated, will be removed + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_log_requests: bool = False, + disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) # Create the LLMEngine. return cls( @@ -288,14 +290,20 @@ class AsyncLLM(EngineClient): assert prompt_text is None logger.warning_once( "Processor has been moved under OpenAIServing and will " - "be removed from AsyncLLM in v0.13.") - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - tokenization_kwargs, - trace_headers, priority, - data_parallel_rank) - prompt_text = (prompt if isinstance(prompt, str) else - prompt.get("prompt")) + "be removed from AsyncLLM in v0.13." + ) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + data_parallel_rank, + ) + prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if is_pooling or params.n == 1: await self._add_request(request, prompt_text, None, 0, queue) @@ -310,22 +318,24 @@ class AsyncLLM(EngineClient): parent_request = ParentRequest(request_id, parent_params) for idx in range(parent_params.n): request_id, child_params = parent_request.get_child_info(idx) - child_request = request if idx == parent_params.n - 1 else copy( - request) + child_request = request if idx == parent_params.n - 1 else copy(request) child_request.request_id = request_id child_request.sampling_params = child_params - await self._add_request(child_request, prompt_text, parent_request, - idx, queue) + await self._add_request( + child_request, prompt_text, parent_request, idx, queue + ) return queue - async def _add_request(self, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest], index: int, - queue: RequestOutputCollector): - + async def _add_request( + self, + request: EngineCoreRequest, + prompt: Optional[str], + parent_req: Optional[ParentRequest], + index: int, + queue: RequestOutputCollector, + ): # Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, prompt, parent_req, index, - queue) + self.output_processor.add_request(request, prompt, parent_req, index, queue) # Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(request) @@ -366,12 +376,15 @@ class AsyncLLM(EngineClient): returning the RequestOutput back to the caller. """ - if (self.vllm_config.cache_config.kv_sharing_fast_prefill - and sampling_params.prompt_logprobs): + if ( + self.vllm_config.cache_config.kv_sharing_fast_prefill + and sampling_params.prompt_logprobs + ): raise ValueError( "--kv-sharing-fast-prefill produces incorrect logprobs for " "prompt tokens, please disable it when the requests need " - "prompt logprobs") + "prompt logprobs" + ) try: # We start the output_handler on the first call to generate() so @@ -389,15 +402,17 @@ class AsyncLLM(EngineClient): tokenization_kwargs, ) - q = await self.add_request(request_id, - prompt, - sampling_params, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - trace_headers=trace_headers, - priority=priority, - data_parallel_rank=data_parallel_rank, - prompt_text=prompt_text) + q = await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + data_parallel_rank=data_parallel_rank, + prompt_text=prompt_text, + ) # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. @@ -460,23 +475,26 @@ class AsyncLLM(EngineClient): outputs = await engine_core.get_output_async() num_outputs = len(outputs.outputs) - iteration_stats = IterationStats() if ( - log_stats and num_outputs) else None + iteration_stats = ( + IterationStats() if (log_stats and num_outputs) else None + ) # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: - slices = (outputs.outputs, ) + slices = (outputs.outputs,) else: slices = np.array_split( outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE), + ) for i, outputs_slice in enumerate(slices): # 2) Process EngineCoreOutputs. processed_outputs = output_processor.process_outputs( - outputs_slice, outputs.timestamp, iteration_stats) + outputs_slice, outputs.timestamp, iteration_stats + ) # NOTE: RequestOutputs are pushed to their queues. assert not processed_outputs.request_outputs @@ -486,7 +504,8 @@ class AsyncLLM(EngineClient): # 3) Abort any reqs that finished due to stop strings. await engine_core.abort_requests_async( - processed_outputs.reqs_to_abort) + processed_outputs.reqs_to_abort + ) # 4) Logging. # TODO(rob): make into a coroutine and launch it in @@ -506,8 +525,9 @@ class AsyncLLM(EngineClient): async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = (request_id, ) if isinstance( - request_id, str) else as_list(request_id) + request_ids = ( + (request_id,) if isinstance(request_id, str) else as_list(request_id) + ) all_request_ids = self.output_processor.abort_requests(request_ids) await self.engine_core.abort_requests_async(all_request_ids) @@ -614,8 +634,9 @@ class AsyncLLM(EngineClient): async def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") + raise ValueError( + "Unable to get tokenizer because skip_tokenizer_init is True" + ) return self.tokenizer @@ -647,8 +668,7 @@ class AsyncLLM(EngineClient): self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: if device == Device.CPU: raise ValueError("Not supported on CPU.") await self.engine_core.reset_prefix_cache_async() @@ -679,16 +699,19 @@ class AsyncLLM(EngineClient): """Prevent an adapter from being evicted.""" return await self.engine_core.pin_lora_async(lora_id) - async def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): + async def collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): """ Perform a collective RPC call to the given path. """ return await self.engine_core.collective_rpc_async( - method, timeout, args, kwargs) + method, timeout, args, kwargs + ) async def wait_for_requests_to_drain(self, drain_timeout: int = 300): """Wait for all requests to be drained.""" @@ -698,16 +721,17 @@ class AsyncLLM(EngineClient): logger.info("Engines are idle, requests have been drained") return - logger.info( - "Engines are still running, waiting for requests to drain...") + logger.info("Engines are still running, waiting for requests to drain...") await asyncio.sleep(1) # Wait 1 second before checking again - raise TimeoutError(f"Timeout reached after {drain_timeout} seconds " - "waiting for requests to drain.") + raise TimeoutError( + f"Timeout reached after {drain_timeout} seconds " + "waiting for requests to drain." + ) - async def scale_elastic_ep(self, - new_data_parallel_size: int, - drain_timeout: int = 300): + async def scale_elastic_ep( + self, new_data_parallel_size: int, drain_timeout: int = 300 + ): """ Scale up or down the data parallel size by adding or removing engine cores. @@ -716,22 +740,24 @@ class AsyncLLM(EngineClient): drain_timeout: Maximum time to wait for requests to drain (seconds) """ - old_data_parallel_size = \ - self.vllm_config.parallel_config.data_parallel_size + old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if old_data_parallel_size == new_data_parallel_size: - logger.info("Data parallel size is already %s, skipping scale", - new_data_parallel_size) + logger.info( + "Data parallel size is already %s, skipping scale", + new_data_parallel_size, + ) return logger.info( - "Waiting for requests to drain before " - "scaling up to %s engines...", new_data_parallel_size) + "Waiting for requests to drain before scaling up to %s engines...", + new_data_parallel_size, + ) await self.wait_for_requests_to_drain(drain_timeout) logger.info( - "Requests have been drained, proceeding with scale " - "to %s engines", new_data_parallel_size) + "Requests have been drained, proceeding with scale to %s engines", + new_data_parallel_size, + ) await self.engine_core.scale_elastic_ep(new_data_parallel_size) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # recreate stat loggers if new_data_parallel_size > old_data_parallel_size and self.log_stats: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 596edfdbe2..9bb08e6db7 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -56,7 +56,6 @@ class DPCoordinator: """ def __init__(self, parallel_config: ParallelConfig): - dp_size = parallel_config.data_parallel_size assert dp_size > 1, "Coordinator only used for data parallel" @@ -68,7 +67,8 @@ class DPCoordinator: # either external or hybrid DP LB mode. local_only = not (external_lb or hybrid_lb) front_publish_address = get_engine_client_zmq_addr( - local_only=local_only, host=host) + local_only=local_only, host=host + ) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -84,7 +84,8 @@ class DPCoordinator: "back_output_address": back_output_address, "back_publish_address": back_publish_address, }, - daemon=True) + daemon=True, + ) self.proc.start() self.stats_publish_address = front_publish_address @@ -104,16 +105,12 @@ class DPCoordinator: class EngineState: - def __init__(self): self.request_counts = [0, 0] # [waiting, running] class DPCoordinatorProc: - - def __init__(self, - engine_count: int, - min_stats_update_interval_ms: int = 100): + def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100): set_process_title("DPCoordinator") self.ctx = zmq.Context() @@ -131,7 +128,8 @@ class DPCoordinatorProc: ): coordinator = DPCoordinatorProc( engine_count=engine_count, - min_stats_update_interval_ms=min_stats_update_interval_ms) + min_stats_update_interval_ms=min_stats_update_interval_ms, + ) try: coordinator.process_input_socket( front_publish_address, @@ -141,10 +139,12 @@ class DPCoordinatorProc: except KeyboardInterrupt: logger.info("DP Coordinator process exiting") - def process_input_socket(self, front_publish_address: str, - back_output_address: str, - back_publish_address: str): - + def process_input_socket( + self, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): decoder = MsgpackDecoder(EngineCoreOutputs) # For tracking request wave progression. @@ -157,29 +157,33 @@ class DPCoordinatorProc: last_stats_wave = -1 last_step_counts: Optional[list[list[int]]] = None - with make_zmq_socket( + with ( + make_zmq_socket( path=front_publish_address, # IPC ctx=self.ctx, socket_type=zmq.XPUB, bind=True, - ) as publish_front, make_zmq_socket( + ) as publish_front, + make_zmq_socket( path=back_output_address, # IPC or TCP ctx=self.ctx, socket_type=zmq.PULL, bind=True, - ) as output_back, make_zmq_socket( + ) as output_back, + make_zmq_socket( path=back_publish_address, # IPC or TCP ctx=self.ctx, socket_type=zmq.XPUB, bind=True, - ) as publish_back: - + ) as publish_back, + ): # Wait until all engines subscribe. for _ in self.engines: - if publish_back.recv() != b'\x01': + if publish_back.recv() != b"\x01": logger.error( "DP Coordinator received unexpected message while " - "waiting for engines to subscribe") + "waiting for engines to subscribe" + ) return # Send ready message to engines. publish_back.send(b"READY") @@ -194,15 +198,13 @@ class DPCoordinatorProc: elapsed = int(time.time() * 1000) - last_publish_time # Send at stats_update_interval_ms interval if the stats have # changed, or otherwise every 5 seconds. - wait_for = (self.stats_update_interval_ms - if stats_changed else 5000) + wait_for = self.stats_update_interval_ms if stats_changed else 5000 # Wait at least 50ms to ensure we've received all stats for # the current step. min_timeout = 50 if last_step_counts is None else 0 - events = poller.poll(timeout=max(min_timeout, wait_for - - elapsed)) + events = poller.poll(timeout=max(min_timeout, wait_for - elapsed)) if not events: # Poller timeout - publish current stats to front-ends. if last_step_counts is not None: @@ -212,8 +214,7 @@ class DPCoordinatorProc: engine_req_counts_list = self._get_engine_counts() stats_changed = False - to_publish = (engine_req_counts_list, current_wave, - engines_running) + to_publish = (engine_req_counts_list, current_wave, engines_running) publish_front.send(msgspec.msgpack.encode(to_publish)) last_publish_time = int(time.time() * 1000) continue @@ -223,13 +224,16 @@ class DPCoordinatorProc: if publish_front in events: buffer = publish_front.recv() - if buffer in (b'\x01', b'\x00'): + if buffer in (b"\x01", b"\x00"): # Ignore subscription messages. continue decoded = msgspec.msgpack.decode(buffer) - if isinstance(decoded, (list, tuple)) and len( - decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP": + if ( + isinstance(decoded, (list, tuple)) + and len(decoded) == 2 + and decoded[0] == "SCALE_ELASTIC_EP" + ): # Handle scale up notification new_engine_count = decoded[1] current_count = len(self.engines) @@ -248,13 +252,17 @@ class DPCoordinatorProc: # engine engines_running = False logger.info( - "DPCoordinator scaled up from %s to %s " - "engines", current_count, new_engine_count) + "DPCoordinator scaled up from %s to %s engines", + current_count, + new_engine_count, + ) else: self.engines = self.engines[:new_engine_count] logger.info( - "DPCoordinator scaled down from %s to %s " - "engines", current_count, new_engine_count) + "DPCoordinator scaled down from %s to %s engines", + current_count, + new_engine_count, + ) continue # Skip normal engine notification processing # We received a message on the front-end XPUB socket, @@ -270,8 +278,9 @@ class DPCoordinatorProc: engines_running = True wave_state_changed = True - self._send_start_wave(publish_back, current_wave, - engine_to_exclude) + self._send_start_wave( + publish_back, current_wave, engine_to_exclude + ) if output_back in events: # We received a message from one of the engines. @@ -290,21 +299,28 @@ class DPCoordinatorProc: stats = self.engines[eng_index].request_counts stats_step = scheduler_stats.step_counter stats_wave = scheduler_stats.current_wave - if (stats_wave > last_stats_wave - or stats_wave == last_stats_wave - and stats_step > last_stats_step): + if ( + stats_wave > last_stats_wave + or stats_wave == last_stats_wave + and stats_step > last_stats_step + ): if stats_changed: - last_step_counts = self._get_engine_counts( - do_copy=True) + last_step_counts = self._get_engine_counts(do_copy=True) last_stats_step = stats_step last_stats_wave = stats_wave elif stats_wave != last_stats_wave or ( - stats_step != last_stats_step): + stats_step != last_stats_step + ): logger.warning( "Received stats for out-of-order " "step (%d, %d) from engine %d (expected " - "> (%d, %d))", stats_wave, stats_step, - eng_index, last_stats_wave, last_stats_step) + "> (%d, %d))", + stats_wave, + stats_step, + eng_index, + last_stats_wave, + last_stats_step, + ) stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs stats_changed = True @@ -315,20 +331,24 @@ class DPCoordinatorProc: # (engines_running==False). if current_wave <= wave: new_wave = wave + 1 - logger.debug("Moving DP wave from %d to %d.", - current_wave, new_wave) + logger.debug( + "Moving DP wave from %d to %d.", current_wave, new_wave + ) current_wave = new_wave engines_running = False wave_state_changed = True elif (wave := outputs.start_wave) is not None and ( - wave > current_wave or - (wave == current_wave and not engines_running)): + wave > current_wave + or (wave == current_wave and not engines_running) + ): # 3. The engine received request for a non-current wave # so we must ensure that other engines progress to the # next wave (race condition handling). logger.debug( "Starting wave %d after notification of " - "stale wave request from engine.", wave) + "stale wave request from engine.", + wave, + ) current_wave = wave engines_running = True wave_state_changed = True @@ -339,16 +359,16 @@ class DPCoordinatorProc: publish_front.send(msgspec.msgpack.encode(message)) @staticmethod - def _send_start_wave(socket: zmq.Socket, wave: int, - exclude_engine_index: Optional[int]): + def _send_start_wave( + socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int] + ): """Broadcast the START_DP_WAVE message to all the engines. It includes the current wave number and index of engine which has already received a request with this wave number and so doesn't require additional notification. """ wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) - socket.send_multipart( - (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) def _get_engine_counts(self, do_copy=False) -> list[list[int]]: """Return list of [waiting, running] count lists for each engine.""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3ee804f10c..4826d7c589 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -25,25 +25,39 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, - resolve_obj_by_qualname, set_process_title) +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.utils import ( + decorate_logs, + get_hash_fn_by_name, + make_zmq_socket, + resolve_obj_by_qualname, + set_process_title, +) from vllm.utils.gc_utils import maybe_attach_gc_debug_callback -from vllm.v1.core.kv_cache_utils import (BlockHash, - generate_scheduler_kv_cache_config, - get_kv_cache_configs, - get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_request_block_hasher, + init_none_hash, +) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, - ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput, UtilityResult) -from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, - get_device_indices) +from vllm.v1.engine import ( + EngineCoreOutputs, + EngineCoreRequest, + EngineCoreRequestType, + ReconfigureDistributedRequest, + ReconfigureRankType, + UtilityOutput, + UtilityResult, +) +from vllm.v1.engine.utils import ( + EngineHandshakeMetadata, + EngineZmqAddresses, + get_device_indices, +) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -58,51 +72,56 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 HANDSHAKE_TIMEOUT_MINS = 5 -_R = TypeVar('_R') # Return type for collective_rpc +_R = TypeVar("_R") # Return type for collective_rpc class EngineCore: """Inner loop of vLLM's Engine.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - executor_fail_callback: Optional[Callable] = None): - + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Optional[Callable] = None, + ): # plugins need to be loaded at the engine/scheduler level too from vllm.plugins import load_general_plugins + load_general_plugins() self.vllm_config = vllm_config - logger.info("Initializing a V1 LLM engine (v%s) with config: %s", - VLLM_VERSION, vllm_config) + logger.info( + "Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, + vllm_config, + ) self.log_stats = log_stats # Setup Model. self.model_executor = executor_class(vllm_config) if executor_fail_callback is not None: - self.model_executor.register_failure_callback( - executor_fail_callback) + self.model_executor.register_failure_callback(executor_fail_callback) self.available_gpu_memory_for_kv_cache = -1 # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ - self._initialize_kv_caches(vllm_config) + num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( + vllm_config + ) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks - self.collective_rpc("initialize_cache", - args=(num_gpu_blocks, num_cpu_blocks)) + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. if isinstance(vllm_config.scheduler_config.scheduler_cls, str): Scheduler = resolve_obj_by_qualname( - vllm_config.scheduler_config.scheduler_cls) + vllm_config.scheduler_config.scheduler_cls + ) else: Scheduler = vllm_config.scheduler_config.scheduler_cls @@ -114,7 +133,8 @@ class EngineCore: "Using configured V1 scheduler class %s. " "This scheduler interface is not public and " "compatibility may not be maintained.", - vllm_config.scheduler_config.scheduler_cls) + vllm_config.scheduler_config.scheduler_cls, + ) if len(kv_cache_config.kv_cache_groups) == 0: # Encoder models without KV cache don't support @@ -126,49 +146,54 @@ class EngineCore: vllm_config=vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=self.structured_output_manager, - include_finished_set=vllm_config.parallel_config.data_parallel_size - > 1, + include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, ) self.use_spec_decode = vllm_config.speculative_config is not None if self.scheduler.connector is not None: # type: ignore self.model_executor.init_kv_output_aggregator( - self.scheduler.connector.get_finished_count()) # type: ignore + self.scheduler.connector.get_finished_count() # type: ignore + ) self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_receiver_cache = engine_receiver_cache_from_config( - vllm_config, mm_registry) + vllm_config, mm_registry + ) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput], - SchedulerOutput]]] = None + self.batch_queue: Optional[ + deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] + ] = None if self.batch_queue_size > 1: - logger.info("Batch queue is enabled with size %d", - self.batch_queue_size) + logger.info("Batch queue is enabled with size %d", self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size) - self.request_block_hasher: Optional[Callable[[Request], - list[BlockHash]]] = None - if (self.vllm_config.cache_config.enable_prefix_caching - or self.scheduler.get_kv_connector() is not None): - + self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None + if ( + self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None + ): block_size = vllm_config.cache_config.block_size caching_hash_fn = get_hash_fn_by_name( - vllm_config.cache_config.prefix_caching_hash_algo) + vllm_config.cache_config.prefix_caching_hash_algo + ) init_none_hash(caching_hash_fn) self.request_block_hasher = get_request_block_hasher( - block_size, caching_hash_fn) + block_size, caching_hash_fn + ) - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) + self.step_fn = ( + self.step if self.batch_queue is None else self.step_with_batch_queue + ) def _initialize_kv_caches( - self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + self, vllm_config: VllmConfig + ) -> tuple[int, int, KVCacheConfig]: start = time.time() # Get all kv cache needed by the model @@ -179,28 +204,27 @@ class EngineCore: if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None - self.available_gpu_memory_for_kv_cache = \ + self.available_gpu_memory_for_kv_cache = ( ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) - available_gpu_memory = [ - self.available_gpu_memory_for_kv_cache - ] * len(kv_cache_specs) + ) + available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( + kv_cache_specs + ) else: # Profiles the peak memory usage of the model to determine how # much memory can be allocated for kv cache. - available_gpu_memory = ( - self.model_executor.determine_available_memory()) - self.available_gpu_memory_for_kv_cache = \ - available_gpu_memory[0] + available_gpu_memory = self.model_executor.determine_available_memory() + self.available_gpu_memory_for_kv_cache = available_gpu_memory[0] else: # Attention free models don't need memory for kv cache available_gpu_memory = [0] * len(kv_cache_specs) assert len(kv_cache_specs) == len(available_gpu_memory) - kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, - available_gpu_memory) - scheduler_kv_cache_config = generate_scheduler_kv_cache_config( - kv_cache_configs) + kv_cache_configs = get_kv_cache_configs( + vllm_config, kv_cache_specs, available_gpu_memory + ) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) num_gpu_blocks = scheduler_kv_cache_config.num_blocks num_cpu_blocks = 0 @@ -208,8 +232,10 @@ class EngineCore: self.model_executor.initialize_from_config(kv_cache_configs) elapsed = time.time() - start - logger.info(("init engine (profile, create kv cache, " - "warmup model) took %.2f seconds"), elapsed) + logger.info( + ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), + elapsed, + ) return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config def get_supported_tasks(self) -> tuple[SupportedTask, ...]: @@ -224,22 +250,27 @@ class EngineCore: # Validate the request_id type. if not isinstance(request.request_id, str): raise TypeError( - f"request_id must be a string, got {type(request.request_id)}") + f"request_id must be a string, got {type(request.request_id)}" + ) if pooling_params := request.pooling_params: supported_pooling_tasks = [ - task for task in self.get_supported_tasks() - if task in POOLING_TASKS + task for task in self.get_supported_tasks() if task in POOLING_TASKS ] if pooling_params.task not in supported_pooling_tasks: - raise ValueError(f"Unsupported task: {pooling_params.task!r} " - f"Supported tasks: {supported_pooling_tasks}") + raise ValueError( + f"Unsupported task: {pooling_params.task!r} " + f"Supported tasks: {supported_pooling_tasks}" + ) if request.kv_transfer_params is not None and ( - not self.scheduler.get_kv_connector()): - logger.warning("Got kv_transfer_params, but no KVConnector found. " - "Disabling KVTransfer for this request.") + not self.scheduler.get_kv_connector() + ): + logger.warning( + "Got kv_transfer_params, but no KVConnector found. " + "Disabling KVTransfer for this request." + ) self.scheduler.add_request(request) @@ -249,8 +280,7 @@ class EngineCore: # TODO: The scheduler doesn't really need to know the # specific finish reason, TBD whether we propagate that # (i.e. client-aborted vs stop criteria met). - self.scheduler.finish_requests(request_ids, - RequestStatus.FINISHED_ABORTED) + self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) def execute_model_with_error_logging( self, @@ -266,8 +296,9 @@ class EngineCore: # error from execute_model itself. # NOTE: This method is exception-free - dump_engine_exception(self.vllm_config, scheduler_output, - self.scheduler.make_stats()) + dump_engine_exception( + self.vllm_config, scheduler_output, self.scheduler.make_stats() + ) raise err def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: @@ -284,12 +315,13 @@ class EngineCore: scheduler_output = self.scheduler.schedule() model_output = self.execute_model_with_error_logging( self.model_executor.execute_model, # type: ignore - scheduler_output) + scheduler_output, + ) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) # type: ignore + scheduler_output, model_output + ) # type: ignore - return (engine_core_outputs, - scheduler_output.total_num_scheduled_tokens > 0) + return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) def post_step(self, model_executed: bool) -> None: if self.use_spec_decode and model_executed: @@ -299,7 +331,8 @@ class EngineCore: self.scheduler.update_draft_token_ids(draft_token_ids) def step_with_batch_queue( - self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: + self, + ) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -324,14 +357,15 @@ class EngineCore: model_executed = False if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, - non_block=True) - batch_queue.appendleft( - (future, scheduler_output)) # type: ignore[arg-type] + future = self.model_executor.execute_model(scheduler_output, non_block=True) + batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type] model_executed = scheduler_output.total_num_scheduled_tokens > 0 - if model_executed and len(batch_queue) < self.batch_queue_size \ - and not batch_queue[-1][0].done(): + if ( + model_executed + and len(batch_queue) < self.batch_queue_size + and not batch_queue[-1][0].done() + ): # Don't block on next worker response unless the queue is full # or there are no more requests to schedule. return None, True @@ -345,10 +379,12 @@ class EngineCore: # Block until the next result is available. future, scheduler_output = batch_queue.pop() model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output) + lambda _: future.result(), scheduler_output + ) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + scheduler_output, model_output + ) return engine_core_outputs, model_executed @@ -366,8 +402,10 @@ class EngineCore: # NOTE: Since this is mainly for debugging, we don't attempt to # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) if self.scheduler.has_unfinished_requests(): - logger.warning("Resetting the multi-modal cache when requests are " - "in progress may lead to desynced internal caches.") + logger.warning( + "Resetting the multi-modal cache when requests are " + "in progress may lead to desynced internal caches." + ) if self.mm_receiver_cache is not None: self.mm_receiver_cache.clear_cache() @@ -405,27 +443,28 @@ class EngineCore: pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self.model_executor.save_sharded_state(path=path, - pattern=pattern, - max_size=max_size) + self.model_executor.save_sharded_state( + path=path, pattern=pattern, max_size=max_size + ) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.model_executor.collective_rpc(method, timeout, args, - kwargs) + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, kwargs) def save_tensorized_model( self, tensorizer_config, ) -> None: self.model_executor.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + tensorizer_config=tensorizer_config, + ) - def preprocess_add_request( - self, request: EngineCoreRequest) -> tuple[Request, int]: + def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. This function could be directly used in input processing thread to allow @@ -435,12 +474,11 @@ class EngineCore: # `mm_receiver_cache` is reset at the end of LLMEngine init, # and will only be accessed in the input processing thread afterwards. if self.mm_receiver_cache is not None and request.mm_features: - request.mm_features = ( - self.mm_receiver_cache.get_and_update_features( - request.mm_features)) + request.mm_features = self.mm_receiver_cache.get_and_update_features( + request.mm_features + ) - req = Request.from_engine_core_request(request, - self.request_block_hasher) + req = Request.from_engine_core_request(request, self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For @@ -454,7 +492,7 @@ class EngineCore: class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" - ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" def __init__( self, @@ -467,37 +505,46 @@ class EngineCoreProc(EngineCore): engine_index: int = 0, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() - self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], - bytes]]() + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]() executor_fail_callback = lambda: self.input_queue.put_nowait( - (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + (EngineCoreRequestType.EXECUTOR_FAILED, b"") + ) self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False - with self._perform_handshakes(handshake_address, identity, - local_client, vllm_config, - client_handshake_address) as addresses: + with self._perform_handshakes( + handshake_address, + identity, + local_client, + vllm_config, + client_handshake_address, + ) as addresses: self.client_count = len(addresses.outputs) # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( - addresses.frontend_stats_publish_address) - logger.debug("Has DP Coordinator: %s, stats publish address: %s", - self.has_coordinator, - self.frontend_stats_publish_address) + addresses.frontend_stats_publish_address + ) + logger.debug( + "Has DP Coordinator: %s, stats publish address: %s", + self.has_coordinator, + self.frontend_stats_publish_address, + ) # Only publish request queue stats to coordinator for "internal" # and "hybrid" LB modes . self.publish_dp_lb_stats = ( self.has_coordinator - and not vllm_config.parallel_config.data_parallel_external_lb) + and not vllm_config.parallel_config.data_parallel_external_lb + ) self._init_data_parallel(vllm_config) - super().__init__(vllm_config, executor_class, log_stats, - executor_fail_callback) + super().__init__( + vllm_config, executor_class, log_stats, executor_fail_callback + ) # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, @@ -505,26 +552,34 @@ class EngineCoreProc(EngineCore): # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. ready_event = threading.Event() - input_thread = threading.Thread(target=self.process_input_sockets, - args=(addresses.inputs, - addresses.coordinator_input, - identity, ready_event), - daemon=True) + input_thread = threading.Thread( + target=self.process_input_sockets, + args=( + addresses.inputs, + addresses.coordinator_input, + identity, + ready_event, + ), + daemon=True, + ) input_thread.start() self.output_thread = threading.Thread( target=self.process_output_sockets, - args=(addresses.outputs, addresses.coordinator_output, - self.engine_index), - daemon=True) + args=( + addresses.outputs, + addresses.coordinator_output, + self.engine_index, + ), + daemon=True, + ) self.output_thread.start() # Don't complete handshake until DP coordinator ready message is # received. while not ready_event.wait(timeout=10): if not input_thread.is_alive(): - raise RuntimeError( - "Input socket thread died during startup") + raise RuntimeError("Input socket thread died during startup") assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") @@ -570,18 +625,23 @@ class EngineCoreProc(EngineCore): input_ctx = zmq.Context() is_local = local_client and client_handshake_address is None headless = not local_client - handshake = self._perform_handshake(input_ctx, handshake_address, - identity, is_local, headless, - vllm_config, - vllm_config.parallel_config) + handshake = self._perform_handshake( + input_ctx, + handshake_address, + identity, + is_local, + headless, + vllm_config, + vllm_config.parallel_config, + ) if client_handshake_address is None: with handshake as addresses: yield addresses else: assert local_client local_handshake = self._perform_handshake( - input_ctx, client_handshake_address, identity, True, False, - vllm_config) + input_ctx, client_handshake_address, identity, True, False, vllm_config + ) with handshake as addresses, local_handshake as client_addresses: addresses.inputs = client_addresses.inputs addresses.outputs = client_addresses.outputs @@ -601,16 +661,18 @@ class EngineCoreProc(EngineCore): vllm_config: VllmConfig, parallel_config_to_update: Optional[ParallelConfig] = None, ) -> Generator[EngineZmqAddresses, None, None]: - with make_zmq_socket(ctx, - handshake_address, - zmq.DEALER, - identity=identity, - linger=5000, - bind=False) as handshake_socket: + with make_zmq_socket( + ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False, + ) as handshake_socket: # Register engine with front-end. - addresses = self.startup_handshake(handshake_socket, local_client, - headless, - parallel_config_to_update) + addresses = self.startup_handshake( + handshake_socket, local_client, headless, parallel_config_to_update + ) yield addresses # Send ready message. @@ -620,13 +682,16 @@ class EngineCoreProc(EngineCore): # only runs with rank 0). dp_stats_address = self.frontend_stats_publish_address handshake_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": local_client, - "headless": headless, - "num_gpu_blocks": num_gpu_blocks, - "dp_stats_address": dp_stats_address, - })) + msgspec.msgpack.encode( + { + "status": "READY", + "local": local_client, + "headless": headless, + "num_gpu_blocks": num_gpu_blocks, + "dp_stats_address": dp_stats_address, + } + ) + ) @staticmethod def startup_handshake( @@ -635,24 +700,29 @@ class EngineCoreProc(EngineCore): headless: bool, parallel_config: Optional[ParallelConfig] = None, ) -> EngineZmqAddresses: - # Send registration message. handshake_socket.send( - msgspec.msgpack.encode({ - "status": "HELLO", - "local": local_client, - "headless": headless, - })) + msgspec.msgpack.encode( + { + "status": "HELLO", + "local": local_client, + "headless": headless, + } + ) + ) # Receive initialization message. logger.info("Waiting for init message from front-end.") if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): - raise RuntimeError("Did not receive response from front-end " - f"process within {HANDSHAKE_TIMEOUT_MINS} " - f"minutes") + raise RuntimeError( + "Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes" + ) init_bytes = handshake_socket.recv() init_message: EngineHandshakeMetadata = msgspec.msgpack.decode( - init_bytes, type=EngineHandshakeMetadata) + init_bytes, type=EngineHandshakeMetadata + ) logger.debug("Received init message: %s", init_message) if parallel_config is not None: @@ -662,10 +732,7 @@ class EngineCoreProc(EngineCore): return init_message.addresses @staticmethod - def run_engine_core(*args, - dp_rank: int = 0, - local_dp_rank: int = 0, - **kwargs): + def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -688,8 +755,7 @@ class EngineCoreProc(EngineCore): engine_core: Optional[EngineCoreProc] = None try: - parallel_config: ParallelConfig = kwargs[ - "vllm_config"].parallel_config + parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: set_process_title("EngineCore", f"DP{dp_rank}") decorate_logs() @@ -735,8 +801,11 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not self.scheduler.has_requests() \ - and not self.batch_queue: + while ( + not self.engines_running + and not self.scheduler.has_requests() + and not self.batch_queue + ): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -757,15 +826,16 @@ class EngineCoreProc(EngineCore): # Step the engine core. outputs, model_executed = self.step_fn() # Put EngineCoreOutputs into the output queue. - for output in (outputs.items() if outputs else ()): + for output in outputs.items() if outputs else (): self.output_queue.put_nowait(output) # Post-step hook. self.post_step(model_executed) return model_executed - def _handle_client_request(self, request_type: EngineCoreRequestType, - request: Any) -> None: + def _handle_client_request( + self, request_type: EngineCoreRequestType, request: Any + ) -> None: """Dispatch request from client.""" if request_type == EngineCoreRequestType.ADD: @@ -782,29 +852,35 @@ class EngineCoreProc(EngineCore): output.result = UtilityResult(result) except BaseException as e: logger.exception("Invocation of %s method failed", method_name) - output.failure_message = (f"Call to {method_name} method" - f" failed: {str(e)}") + output.failure_message = ( + f"Call to {method_name} method failed: {str(e)}" + ) self.output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output))) + (client_idx, EngineCoreOutputs(utility_output=output)) + ) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: - logger.error("Unrecognized input request type encountered: %s", - request_type) + logger.error( + "Unrecognized input request type encountered: %s", request_type + ) @staticmethod def _convert_msgspec_args(method, args): """If a provided arg type doesn't match corresponding target method - arg type, try converting to msgspec object.""" + arg type, try converting to msgspec object.""" if not args: return args arg_types = signature(method).parameters.values() assert len(args) <= len(arg_types) return tuple( - msgspec.convert(v, type=p.annotation) if isclass(p.annotation) + msgspec.convert(v, type=p.annotation) + if isclass(p.annotation) and issubclass(p.annotation, msgspec.Struct) - and not isinstance(v, p.annotation) else v - for v, p in zip(args, arg_types)) + and not isinstance(v, p.annotation) + else v + for v, p in zip(args, arg_types) + ) def _send_engine_dead(self): """Send EngineDead status to the EngineCoreClient.""" @@ -815,12 +891,18 @@ class EngineCoreProc(EngineCore): # Wait until msg sent by the daemon before shutdown. self.output_thread.join(timeout=5.0) if self.output_thread.is_alive(): - logger.fatal("vLLM shutdown signal from EngineCore failed " - "to send. Please report this issue.") + logger.fatal( + "vLLM shutdown signal from EngineCore failed " + "to send. Please report this issue." + ) - def process_input_sockets(self, input_addresses: list[str], - coord_input_address: Optional[str], - identity: bytes, ready_event: threading.Event): + def process_input_sockets( + self, + input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes, + ready_event: threading.Event, + ): """Input socket IO thread.""" # Msgpack serialization decoding. @@ -830,24 +912,26 @@ class EngineCoreProc(EngineCore): with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ stack.enter_context( - make_zmq_socket(ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False)) + make_zmq_socket( + ctx, input_address, zmq.DEALER, identity=identity, bind=False + ) + ) for input_address in input_addresses ] if coord_input_address is None: coord_socket = None else: coord_socket = stack.enter_context( - make_zmq_socket(ctx, - coord_input_address, - zmq.XSUB, - identity=identity, - bind=False)) + make_zmq_socket( + ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False, + ) + ) # Send subscription message to coordinator. - coord_socket.send(b'\x01') + coord_socket.send(b"\x01") # Register sockets with poller. poller = zmq.Poller() @@ -855,7 +939,7 @@ class EngineCoreProc(EngineCore): # Send initial message to each input socket - this is required # before the front-end ROUTER socket can send input messages # back to us. - input_socket.send(b'') + input_socket.send(b"") poller.register(input_socket, zmq.POLLIN) if coord_socket is not None: @@ -868,10 +952,8 @@ class EngineCoreProc(EngineCore): while True: for input_socket, _ in poller.poll(): # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart( - copy=False) - request_type = EngineCoreRequestType( - bytes(type_frame.buffer)) + type_frame, *data_frames = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. if request_type == EngineCoreRequestType.ADD: @@ -883,9 +965,12 @@ class EngineCoreProc(EngineCore): # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) - def process_output_sockets(self, output_paths: list[str], - coord_output_path: Optional[str], - engine_index: int): + def process_output_sockets( + self, + output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int, + ): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -902,13 +987,19 @@ class EngineCoreProc(EngineCore): with ExitStack() as stack, zmq.Context() as ctx: sockets = [ stack.enter_context( - make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000) + ) for output_path in output_paths ] - coord_socket = stack.enter_context( - make_zmq_socket( - ctx, coord_output_path, zmq.PUSH, bind=False, - linger=4000)) if coord_output_path is not None else None + coord_socket = ( + stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000 + ) + ) + if coord_output_path is not None + else None + ) max_reuse_bufs = len(sockets) + 1 while True: @@ -934,9 +1025,9 @@ class EngineCoreProc(EngineCore): buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = sockets[client_index].send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart( + buffers, copy=False, track=True + ) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) @@ -966,12 +1057,17 @@ class DPEngineCoreProc(EngineCoreProc): # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, local_client, handshake_address, - executor_class, log_stats, client_handshake_address, - dp_rank) + super().__init__( + vllm_config, + local_client, + handshake_address, + executor_class, + log_stats, + client_handshake_address, + dp_rank, + ) def _init_data_parallel(self, vllm_config: VllmConfig): - # Configure GPUs and stateless process group for data parallel. dp_rank = vllm_config.parallel_config.data_parallel_rank dp_size = vllm_config.parallel_config.data_parallel_size @@ -986,8 +1082,10 @@ class DPEngineCoreProc(EngineCoreProc): vllm_config.kv_transfer_config.engine_id = ( f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}" ) - logger.debug("Setting kv_transfer_config.engine_id to %s", - vllm_config.kv_transfer_config.engine_id) + logger.debug( + "Setting kv_transfer_config.engine_id to %s", + vllm_config.kv_transfer_config.engine_id, + ) self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() @@ -1005,20 +1103,22 @@ class DPEngineCoreProc(EngineCoreProc): # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - (-1, EngineCoreOutputs(start_wave=self.current_wave))) + (-1, EngineCoreOutputs(start_wave=self.current_wave)) + ) super().add_request(request, request_wave) - def _handle_client_request(self, request_type: EngineCoreRequestType, - request: Any) -> None: + def _handle_client_request( + self, request_type: EngineCoreRequestType, request: Any + ) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: new_wave, exclude_eng_index = request if exclude_eng_index != self.engine_index and ( - new_wave >= self.current_wave): + new_wave >= self.current_wave + ): self.current_wave = new_wave if not self.engines_running: - logger.debug("EngineCore starting idle loop for wave %d.", - new_wave) + logger.debug("EngineCore starting idle loop for wave %d.", new_wave) self.engines_running = True else: super()._handle_client_request(request_type, request) @@ -1031,11 +1131,10 @@ class DPEngineCoreProc(EngineCoreProc): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts, - step_counter=self.step_counter, - current_wave=self.current_wave) - self.output_queue.put_nowait( - (-1, EngineCoreOutputs(scheduler_stats=stats))) + stats = SchedulerStats( + *counts, step_counter=self.step_counter, current_wave=self.current_wave + ) + self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats))) def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -1061,58 +1160,65 @@ class DPEngineCoreProc(EngineCoreProc): # 3) All-reduce operation to determine global unfinished reqs. self.engines_running = self._has_global_unfinished_reqs( - local_unfinished_reqs) + local_unfinished_reqs + ) if not self.engines_running: if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop. - logger.debug("Wave %d finished, pausing engine loop.", - self.current_wave) + logger.debug( + "Wave %d finished, pausing engine loop.", self.current_wave + ) # In the coordinator case, dp rank 0 sends updates to the # coordinator. Otherwise (offline spmd case), each rank # sends the update to its colocated front-end process. client_index = -1 if self.has_coordinator else 0 self.output_queue.put_nowait( - (client_index, - EngineCoreOutputs(wave_complete=self.current_wave))) + ( + client_index, + EngineCoreOutputs(wave_complete=self.current_wave), + ) + ) # Increment wave count and reset step counter. self.current_wave += 1 self.step_counter = 0 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 32 steps. self.step_counter += 1 if self.step_counter % 32 != 0: return True - return ParallelConfig.has_unfinished_dp(self.dp_group, - local_unfinished) + return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: stateless_destroy_torch_distributed_process_group(self.dp_group) self.shutdown() parallel_config = self.vllm_config.parallel_config old_dp_size = parallel_config.data_parallel_size - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size if reconfig_request.new_data_parallel_rank != -1: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank # local rank specifies device visibility, it should not be changed - assert reconfig_request.new_data_parallel_rank_local == \ - ReconfigureRankType.KEEP_CURRENT_RANK - parallel_config.data_parallel_master_ip = \ + assert ( + reconfig_request.new_data_parallel_rank_local + == ReconfigureRankType.KEEP_CURRENT_RANK + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) if reconfig_request.new_data_parallel_rank != -2: self.dp_rank = parallel_config.data_parallel_rank self.dp_group = parallel_config.stateless_init_dp_group() - reconfig_request.new_data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port = ( parallel_config.data_parallel_master_port + ) self.model_executor.reinitialize_distributed(reconfig_request) if reconfig_request.new_data_parallel_size > old_dp_size: @@ -1121,17 +1227,21 @@ class DPEngineCoreProc(EngineCoreProc): # engine-cores to new engine-cores so they can directly # use it in _initialize_kv_caches() rather than profiling. ParallelConfig.sync_kv_cache_memory_size( - self.dp_group, self.available_gpu_memory_for_kv_cache) + self.dp_group, self.available_gpu_memory_for_kv_cache + ) # NOTE(yongji): newly joined workers require dummy_run even # CUDA graph is not used self.model_executor.collective_rpc("compile_or_warm_up_model") - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) else: - logger.info("Distributed environment reinitialized for DP rank %s", - self.dp_rank) + logger.info( + "Distributed environment reinitialized for DP rank %s", self.dp_rank + ) class DPEngineCoreActor(DPEngineCoreProc): @@ -1151,8 +1261,7 @@ class DPEngineCoreActor(DPEngineCoreProc): ): self.addresses = addresses vllm_config.parallel_config.data_parallel_rank = dp_rank - vllm_config.parallel_config.data_parallel_rank_local = \ - local_dp_rank + vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time, @@ -1173,39 +1282,46 @@ class DPEngineCoreActor(DPEngineCoreProc): # of ray. self._set_visible_devices(vllm_config, local_dp_rank) - super().__init__(vllm_config, local_client, "", executor_class, - log_stats) + super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int): from vllm.platforms import current_platform + if current_platform.is_xpu(): pass else: device_control_env_var = current_platform.device_control_env_var - self._set_cuda_visible_devices(vllm_config, local_dp_rank, - device_control_env_var) + self._set_cuda_visible_devices( + vllm_config, local_dp_rank, device_control_env_var + ) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int, - device_control_env_var: str): + def _set_cuda_visible_devices( + self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str + ): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - value = get_device_indices(device_control_env_var, local_dp_rank, - world_size) + value = get_device_indices( + device_control_env_var, local_dp_rank, world_size + ) os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " - f"base value: \"{os.getenv(device_control_env_var)}\"") from e + f'base value: "{os.getenv(device_control_env_var)}"' + ) from e @contextmanager - def _perform_handshakes(self, handshake_address: str, identity: bytes, - local_client: bool, vllm_config: VllmConfig, - client_handshake_address: Optional[str]): + def _perform_handshakes( + self, + handshake_address: str, + identity: bytes, + local_client: bool, + vllm_config: VllmConfig, + client_handshake_address: Optional[str], + ): """ For Ray, we don't need to actually perform handshake. All addresses information is known before the actor creation. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a84b0e5510..27283411ea 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,17 +23,29 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path, - in_loop, make_zmq_socket) -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, - ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput) +from vllm.utils import ( + close_sockets, + get_open_port, + get_open_zmq_inproc_path, + in_loop, + make_zmq_socket, +) +from vllm.v1.engine import ( + EngineCoreOutputs, + EngineCoreRequest, + EngineCoreRequestType, + ReconfigureDistributedRequest, + ReconfigureRankType, + UtilityOutput, +) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager, launch_core_engines) +from vllm.v1.engine.utils import ( + CoreEngineActorManager, + CoreEngineProcManager, + launch_core_engines, +) from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr @@ -41,14 +53,14 @@ logger = init_logger(__name__) AnyFuture = Union[asyncio.Future[Any], Future[Any]] -_R = TypeVar('_R') # Return type for collective_rpc +_R = TypeVar("_R") # Return type for collective_rpc EngineIdentity = bytes class EngineCoreClient(ABC): """ - EngineCoreClient: subclasses handle different methods for pushing + EngineCoreClient: subclasses handle different methods for pushing and pulling from the EngineCore for asyncio / multiprocessing. Subclasses: @@ -65,16 +77,17 @@ class EngineCoreClient(ABC): executor_class: type[Executor], log_stats: bool, ) -> "EngineCoreClient": - # TODO: support this for debugging purposes. if asyncio_mode and not multiprocess_mode: raise NotImplementedError( "Running EngineCore in asyncio without multiprocessing " - "is not currently supported.") + "is not currently supported." + ) if multiprocess_mode and asyncio_mode: return EngineCoreClient.make_async_mp_client( - vllm_config, executor_class, log_stats) + vllm_config, executor_class, log_stats + ) if multiprocess_mode and not asyncio_mode: return SyncMPClient(vllm_config, executor_class, log_stats) @@ -91,8 +104,14 @@ class EngineCoreClient(ABC): client_index: int = 0, ) -> "MPClient": parallel_config = vllm_config.parallel_config - client_args = (vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + client_args = ( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_external_lb: # External load balancer - client per DP rank. @@ -102,8 +121,7 @@ class EngineCoreClient(ABC): return AsyncMPClient(*client_args) @abstractmethod - def shutdown(self): - ... + def shutdown(self): ... def get_output(self) -> EngineCoreOutputs: raise NotImplementedError @@ -153,17 +171,18 @@ class EngineCoreClient(ABC): def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: raise NotImplementedError - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: raise NotImplementedError def dp_engines_running(self) -> bool: @@ -216,24 +235,24 @@ class EngineCoreClient(ABC): async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError - async def save_sharded_state_async(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + async def save_sharded_state_async( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: raise NotImplementedError async def collective_rpc_async( - self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: raise NotImplementedError class InprocClient(EngineCoreClient): """ - InprocClient: client for in-process EngineCore. Intended + InprocClient: client for in-process EngineCore. Intended for use in LLMEngine for V0-style add_request() and step() EngineCore setup in this process (no busy loop). @@ -295,17 +314,18 @@ class InprocClient(EngineCoreClient): def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: self.engine_core.save_sharded_state(path, pattern, max_size) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) def dp_engines_running(self) -> bool: @@ -320,8 +340,9 @@ class BackgroundResources: ctx: zmq.Context # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. - engine_manager: Optional[Union[CoreEngineProcManager, - CoreEngineActorManager]] = None + engine_manager: Optional[Union[CoreEngineProcManager, CoreEngineActorManager]] = ( + None + ) coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None @@ -347,12 +368,15 @@ class BackgroundResources: if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. - loop = self.output_queue_task._loop \ - if self.output_queue_task else None + loop = self.output_queue_task._loop if self.output_queue_task else None - sockets = (self.output_socket, self.input_socket, - self.first_req_send_socket, self.first_req_rcv_socket, - self.stats_update_socket) + sockets = ( + self.output_socket, + self.input_socket, + self.first_req_send_socket, + self.first_req_rcv_socket, + self.stats_update_socket, + ) tasks = (self.output_queue_task, self.stats_update_task) @@ -387,11 +411,10 @@ class BackgroundResources: with self.ctx.socket(zmq.PAIR) as shutdown_sender: shutdown_sender.connect(self.shutdown_path) # Send shutdown signal. - shutdown_sender.send(b'') + shutdown_sender.send(b"") def validate_alive(self, frames: Sequence[zmq.Frame]): - if len(frames) == 1 and (frames[0].buffer - == EngineCoreProc.ENGINE_CORE_DEAD): + if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD): self.engine_dead = True raise EngineDeadError() @@ -404,7 +427,7 @@ class MPClient(EngineCoreClient): * pushes EngineCoreRequests via input_socket * pulls EngineCoreOutputs via output_socket - + * AsyncMPClient subclass for AsyncLLM usage * SyncMPClient subclass for LLM usage """ @@ -441,30 +464,32 @@ class MPClient(EngineCoreClient): # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] - self.stats_update_address = client_addresses.get( - "stats_update_address") + self.stats_update_address = client_addresses.get("stats_update_address") else: # Engines are managed by this client. - with launch_core_engines(vllm_config, executor_class, - log_stats) as (engine_manager, - coordinator, - addresses): + with launch_core_engines(vllm_config, executor_class, log_stats) as ( + engine_manager, + coordinator, + addresses, + ): self.resources.coordinator = coordinator self.resources.engine_manager = engine_manager - (input_address, ) = addresses.inputs - (output_address, ) = addresses.outputs - self.stats_update_address = ( - addresses.frontend_stats_publish_address) + (input_address,) = addresses.inputs + (output_address,) = addresses.outputs + self.stats_update_address = addresses.frontend_stats_publish_address if coordinator is not None: assert self.stats_update_address == ( - coordinator.get_stats_publish_address()) + coordinator.get_stats_publish_address() + ) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( - self.ctx, input_address, zmq.ROUTER, bind=True) + self.ctx, input_address, zmq.ROUTER, bind=True + ) self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.PULL) + self.ctx, output_address, zmq.PULL + ) parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size @@ -473,19 +498,22 @@ class MPClient(EngineCoreClient): offline_mode = parallel_config.data_parallel_rank_local is not None # Client manages local+remote EngineCores in pure internal LB case. # Client manages local EngineCores in hybrid and external LB case. - local_engines_only = (parallel_config.data_parallel_hybrid_lb - or parallel_config.data_parallel_external_lb) + local_engines_only = ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) num_ranks = dp_local_size if local_engines_only else dp_size - self.engine_ranks_managed = [dp_rank] if offline_mode else list( - range(dp_rank, dp_rank + num_ranks)) + self.engine_ranks_managed = ( + [dp_rank] if offline_mode else list(range(dp_rank, dp_rank + num_ranks)) + ) assert parallel_config.data_parallel_size_local <= len( - self.engine_ranks_managed) + self.engine_ranks_managed + ) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ - rank.to_bytes(2, "little") - for rank in self.engine_ranks_managed + rank.to_bytes(2, "little") for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. @@ -493,8 +521,10 @@ class MPClient(EngineCoreClient): sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: if not sync_input_socket.poll(timeout=600_000): - raise TimeoutError("Timed out waiting for engines to send" - "initial message on input socket.") + raise TimeoutError( + "Timed out waiting for engines to send" + "initial message on input socket." + ) identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) @@ -520,8 +550,9 @@ class MPClient(EngineCoreClient): def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" - return EngineDeadError( - suppress_context=True) if self.resources.engine_dead else e + return ( + EngineDeadError(suppress_context=True) if self.resources.engine_dead else e + ) def ensure_alive(self): if self.resources.engine_dead: @@ -541,8 +572,11 @@ class MPClient(EngineCoreClient): def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager - if (engine_manager is None or not hasattr(engine_manager, 'processes') - or not engine_manager.processes): + if ( + engine_manager is None + or not hasattr(engine_manager, "processes") + or not engine_manager.processes + ): # No engine processes to monitor return @@ -559,23 +593,26 @@ class MPClient(EngineCoreClient): if not _self or _self.resources.engine_dead: return _self.resources.engine_dead = True - proc_name = next(proc.name for proc in engine_processes - if proc.sentinel == died[0]) + proc_name = next( + proc.name for proc in engine_processes if proc.sentinel == died[0] + ) logger.error( - "Engine core proc %s died unexpectedly, " - "shutting down client.", proc_name) + "Engine core proc %s died unexpectedly, shutting down client.", + proc_name, + ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will # cause subsequent operations to raise EngineDeadError - Thread(target=monitor_engine_cores, - daemon=True, - name="MPClientEngineMonitor").start() + Thread( + target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor" + ).start() -def _process_utility_output(output: UtilityOutput, - utility_results: dict[int, AnyFuture]): +def _process_utility_output( + output: UtilityOutput, utility_results: dict[int, AnyFuture] +): """Set the result from a utility method in the waiting future.""" future = utility_results.pop(output.call_id) failure_message = output.failure_message @@ -590,15 +627,17 @@ def _process_utility_output(output: UtilityOutput, # original calling task being cancelled. if failure_message is not None: logger.error( - "Cancelled call to utility method failed " - "with error: %s", failure_message) + "Cancelled call to utility method failed with error: %s", + failure_message, + ) class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__( + self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool + ): super().__init__( asyncio_mode=False, vllm_config=vllm_config, @@ -641,8 +680,7 @@ class SyncMPClient(MPClient): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) except Exception as e: @@ -653,9 +691,11 @@ class SyncMPClient(MPClient): out_socket.close(linger=0) # Process outputs from engine in separate thread. - self.output_queue_thread = Thread(target=process_outputs_socket, - name="EngineCoreOutputQueueThread", - daemon=True) + self.output_queue_thread = Thread( + target=process_outputs_socket, + name="EngineCoreOutputQueueThread", + daemon=True, + ) self.output_queue_thread.start() # The thread takes on responsibility for closing the socket. @@ -676,8 +716,7 @@ class SyncMPClient(MPClient): self.ensure_alive() self.free_pending_messages() # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine, request_type.value, - *self.encoder.encode(request)) + msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. @@ -691,8 +730,7 @@ class SyncMPClient(MPClient): call_id = uuid.uuid1().int >> 64 future: Future[Any] = Future() self.utility_results[call_id] = future - self._send_input(EngineCoreRequestType.UTILITY, - (0, call_id, method, args)) + self._send_input(EngineCoreRequestType.UTILITY, (0, call_id, method, args)) return future.result() @@ -741,31 +779,33 @@ class SyncMPClient(MPClient): def execute_dummy_batch(self) -> None: self.call_utility("execute_dummy_batch") - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.call_utility("collective_rpc", method, timeout, args, - kwargs) + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return self.call_utility("collective_rpc", method, timeout, args, kwargs) - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): super().__init__( asyncio_mode=True, vllm_config=vllm_config, @@ -776,8 +816,7 @@ class AsyncMPClient(MPClient): self.client_count = client_count self.client_index = client_index - self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, - Exception]]() + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: # If we are running in an asyncio event loop, start the queue task. # Otherwise, it will be started lazily. If it is not started here, @@ -798,10 +837,9 @@ class AsyncMPClient(MPClient): decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], - Awaitable[None]]] = getattr( - self.__class__, - "process_engine_outputs", None) + output_handler: Optional[ + Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]] + ] = getattr(self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_socket = resources.output_socket assert output_socket is not None @@ -813,8 +851,7 @@ class AsyncMPClient(MPClient): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + _process_utility_output(outputs.utility_output, utility_results) continue if output_handler is not None: @@ -833,7 +870,8 @@ class AsyncMPClient(MPClient): outputs_queue.put_nowait(EngineDeadError()) resources.output_queue_task = asyncio.create_task( - process_outputs_socket(), name="EngineCoreOutputQueueTask") + process_outputs_socket(), name="EngineCoreOutputQueueTask" + ) async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() @@ -846,19 +884,21 @@ class AsyncMPClient(MPClient): raise self._format_exception(outputs) from None return outputs - def _send_input(self, - request_type: EngineCoreRequestType, - request: Any, - engine: Optional[EngineIdentity] = None) -> Awaitable[Any]: + def _send_input( + self, + request_type: EngineCoreRequestType, + request: Any, + engine: Optional[EngineIdentity] = None, + ) -> Awaitable[Any]: if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) - def _send_input_message(self, message: tuple[bytestr, - ...], engine: EngineIdentity, - objects: Any) -> Awaitable[Any]: + def _send_input_message( + self, message: tuple[bytestr, ...], engine: EngineIdentity, objects: Any + ) -> Awaitable[Any]: """ objects is a reference to retain until zmq is finished with the buffers, in case they were extracted from tensors in the request. @@ -866,7 +906,7 @@ class AsyncMPClient(MPClient): self.ensure_alive() self.free_pending_messages() - msg = (engine, ) + message + msg = (engine,) + message if not objects or len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. return self.input_socket.send_multipart(msg, copy=False) @@ -882,17 +922,18 @@ class AsyncMPClient(MPClient): return future async def call_utility_async(self, method: str, *args) -> Any: - return await self._call_utility_async(method, - *args, - engine=self.core_engine) + return await self._call_utility_async(method, *args, engine=self.core_engine) - async def _call_utility_async(self, method: str, *args, - engine: EngineIdentity) -> Any: + async def _call_utility_async( + self, method: str, *args, engine: EngineIdentity + ) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, method, args))) + message = ( + EngineCoreRequestType.UTILITY.value, + *self.encoder.encode((self.client_index, call_id, method, args)), + ) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future @@ -942,38 +983,46 @@ class AsyncMPClient(MPClient): async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) - async def save_sharded_state_async(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: - await self.call_utility_async("save_sharded_state", path, pattern, - max_size) + async def save_sharded_state_async( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: + await self.call_utility_async("save_sharded_state", path, pattern, max_size) async def collective_rpc_async( - self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return await self.call_utility_async("collective_rpc", method, timeout, - args, kwargs) + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return await self.call_utility_async( + "collective_rpc", method, timeout, args, kwargs + ) class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore. Assumes external load-balancing by default.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): self.current_wave = 0 - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + super().__init__( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. @@ -981,10 +1030,8 @@ class DPAsyncMPClient(AsyncMPClient): self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( - make_zmq_socket(self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=True)) + make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True) + ) try: # If we are running in an asyncio event loop, start the stats task. # Otherwise, it will be started lazily. @@ -1003,25 +1050,25 @@ class DPAsyncMPClient(AsyncMPClient): # NOTE: running and waiting counts are all global from # the Coordinator include all global EngineCores. This # slice includes just the cores managed by this client. - count_slice = slice(self.engine_ranks_managed[0], - self.engine_ranks_managed[-1] + 1) + count_slice = slice( + self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1 + ) async def run_engine_stats_update_task(): - with (make_zmq_socket(self.ctx, - self.stats_update_address, - zmq.XSUB, - linger=0) as socket, - make_zmq_socket(self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=False, - linger=0) as first_req_rcv_socket): + with ( + make_zmq_socket( + self.ctx, self.stats_update_address, zmq.XSUB, linger=0 + ) as socket, + make_zmq_socket( + self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 + ) as first_req_rcv_socket, + ): assert isinstance(socket, zmq.asyncio.Socket) assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) self.resources.stats_update_socket = socket self.resources.first_req_rcv_socket = first_req_rcv_socket # Send subscription message. - await socket.send(b'\x01') + await socket.send(b"\x01") poller = zmq.asyncio.Poller() poller.register(socket, zmq.POLLIN) @@ -1029,23 +1076,27 @@ class DPAsyncMPClient(AsyncMPClient): while True: events = await poller.poll() - if not self.engines_running and len(events) == 2 or ( - events[0][0] == first_req_rcv_socket): + if ( + not self.engines_running + and len(events) == 2 + or (events[0][0] == first_req_rcv_socket) + ): # Check if this is a regular request notification or # scale up notification - buf = first_req_rcv_socket.recv( - flags=zmq.NOBLOCK).result() + buf = first_req_rcv_socket.recv(flags=zmq.NOBLOCK).result() decoded = msgspec.msgpack.decode(buf) - if isinstance( - decoded, - (list, tuple)) and len(decoded) == 2 and decoded[ - 0] == "SCALE_ELASTIC_EP": + if ( + isinstance(decoded, (list, tuple)) + and len(decoded) == 2 + and decoded[0] == "SCALE_ELASTIC_EP" + ): # Extract new engine count from the decoded message new_engine_count = decoded[1] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_engine_count)) + ("SCALE_ELASTIC_EP", new_engine_count) + ) await socket.send(scale_msg) continue @@ -1056,14 +1107,14 @@ class DPAsyncMPClient(AsyncMPClient): target_eng_index = decoded[1] self.engines_running = True msg = msgspec.msgpack.encode( - (target_eng_index, self.current_wave)) + (target_eng_index, self.current_wave) + ) await socket.send(msg) buf = None while True: # Drain all stats events (we only care about latest). - future: asyncio.Future[bytes] = socket.recv( - flags=zmq.NOBLOCK) + future: asyncio.Future[bytes] = socket.recv(flags=zmq.NOBLOCK) if isinstance(future.exception(), zmq.Again): break buf = future.result() @@ -1077,11 +1128,13 @@ class DPAsyncMPClient(AsyncMPClient): if counts is not None: sliced_counts = counts[count_slice] self.lb_engines = sliced_counts - logger.debug("Received counts: %s (%s)", sliced_counts, - count_slice) + logger.debug( + "Received counts: %s (%s)", sliced_counts, count_slice + ) resources.stats_update_task = asyncio.create_task( - run_engine_stats_update_task()) + run_engine_stats_update_task() + ) async def add_request_async(self, request: EngineCoreRequest) -> None: self._ensure_stats_update_task() @@ -1090,8 +1143,7 @@ class DPAsyncMPClient(AsyncMPClient): request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request(request) - to_await = self._send_input(EngineCoreRequestType.ADD, request, - chosen_engine) + to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: # Notify coordinator that we're sending a request req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine)) @@ -1109,29 +1161,36 @@ class DPLBAsyncMPClient(DPAsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore. Load-balances between multiple engine processes.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): - + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): self.client_count = client_count # To route aborts to the correct engine. self.reqs_in_flight: dict[str, EngineIdentity] = {} - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + super().__init__( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) assert len(self.core_engines) > 1 - self.eng_start_index = (len(self.core_engines) * - self.client_index) // client_count + self.eng_start_index = ( + len(self.core_engines) * self.client_index + ) // client_count - def get_core_engine_for_request( - self, request: EngineCoreRequest) -> EngineIdentity: + def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. if (eng_index := request.data_parallel_rank) is None: current_counts = self.lb_engines @@ -1159,14 +1218,19 @@ class DPLBAsyncMPClient(DPAsyncMPClient): async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. - return (await asyncio.gather(*[ - self._call_utility_async(method, *args, engine=engine) - for engine in self.core_engines - ]))[0] + return ( + await asyncio.gather( + *[ + self._call_utility_async(method, *args, engine=engine) + for engine in self.core_engines + ] + ) + )[0] @staticmethod - async def process_engine_outputs(self: "DPLBAsyncMPClient", - outputs: EngineCoreOutputs): + async def process_engine_outputs( + self: "DPLBAsyncMPClient", outputs: EngineCoreOutputs + ): if outputs.finished_requests and self.reqs_in_flight: for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) @@ -1188,10 +1252,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): for engine, req_ids in by_engine.items(): await self._abort_requests(req_ids, engine) - async def _abort_requests(self, request_ids: list[str], - engine: EngineIdentity) -> None: - await self._send_input(EngineCoreRequestType.ABORT, request_ids, - engine) + async def _abort_requests( + self, request_ids: list[str], engine: EngineIdentity + ) -> None: + await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" @@ -1199,22 +1263,27 @@ class DPLBAsyncMPClient(DPAsyncMPClient): assert new_data_parallel_size != cur_data_parallel_size, ( f"new_data_parallel_size {new_data_parallel_size} must be " - f"different from cur_data_parallel_size {cur_data_parallel_size}") + f"different from cur_data_parallel_size {cur_data_parallel_size}" + ) - assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", "Only ray DP backend supports scaling elastic EP" + assert self.vllm_config.parallel_config.data_parallel_backend == "ray", ( + "Only ray DP backend supports scaling elastic EP" + ) scale_up = new_data_parallel_size > cur_data_parallel_size if scale_up: - await self._scale_up_elastic_ep(cur_data_parallel_size, - new_data_parallel_size) + await self._scale_up_elastic_ep( + cur_data_parallel_size, new_data_parallel_size + ) else: - await self._scale_down_elastic_ep(cur_data_parallel_size, - new_data_parallel_size) + await self._scale_down_elastic_ep( + cur_data_parallel_size, new_data_parallel_size + ) - async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + async def _scale_up_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: """Scale up the data parallel size by creating new engine cores and reconfiguring existing ones.""" cur_data_parallel_size = len(self.core_engines) @@ -1222,21 +1291,18 @@ class DPLBAsyncMPClient(DPAsyncMPClient): # Phase 1: Send reconfigure messages to all existing engines and wait # for them to be sent reconfig_futures = [] - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() for engine in self.core_engines: reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_rank_local=\ - ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) - coro = self._call_utility_async("reinitialize_distributed", - reconfig_request, - engine=engine) + new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + ) + coro = self._call_utility_async( + "reinitialize_distributed", reconfig_request, engine=engine + ) reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") @@ -1244,10 +1310,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): # Phase 2: Create new engines now that reconfig messages have been sent # self.resources.engine_manager is guaranteed to be # CoreEngineActorManager for RayDPClient - assert isinstance(self.resources.engine_manager, - CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.resources.engine_manager.scale_up_elastic_ep( - self.vllm_config, new_data_parallel_size) + self.vllm_config, new_data_parallel_size + ) # Create new CoreEngine objects for the new engines new_engine_identities = set() @@ -1262,7 +1328,8 @@ class DPLBAsyncMPClient(DPAsyncMPClient): if not sync_input_socket.poll(timeout=600_000): raise TimeoutError( "Timed out waiting for new engines to send initial " - "message on input socket.") + "message on input socket." + ) identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) @@ -1274,42 +1341,42 @@ class DPLBAsyncMPClient(DPAsyncMPClient): # stats_update_task connection self._ensure_stats_update_task() scale_up_marker = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_data_parallel_size)) + ("SCALE_ELASTIC_EP", new_data_parallel_size) + ) await self.first_req_send_socket.send(scale_up_marker) # Update the parallel config - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale up completed, new data parallel size: %s", - new_data_parallel_size) + new_data_parallel_size, + ) - async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + async def _scale_down_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: """Scale down the data parallel size by shutting down and reconfiguring existing engine cores.""" cur_data_parallel_size = len(self.core_engines) - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_rank_local=\ - ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) + new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + ) if cur_dp_rank >= new_data_parallel_size: - reconfig_request.new_data_parallel_rank = \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK - coro = self._call_utility_async("reinitialize_distributed", - reconfig_request, - engine=engine) + reconfig_request.new_data_parallel_rank = ( + ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ) + coro = self._call_utility_async( + "reinitialize_distributed", reconfig_request, engine=engine + ) reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): @@ -1317,18 +1384,19 @@ class DPLBAsyncMPClient(DPAsyncMPClient): await asyncio.gather(*reconfig_futures) - assert isinstance(self.resources.engine_manager, - CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.resources.engine_manager.scale_down_elastic_ep( - cur_data_parallel_size, new_data_parallel_size) + cur_data_parallel_size, new_data_parallel_size + ) self._ensure_stats_update_task() scale_down_marker = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_data_parallel_size)) + ("SCALE_ELASTIC_EP", new_data_parallel_size) + ) await self.first_req_send_socket.send(scale_down_marker) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale down completed, new data parallel size: %s", - new_data_parallel_size) + new_data_parallel_size, + ) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 0f993a74c8..9d1d7558b1 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -11,7 +11,10 @@ from transformers import PreTrainedTokenizerFast from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( - AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) + AnyTokenizer, + convert_prompt_ids_to_tokens, + detokenize_incrementally, +) from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest @@ -19,15 +22,13 @@ logger = init_logger(__name__) # Only tokenizers >= 0.21.1 supports DecodeStream used for # FastIncrementalDetokenizer. -USE_FAST_DETOKENIZER = version.parse( - tokenizers.__version__) >= version.parse("0.21.1") +USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1") # Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042 INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered" class IncrementalDetokenizer: - def __init__(self): self.token_ids: list[int] = [] @@ -35,8 +36,7 @@ class IncrementalDetokenizer: def output_token_ids(self) -> list[int]: return self.token_ids - def update(self, new_token_ids: list[int], - stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: self.token_ids.extend(new_token_ids) return None @@ -49,15 +49,13 @@ class IncrementalDetokenizer: tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "IncrementalDetokenizer": - assert request.sampling_params is not None if tokenizer is None: # No tokenizer => skipping detokenization. return IncrementalDetokenizer() - if USE_FAST_DETOKENIZER and isinstance(tokenizer, - PreTrainedTokenizerFast): + if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast): # Fast tokenizer => use tokenizers library DecodeStream. return FastIncrementalDetokenizer(tokenizer, request) @@ -66,7 +64,6 @@ class IncrementalDetokenizer: class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): - def __init__(self, request: EngineCoreRequest): super().__init__() @@ -88,8 +85,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # Generation data self.output_text = "" - def update(self, new_token_ids: list[int], - stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. @@ -117,8 +113,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): self.token_ids.append(new_token_id) self.output_text += self.decode_next(new_token_id) # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014 - if self.min_tokens and len( - self.output_token_ids) <= self.min_tokens: + if self.min_tokens and len(self.output_token_ids) <= self.min_tokens: stop_check_offset = len(self.output_text) if skipped_stop_token_id is not None: @@ -152,8 +147,11 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # We return the full output text if the sequence is finished. buffer_length = 0 if finished else self.stop_buffer_length if not delta: - return self.output_text[:-buffer_length] if buffer_length else ( - self.output_text) + return ( + self.output_text[:-buffer_length] + if buffer_length + else (self.output_text) + ) length = len(self.output_text) - buffer_length last_offset = self._last_output_text_offset if last_offset < length: @@ -163,9 +161,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): - - def __init__(self, tokenizer: PreTrainedTokenizerFast, - request: EngineCoreRequest): + def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest): super().__init__(request) sampling_params = request.sampling_params @@ -173,8 +169,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): self.request_id = request.request_id self.skip_special_tokens = sampling_params.skip_special_tokens - self.stream = DecodeStream( - skip_special_tokens=self.skip_special_tokens) + self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) self.tokenizer: Tokenizer = tokenizer._tokenizer @@ -185,7 +180,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): if prompt_len > 4: for i in range(4, min(prompt_len + 1, 24)): suffix = prompt_token_ids[-i:] - if '�' not in self.tokenizer.decode(suffix): + if "�" not in self.tokenizer.decode(suffix): prompt_suffix = suffix break @@ -195,17 +190,18 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): self.spaces_between_special_tokens = ( sampling_params.skip_special_tokens - or sampling_params.spaces_between_special_tokens) + or sampling_params.spaces_between_special_tokens + ) if not self.spaces_between_special_tokens: # Store dict of added token ids so that we can suppress # the spaces between them. - if (added_token_ids := getattr(self.tokenizer, "added_token_ids", - None)) is None: + if ( + added_token_ids := getattr(self.tokenizer, "added_token_ids", None) + ) is None: self.tokenizer.added_token_ids = added_token_ids = { tid: tok.content - for tid, tok in - self.tokenizer.get_added_tokens_decoder().items() + for tid, tok in self.tokenizer.get_added_tokens_decoder().items() } if added_token_ids: @@ -245,15 +241,15 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): # See https://github.com/vllm-project/vllm/issues/17448. logger.warning( "Encountered invalid prefix detokenization error" - " for request %s, resetting decode stream.", self.request_id) - self.stream = DecodeStream( - skip_special_tokens=self.skip_special_tokens) + " for request %s, resetting decode stream.", + self.request_id, + ) + self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) token = self.stream.step(self.tokenizer, next_token_id) return token class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): - def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): super().__init__(request) @@ -262,7 +258,8 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): assert params is not None self.prompt_len = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) + request.prompt_token_ids, request.prompt_embeds + ) # Metadata for incremental detokenization. if request.prompt_token_ids is not None: @@ -271,37 +268,37 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, skip_special_tokens=params.skip_special_tokens, - )) + ) + ) else: # Prompt embedding requests cannot be detokenized, in general. self.tokens = [""] * self.prompt_len self.prefix_offset = 0 self.read_offest = 0 - self.token_ids.extend(request.prompt_token_ids - or [0] * self.prompt_len) + self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len) self.skip_special_tokens = params.skip_special_tokens - self.spaces_between_special_tokens = ( - params.spaces_between_special_tokens) + self.spaces_between_special_tokens = params.spaces_between_special_tokens @property def output_token_ids(self) -> list[int]: - return self.token_ids if not self.prompt_len else ( - self.token_ids[self.prompt_len:]) + return ( + self.token_ids + if not self.prompt_len + else (self.token_ids[self.prompt_len :]) + ) def decode_next(self, next_token_id: int) -> str: - new_tokens, decoded_text, prefix_offset, read_offset = ( - detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=self.token_ids, - prev_tokens=self.tokens, - prefix_offset=self.prefix_offset, - read_offset=self.read_offset, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self. - spaces_between_special_tokens, - )) + new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + ) self.tokens.extend(new_tokens) self.prefix_offset = prefix_offset @@ -331,8 +328,7 @@ def check_stop_strings( for stop_str in stop: stop_string_len = len(stop_str) # Avoid searching already-searched text. - stop_index = output_text.find(stop_str, - 1 - new_char_count - stop_string_len) + stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len) if stop_index == -1: continue diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py index 692ba9dc84..d9f79a019e 100644 --- a/vllm/v1/engine/exceptions.py +++ b/vllm/v1/engine/exceptions.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project class EngineGenerateError(Exception): """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 3734c20800..9da25c0662 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,8 +23,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine import EngineCoreRequest @@ -62,12 +61,14 @@ class LLMEngine: "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) if stat_loggers is not None: raise NotImplementedError( "Passing StatLoggers to LLMEngine in V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") + "Set VLLM_USE_V1=0 and file and issue on Github." + ) self.vllm_config = vllm_config self.observability_config = vllm_config.observability_config @@ -76,15 +77,19 @@ class LLMEngine: self.log_stats = log_stats - executor_backend = ( - self.vllm_config.parallel_config.distributed_executor_backend) + executor_backend = self.vllm_config.parallel_config.distributed_executor_backend parallel_config = vllm_config.parallel_config - self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and - executor_backend == "external_launcher") + self.external_launcher_dp = ( + parallel_config.data_parallel_size > 1 + and executor_backend == "external_launcher" + ) # important: init dp group before init the engine_core # In the decoupled engine case this is handled in EngineCoreProc. - if not multiprocess_mode and parallel_config.data_parallel_size > 1 \ - and not self.external_launcher_dp: + if ( + not multiprocess_mode + and parallel_config.data_parallel_size > 1 + and not self.external_launcher_dp + ): self.dp_group = parallel_config.stateless_init_dp_group() else: self.dp_group = None @@ -95,20 +100,22 @@ class LLMEngine: else: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config) + model_config=vllm_config.model_config + ) # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config=vllm_config, - tokenizer=self.tokenizer, - mm_registry=mm_registry) + self.processor = Processor( + vllm_config=vllm_config, tokenizer=self.tokenizer, mm_registry=mm_registry + ) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + self.tokenizer, log_stats=self.log_stats + ) if self.observability_config.otlp_traces_endpoint is not None: tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) + "vllm.llm_engine", self.observability_config.otlp_traces_endpoint + ) self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) @@ -149,12 +156,14 @@ class LLMEngine: stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - return cls(vllm_config=vllm_config, - executor_class=Executor.get_class(vllm_config), - log_stats=(not disable_log_stats), - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING) + return cls( + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING, + ) @classmethod def from_engine_args( @@ -175,12 +184,14 @@ class LLMEngine: enable_multiprocessing = True # Create the LLMEngine. - return cls(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=enable_multiprocessing) + return cls( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing, + ) def get_num_unfinished_requests(self) -> int: return self.output_processor.get_num_unfinished_requests() @@ -193,7 +204,8 @@ class LLMEngine: def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: aggregated_has_unfinished = ParallelConfig.has_unfinished_dp( - self.dp_group, has_unfinished) + self.dp_group, has_unfinished + ) if not has_unfinished and aggregated_has_unfinished: self.should_execute_dummy_batch = True return aggregated_has_unfinished @@ -225,22 +237,28 @@ class LLMEngine: ) -> None: # Validate the request_id type. if not isinstance(request_id, str): - raise TypeError( - f"request_id must be a string, got {type(request_id)}") + raise TypeError(f"request_id must be a string, got {type(request_id)}") # Process raw inputs into the request. if isinstance(prompt, EngineCoreRequest): request = prompt else: assert prompt_text is None - logger.warning_once("Processor has been moved under LLM and will " - "be removed from LLMEngine in v0.13.") - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - tokenization_kwargs, - trace_headers, priority) - prompt_text = (prompt if isinstance(prompt, str) else - prompt.get("prompt")) + logger.warning_once( + "Processor has been moved under LLM and will " + "be removed from LLMEngine in v0.13." + ) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + ) + prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") n = params.n if isinstance(params, SamplingParams) else 1 @@ -260,13 +278,13 @@ class LLMEngine: child_request.sampling_params = params # Make a new RequestState and queue. - self.output_processor.add_request(child_request, prompt_text, - parent_req, idx) + self.output_processor.add_request( + child_request, prompt_text, parent_req, idx + ) # Add the request to EngineCore. self.engine_core.add_request(child_request) def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: - if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False self.engine_core.execute_dummy_batch() @@ -280,7 +298,8 @@ class LLMEngine: processed_outputs = self.output_processor.process_outputs( outputs.outputs, engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats) + iteration_stats=iteration_stats, + ) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) @@ -330,8 +349,9 @@ class LLMEngine: def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") + raise ValueError( + "Unable to get tokenizer because skip_tokenizer_init is True" + ) return self.tokenizer @@ -365,17 +385,21 @@ class LLMEngine: """Prevent an adapter from being evicted.""" return self.engine_core.pin_lora(lora_id) - def collective_rpc(self, - method: Union[str, Callable[[WorkerBase], _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[[WorkerBase], _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - return self.collective_rpc("apply_model", args=(func, )) + return self.collective_rpc("apply_model", args=(func,)) def __del__(self): - if dp_group := getattr(self, "dp_group", - None) and not self.external_launcher_dp: + if ( + dp_group := getattr(self, "dp_group", None) + and not self.external_launcher_dp + ): stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 133122b6fc..ab0e44fce1 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -9,7 +9,9 @@ from typing import Optional from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( - AnyTokenizer, convert_ids_list_to_tokens) + AnyTokenizer, + convert_ids_list_to_tokens, +) from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -20,7 +22,6 @@ NONES = itertools.repeat(None) @dataclass class LogprobsProcessor: - # Tokenizer for this request, # None if detokenization is disabled. tokenizer: Optional[AnyTokenizer] @@ -43,7 +44,7 @@ class LogprobsProcessor: num_prompt_logprobs = request.sampling_params.prompt_logprobs return cls( tokenizer=tokenizer, - cumulative_logprob=(None if num_logprobs is None else 0.), + cumulative_logprob=(None if num_logprobs is None else 0.0), logprobs=(None if num_logprobs is None else []), # NOTE: logprob of first prompt token is None. prompt_logprobs=(None if num_prompt_logprobs is None else [None]), @@ -68,12 +69,13 @@ class LogprobsProcessor: token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists - for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, - token_ids_lst): - + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): # Detokenize (non-incrementally). - decoded_tokens = NONES if self.tokenizer is None else ( - convert_ids_list_to_tokens(self.tokenizer, token_ids)) + decoded_tokens = ( + NONES + if self.tokenizer is None + else (convert_ids_list_to_tokens(self.tokenizer, token_ids)) + ) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] @@ -87,7 +89,8 @@ class LogprobsProcessor: decoded_tokens, rank, self.num_logprobs, - )) + ) + ) def _update_prompt_logprobs( self, @@ -109,9 +112,13 @@ class LogprobsProcessor: # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - decoded_tokens = None if self.tokenizer is None else ( - convert_ids_list_to_tokens(self.tokenizer, - token_ids.flatten().tolist())) + decoded_tokens = ( + None + if self.tokenizer is None + else ( + convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist()) + ) + ) # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape @@ -126,15 +133,20 @@ class LogprobsProcessor: # Handle flattening. offset = pos * num_logprobs offset_end = offset + num_logprobs - decoded_tokens_for_pos = NONES \ - if decoded_tokens is None else decoded_tokens[offset:offset_end] + decoded_tokens_for_pos = ( + NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] + ) # Update with the Logprob dictionary for this pos. self.prompt_logprobs.append( - self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], - decoded_tokens_for_pos, - prompt_token_ranks[pos], - self.num_prompt_logprobs)) + self._make_logprob_dict( + prompt_logprobs[pos], + token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs, + ) + ) def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: """Pop and return all request prompt logprobs @@ -182,7 +194,7 @@ class LogprobsProcessor: # being in the topk, since inserting duplicated data # into a dictionary twice is the same as doing it once. topk_ranks = range(1, num_logprobs + 1) - ranks = itertools.chain((rank, ), topk_ranks) + ranks = itertools.chain((rank,), topk_ranks) return { token_id: Logprob( @@ -191,7 +203,8 @@ class LogprobsProcessor: decoded_token=token, ) for token_id, logprob, rank, token in zip( - logprob_token_ids, logprobs, ranks, decoded_tokens) + logprob_token_ids, logprobs, ranks, decoded_tokens + ) } def update_from_output(self, output: EngineCoreOutput) -> None: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 46cb97d4e7..eb65b68969 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,19 +8,21 @@ from typing import Any, Optional, Union, cast import torch -from vllm.outputs import (CompletionOutput, PoolingOutput, - PoolingRequestOutput, RequestOutput) +from vllm.outputs import ( + CompletionOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, +) from vllm.sampling_params import RequestOutputKind -from vllm.tracing import (SpanAttributes, SpanKind, Tracer, - extract_trace_context) +from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats) +from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats class RequestOutputCollector: @@ -34,12 +36,14 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[Union[RequestOutput, PoolingRequestOutput, - Exception]] = None + self.output: Optional[Union[RequestOutput, PoolingRequestOutput, Exception]] = ( + None + ) self.ready = asyncio.Event() - def put(self, output: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: + def put( + self, output: Union[RequestOutput, PoolingRequestOutput, Exception] + ) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): self.output = output @@ -59,8 +63,7 @@ class RequestOutputCollector: raise output return output - def get_nowait( - self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: + def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: """Non-blocking get operation.""" output = self.output if output is not None: @@ -78,7 +81,6 @@ class OutputProcessorOutput: class RequestState: - def __init__( self, request_id: str, @@ -108,7 +110,8 @@ class RequestState: self.prompt_token_ids = prompt_token_ids self.prompt_embeds = prompt_embeds self.prompt_len = length_from_prompt_token_ids_or_embeds( - self.prompt_token_ids, self.prompt_embeds) + self.prompt_token_ids, self.prompt_embeds + ) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param @@ -119,8 +122,7 @@ class RequestState: self.queue = queue self.num_cached_tokens = 0 - self.stats = RequestStateStats( - arrival_time=arrival_time) if log_stats else None + self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None @classmethod def from_new_request( @@ -133,7 +135,6 @@ class RequestState: queue: Optional[RequestOutputCollector], log_stats: bool, ) -> "RequestState": - if sampling_params := request.sampling_params: if not sampling_params.detokenize: tokenizer = None @@ -164,8 +165,9 @@ class RequestState: request_id=request.request_id, parent_req=parent_req, request_index=request_index, - lora_name=(request.lora_request.name - if request.lora_request is not None else None), + lora_name=( + request.lora_request.name if request.lora_request is not None else None + ), output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, @@ -189,7 +191,6 @@ class RequestState: stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: - finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -200,22 +201,23 @@ class RequestState: request_id = self.request_id if pooling_output is not None: return self._new_request_output( - request_id, [self._new_pooling_output(pooling_output)], - finished) + request_id, [self._new_pooling_output(pooling_output)], finished + ) - output = self._new_completion_output(new_token_ids, finish_reason, - stop_reason) + output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) if self.parent_req is None: outputs = [output] else: request_id, outputs, finished = self.parent_req.get_outputs( - request_id, output) + request_id, output + ) if not outputs: return None - return self._new_request_output(request_id, outputs, finished, - kv_transfer_params) + return self._new_request_output( + request_id, outputs, finished, kv_transfer_params + ) def _new_request_output( self, @@ -224,7 +226,6 @@ class RequestState: finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Union[RequestOutput, PoolingRequestOutput]: - first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 @@ -248,15 +249,17 @@ class RequestState: if prompt_token_ids is None and self.prompt_embeds is not None: prompt_token_ids = [0] * len(self.prompt_embeds) - return RequestOutput(request_id=request_id, - prompt=self.prompt, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, - outputs=cast(list[CompletionOutput], outputs), - finished=finished, - kv_transfer_params=kv_transfer_params, - num_cached_tokens=self.num_cached_tokens, - metrics=self.stats) + return RequestOutput( + request_id=request_id, + prompt=self.prompt, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=cast(list[CompletionOutput], outputs), + finished=finished, + kv_transfer_params=kv_transfer_params, + num_cached_tokens=self.num_cached_tokens, + metrics=self.stats, + ) def _new_completion_output( self, @@ -264,7 +267,6 @@ class RequestState: finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], ) -> CompletionOutput: - assert self.detokenizer is not None assert self.logprobs_processor is not None finished = finish_reason is not None @@ -278,7 +280,7 @@ class RequestState: # Prepare logprobs, based on delta mode logprobs = self.logprobs_processor.logprobs if delta and logprobs: - logprobs = logprobs[-len(token_ids):] + logprobs = logprobs[-len(token_ids) :] return CompletionOutput( index=self.request_index, @@ -287,13 +289,13 @@ class RequestState: logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None) + stop_reason=stop_reason if finished else None, + ) def _new_pooling_output( self, pooling_output: torch.Tensor, ) -> PoolingOutput: - return PoolingOutput(data=pooling_output) @@ -333,15 +335,18 @@ class OutputProcessor: request_ids_to_abort.append(request_id) # Produce final abort output. if req_state.queue is not None and ( - request_output := req_state.make_request_output( - new_token_ids=[], - # Set pooling_output is not None to - # correctly enter the abort pooling branch - pooling_output=torch.randn(0, device="cpu") - if req_state.detokenizer is None else None, - finish_reason=FinishReason.ABORT, - stop_reason=None, - kv_transfer_params=None)): + request_output := req_state.make_request_output( + new_token_ids=[], + # Set pooling_output is not None to + # correctly enter the abort pooling branch + pooling_output=torch.randn(0, device="cpu") + if req_state.detokenizer is None + else None, + finish_reason=FinishReason.ABORT, + stop_reason=None, + kv_transfer_params=None, + ) + ): req_state.queue.put(request_output) elif parent := self.parent_requests.get(request_id): # Abort children prior to removing the parent. @@ -364,13 +369,15 @@ class OutputProcessor: if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request(tokenizer=self.tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request( + tokenizer=self.tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats, + ) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: @@ -404,8 +411,7 @@ class OutputProcessor: within the loop below. """ - request_outputs: Union[list[RequestOutput], - list[PoolingRequestOutput]] = [] + request_outputs: Union[list[RequestOutput], list[PoolingRequestOutput]] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id @@ -415,9 +421,9 @@ class OutputProcessor: continue # 1) Compute stats for this iteration. - self._update_stats_from_output(req_state, engine_core_output, - engine_core_timestamp, - iteration_stats) + self._update_stats_from_output( + req_state, engine_core_output, engine_core_timestamp, iteration_stats + ) new_token_ids = engine_core_output.new_token_ids pooling_output = engine_core_output.pooling_output @@ -432,20 +438,24 @@ class OutputProcessor: assert req_state.logprobs_processor is not None # 2) Detokenize the token ids into text and perform stop checks. stop_string = req_state.detokenizer.update( - new_token_ids, finish_reason == FinishReason.STOP) + new_token_ids, finish_reason == FinishReason.STOP + ) if stop_string: finish_reason = FinishReason.STOP stop_reason = stop_string # 3) Compute sample and prompt logprobs for request, # if required. - req_state.logprobs_processor.update_from_output( - engine_core_output) + req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params): + new_token_ids, + pooling_output, + finish_reason, + stop_reason, + kv_transfer_params, + ): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) @@ -466,11 +476,11 @@ class OutputProcessor: reqs_to_abort.append(req_id) # Track per-request stats - self._update_stats_from_finished(req_state, finish_reason, - iteration_stats) + self._update_stats_from_finished( + req_state, finish_reason, iteration_stats + ) if self.tracer: - self.do_tracing(engine_core_output, req_state, - iteration_stats) + self.do_tracing(engine_core_output, req_state, iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( @@ -478,9 +488,12 @@ class OutputProcessor: reqs_to_abort=reqs_to_abort, ) - def do_tracing(self, engine_core_output: EngineCoreOutput, - req_state: RequestState, - iteration_stats: Optional[IterationStats]) -> None: + def do_tracing( + self, + engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats], + ) -> None: assert req_state.stats is not None assert iteration_stats is not None assert self.tracer is not None @@ -488,59 +501,63 @@ class OutputProcessor: arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) trace_context = extract_trace_context(engine_core_output.trace_headers) prompt_length = length_from_prompt_token_ids_or_embeds( - req_state.prompt_token_ids, req_state.prompt_embeds) - with (self.tracer.start_as_current_span( - "llm_request", - kind=SpanKind.SERVER, - context=trace_context, - start_time=arrival_time_nano_seconds) as span): + req_state.prompt_token_ids, req_state.prompt_embeds + ) + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds, + ) as span: metrics = req_state.stats - e2e_time = iteration_stats.iteration_timestamp - \ - metrics.arrival_time + e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time queued_time = metrics.scheduled_ts - metrics.queued_ts prefill_time = metrics.first_token_ts - metrics.scheduled_ts decode_time = metrics.last_token_ts - metrics.first_token_ts inference_time = metrics.last_token_ts - metrics.scheduled_ts span.set_attribute( SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, - metrics.first_token_latency) + metrics.first_token_latency, + ) span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) - span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, - queued_time) - span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, - prompt_length) - span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, - metrics.num_generation_tokens) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length) span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, - prefill_time) + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens, + ) span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, - decode_time) + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time + ) span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, - inference_time) + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time + ) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time + ) # meta - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, - req_state.request_id) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id) if req_state.top_p: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, - req_state.top_p) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p) if req_state.max_tokens_param: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, - req_state.max_tokens_param) + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param + ) if req_state.temperature: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, - req_state.temperature) + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature + ) if req_state.n: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, - req_state.n) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n) - def _update_stats_from_output(self, req_state: RequestState, - engine_core_output: EngineCoreOutput, - engine_core_timestamp: Optional[float], - iteration_stats: Optional[IterationStats]): + def _update_stats_from_output( + self, + req_state: RequestState, + engine_core_output: EngineCoreOutput, + engine_core_timestamp: Optional[float], + iteration_stats: Optional[IterationStats], + ): if iteration_stats is None: return @@ -548,15 +565,21 @@ class OutputProcessor: assert engine_core_timestamp is not None assert req_state.stats is not None - iteration_stats.update_from_output(engine_core_output, - engine_core_timestamp, - req_state.is_prefilling, - req_state.prompt_len, - req_state.stats, lora_stats) + iteration_stats.update_from_output( + engine_core_output, + engine_core_timestamp, + req_state.is_prefilling, + req_state.prompt_len, + req_state.stats, + lora_stats, + ) - def _update_stats_from_finished(self, req_state: RequestState, - finish_reason: Optional[FinishReason], - iteration_stats: Optional[IterationStats]): + def _update_stats_from_finished( + self, + req_state: RequestState, + finish_reason: Optional[FinishReason], + iteration_stats: Optional[IterationStats], + ): if iteration_stats is None: return @@ -565,11 +588,13 @@ class OutputProcessor: iteration_stats.update_from_finished_request( finish_reason=finish_reason, num_prompt_tokens=length_from_prompt_token_ids_or_embeds( - req_state.prompt_token_ids, req_state.prompt_embeds), + req_state.prompt_token_ids, req_state.prompt_embeds + ), max_tokens_param=req_state.max_tokens_param, - req_stats=req_state.stats) + req_stats=req_state.stats, + ) self.lora_states.finish_request(req_state) ParentRequest.observe_finished_request( - req_state.parent_req, iteration_stats, - req_state.stats.num_generation_tokens) + req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens + ) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 1e9911152c..daf115c032 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -31,15 +31,16 @@ class ParentRequest: # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] - def __init__(self, request_id: str, - sampling_params: SamplingParams) -> None: + def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params self.child_requests = set() - self.output_aggregator = [None] * sampling_params.n if ( - sampling_params.output_kind - == RequestOutputKind.FINAL_ONLY) else [] + self.output_aggregator = ( + [None] * sampling_params.n + if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) + else [] + ) self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @@ -49,7 +50,7 @@ class ParentRequest: ) -> SamplingParams: """Efficiently obtain child `sampling_params` - If `sampling_params.seed` is not `None` then + If `sampling_params.seed` is not `None` then each child request requires a unique clone of parent `sampling_params` with a unique seed. @@ -76,10 +77,10 @@ class ParentRequest: def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. - + Args: index: index within `n` child requests. - + Returns: (request ID, sampling_params) tuple """ @@ -111,23 +112,25 @@ class ParentRequest: return self.request_id, outputs, finished def observe_num_generation_tokens(self, num_generation_tokens: int): - self.max_num_generation_tokens = max(num_generation_tokens, - self.max_num_generation_tokens) + self.max_num_generation_tokens = max( + num_generation_tokens, self.max_num_generation_tokens + ) return self.max_num_generation_tokens @staticmethod - def observe_finished_request(parent_req: Optional['ParentRequest'], - iteration_stats: IterationStats, - num_generation_tokens: int): - + def observe_finished_request( + parent_req: Optional["ParentRequest"], + iteration_stats: IterationStats, + num_generation_tokens: int, + ): n_param = parent_req.n if parent_req is not None else 1 if parent_req is not None: num_generation_tokens = parent_req.observe_num_generation_tokens( - num_generation_tokens) + num_generation_tokens + ) # Child requests finished, we can now record to iteration stats if parent_req is None or not parent_req.child_requests: - iteration_stats.max_num_generation_tokens_iter.append( - num_generation_tokens) + iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens) iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index c30ceb96a5..8a6ac0927e 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -21,27 +21,25 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest -from vllm.v1.structured_output.backend_guidance import ( - validate_guidance_grammar) +from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar from vllm.v1.structured_output.backend_lm_format_enforcer import ( - validate_structured_output_request_lm_format_enforcer) + validate_structured_output_request_lm_format_enforcer, +) from vllm.v1.structured_output.backend_outlines import ( - validate_structured_output_request_outlines) -from vllm.v1.structured_output.backend_xgrammar import ( - validate_xgrammar_grammar) + validate_structured_output_request_outlines, +) +from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar logger = init_logger(__name__) class Processor: - def __init__( self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -49,12 +47,10 @@ class Processor: self.structured_outputs_config = vllm_config.structured_outputs_config self.tokenizer = tokenizer - self.generation_config_fields = ( - self.model_config.try_get_generation_config()) + self.generation_config_fields = self.model_config.try_get_generation_config() self.mm_registry = mm_registry - self.mm_processor_cache = processor_cache_from_config( - vllm_config, mm_registry) + self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) self.input_preprocessor = InputPreprocessor( self.model_config, @@ -79,7 +75,8 @@ class Processor: if num_logprobs > max_logprobs: raise ValueError( f"Requested sample logprobs of {num_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + f"which is greater than max allowed: {max_logprobs}" + ) # Validate prompt logprobs. if params.prompt_logprobs: @@ -89,7 +86,8 @@ class Processor: if num_prompt_logprobs > max_logprobs: raise ValueError( f"Requested prompt logprobs of {num_prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + f"which is greater than max allowed: {max_logprobs}" + ) def _validate_sampling_params( self, @@ -108,8 +106,7 @@ class Processor: return vocab_size = len(self.tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): - raise ValueError( - "allowed_token_ids contains out-of-vocab token id!") + raise ValueError("allowed_token_ids contains out-of-vocab token id!") def _validate_logit_bias( self, @@ -129,7 +126,8 @@ class Processor: if invalid_token_ids: raise ValueError( f"token_id(s) {invalid_token_ids} in logit_bias contain " - f"out-of-vocab token ids. Vocabulary size: {vocab_size}") + f"out-of-vocab token ids. Vocabulary size: {vocab_size}" + ) def _validate_supported_sampling_params( self, @@ -140,8 +138,9 @@ class Processor: raise ValueError("vLLM V1 does not yet support best_of.") # Logits processors not supported. if params.logits_processors: - raise ValueError("vLLM V1 does not support per request " - "user provided logits processors.") + raise ValueError( + "vLLM V1 does not support per request user provided logits processors." + ) def _validate_params( self, @@ -178,18 +177,23 @@ class Processor: for modality, items in mm_data.items(): if modality in mm_uuids: data_len = len(items) if isinstance(items, list) else 1 - uuid_len = len(mm_uuids[modality]) if isinstance( - mm_uuids[modality], list) else 1 + uuid_len = ( + len(mm_uuids[modality]) + if isinstance(mm_uuids[modality], list) + else 1 + ) if uuid_len != data_len: raise ValueError( f"multi_modal_uuids for modality '{modality}' " "must have same length as data: got " f"{uuid_len} uuids vs " - f"{data_len} items.") + f"{data_len} items." + ) else: raise ValueError( f"multi_modal_uuids for modality '{modality}' must " - "be provided if multi_modal_data is provided.") + "be provided if multi_modal_data is provided." + ) # Handle explicit encoder/decoder prompts or singleton prompt if isinstance(prompt, dict) and "encoder_prompt" in prompt: @@ -208,8 +212,9 @@ class Processor: # LoRA request passed in while LoRA is not enabled if not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + raise ValueError( + f"Got lora_request {lora_request} but LoRA is not enabled!" + ) if self.tokenizer is not None: logger.warning_once( @@ -217,7 +222,8 @@ class Processor: "tokenizers for different LoRAs. By default, vLLM uses base " "model's tokenizer. If you are using a LoRA " "with its own tokenizer, consider specifying `--tokenizer " - "[lora_path]` to use the LoRA tokenizer.") + "[lora_path]` to use the LoRA tokenizer." + ) def _validate_structured_output(self, params: SamplingParams) -> None: if not params.structured_outputs or not self.structured_outputs_config: @@ -235,20 +241,23 @@ class Processor: # to a specific backend based on `auto` behavior in a previous # request. We remember that it was set as a result of `auto` # using the `_backend_was_auto` field set in the params. - if (backend != _backend - and not (backend == "auto" - and params.structured_outputs._backend_was_auto)): + if backend != _backend and not ( + backend == "auto" and params.structured_outputs._backend_was_auto + ): raise ValueError( "Request-level structured output backend selection is not " f"supported. The request specified '{_backend}', but vLLM " f"was initialised with '{backend}'. This error can be " - "resolved by removing '_backend' from the request.") + "resolved by removing '_backend' from the request." + ) else: params.structured_outputs._backend = backend # Request content validation - if (isinstance(params.structured_outputs.choice, list) - and not params.structured_outputs.choice): + if ( + isinstance(params.structured_outputs.choice, list) + and not params.structured_outputs.choice + ): # It is invalid for choice to be an empty list raise ValueError( f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501 @@ -318,9 +327,7 @@ class Processor: mm_uuids: MultiModalUUIDDict = {} for modality, data in mm_data.items(): n = len(data) if isinstance(data, list) else 1 - mm_uuids[modality] = [ - f"{request_id}-{modality}-{i}" for i in range(n) - ] + mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] return mm_uuids def process_inputs( @@ -339,10 +346,13 @@ class Processor: self._validate_params(params) data_parallel_size = self.vllm_config.parallel_config.data_parallel_size - if data_parallel_rank is not None and not (0 <= data_parallel_rank < - data_parallel_size): - raise ValueError(f"data_parallel_rank {data_parallel_rank} " - f"is out of range [0, {data_parallel_size}).") + if data_parallel_rank is not None and not ( + 0 <= data_parallel_rank < data_parallel_size + ): + raise ValueError( + f"data_parallel_rank {data_parallel_rank} " + f"is out of range [0, {data_parallel_size})." + ) if arrival_time is None: arrival_time = time.time() @@ -355,9 +365,11 @@ class Processor: # reused across requests, therefore identifying multimodal data items # by their content is no longer necessary, and we create uuids with # request id-modality-index as multimodal hash overrides. - if (self.model_config.multimodal_config and - self.model_config.multimodal_config.mm_processor_cache_gb == 0 - and not self.cache_config.enable_prefix_caching): + if ( + self.model_config.multimodal_config + and self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching + ): mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) else: # Otherwise, use user-provided uuids as multimodal hash overrides @@ -378,6 +390,7 @@ class Processor: mm_uuids=mm_uuids, ) from vllm.platforms import current_platform + current_platform.validate_request( prompt=prompt, params=params, @@ -393,10 +406,16 @@ class Processor: # discriminated unions of TypedDicts, because of how it handles # inheritance of TypedDict. If we explicitly extract the items we want # we can avoid type errors from using `dict.get` later in the method. - prompt_token_ids = decoder_inputs[ - "prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None - prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[ - "type"] == "embeds" else None + prompt_token_ids = ( + decoder_inputs["prompt_token_ids"] + if decoder_inputs["type"] != "embeds" + else None + ) + prompt_embeds = ( + decoder_inputs["prompt_embeds"] + if decoder_inputs["type"] == "embeds" + else None + ) sampling_params = None pooling_params = None @@ -406,11 +425,12 @@ class Processor: # If unset max tokens, then generate up to the max_model_len. if sampling_params.max_tokens is None: seq_len = length_from_prompt_token_ids_or_embeds( - prompt_token_ids, prompt_embeds) - sampling_params.max_tokens = \ - self.model_config.max_model_len - seq_len + prompt_token_ids, prompt_embeds + ) + sampling_params.max_tokens = self.model_config.max_model_len - seq_len sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) + self.generation_config_fields, eos_token_id + ) if self.tokenizer is not None: sampling_params.update_from_tokenizer(self.tokenizer) else: @@ -436,7 +456,9 @@ class Processor: data=decoder_mm_inputs[modality][idx], modality=modality, identifier=decoder_mm_hashes[modality][idx], - mm_position=decoder_mm_positions[modality][idx])) + mm_position=decoder_mm_positions[modality][idx], + ) + ) return EngineCoreRequest( request_id=request_id, @@ -454,8 +476,9 @@ class Processor: trace_headers=trace_headers, ) - def _validate_model_inputs(self, encoder_inputs: Optional[SingletonInputs], - decoder_inputs: SingletonInputs): + def _validate_model_inputs( + self, encoder_inputs: Optional[SingletonInputs], decoder_inputs: SingletonInputs + ): if encoder_inputs is not None: self._validate_model_input(encoder_inputs, prompt_type="encoder") @@ -469,12 +492,17 @@ class Processor: ): model_config = self.model_config - prompt_ids = None if prompt_inputs[ - "type"] == "embeds" else prompt_inputs["prompt_token_ids"] - prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[ - "type"] == "embeds" else None - prompt_len = length_from_prompt_token_ids_or_embeds( - prompt_ids, prompt_embeds) + prompt_ids = ( + None + if prompt_inputs["type"] == "embeds" + else prompt_inputs["prompt_token_ids"] + ) + prompt_embeds = ( + prompt_inputs["prompt_embeds"] + if prompt_inputs["type"] == "embeds" + else None + ) + prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data @@ -499,10 +527,10 @@ class Processor: # Here we take the max of the two to determine if a token id is # truly out-of-vocabulary. - if max_input_id > max(tokenizer.max_token_id, - self.model_config.get_vocab_size() - 1): - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") + if max_input_id > max( + tokenizer.max_token_id, self.model_config.get_vocab_size() - 1 + ): + raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if prompt_len > max_prompt_len: @@ -522,16 +550,19 @@ class Processor: "Make sure that `max_model_len` is no smaller than the " "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") + "of images, and possibly their aspect ratios as well." + ) else: suggestion = ( "Make sure that `max_model_len` is no smaller than the " - "number of text tokens.") + "number of text tokens." + ) raise ValueError( f"The {prompt_type} prompt (length {prompt_len}) is " f"longer than the maximum model length of {max_prompt_len}. " - f"{suggestion}") + f"{suggestion}" + ) # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 18ef25ceb6..c78d71c323 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -70,6 +70,7 @@ class EngineHandshakeMetadata: including addresses of the front-end ZMQ queues that they should connect to. """ + addresses: EngineZmqAddresses parallel_config: dict[str, Union[int, str, list[int]]] @@ -103,8 +104,7 @@ class CoreEngineProcManager: } if client_handshake_address: - common_kwargs[ - "client_handshake_address"] = client_handshake_address + common_kwargs["client_handshake_address"] = client_handshake_address self.processes: list[BaseProcess] = [] local_dp_ranks = [] @@ -115,21 +115,27 @@ class CoreEngineProcManager: # Start EngineCore in background process. local_dp_ranks.append(local_index) self.processes.append( - context.Process(target=target_fn, - name=f"EngineCore_DP{global_index}", - kwargs=common_kwargs | { - "dp_rank": global_index, - "local_dp_rank": local_index, - })) + context.Process( + target=target_fn, + name=f"EngineCore_DP{global_index}", + kwargs=common_kwargs + | { + "dp_rank": global_index, + "local_dp_rank": local_index, + }, + ) + ) self._finalizer = weakref.finalize(self, shutdown, self.processes) data_parallel = vllm_config.parallel_config.data_parallel_size > 1 try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): - with set_device_control_env_var( - vllm_config, local_dp_rank) if ( - data_parallel) else contextlib.nullcontext(): + with ( + set_device_control_env_var(vllm_config, local_dp_rank) + if (data_parallel) + else contextlib.nullcontext() + ): proc.start() finally: # Kill other procs if not all are running. @@ -151,13 +157,15 @@ class CoreEngineProcManager: """Returns dict of proc name -> exit code for any finished procs.""" return { proc.name: proc.exitcode - for proc in self.processes if proc.exitcode is not None + for proc in self.processes + if proc.exitcode is not None } @contextlib.contextmanager -def set_device_control_env_var(vllm_config: VllmConfig, - local_dp_rank: int) -> Iterator[None]: +def set_device_control_env_var( + vllm_config: VllmConfig, local_dp_rank: int +) -> Iterator[None]: """ Temporarily set CUDA_VISIBLE_DEVICES or equivalent for engine subprocess. @@ -166,12 +174,13 @@ def set_device_control_env_var(vllm_config: VllmConfig, evar = current_platform.device_control_env_var value = get_device_indices(evar, local_dp_rank, world_size) - with patch.dict(os.environ, values=((evar, value), )): + with patch.dict(os.environ, values=((evar, value),)): yield -def get_device_indices(device_control_env_var: str, local_dp_rank: int, - world_size: int): +def get_device_indices( + device_control_env_var: str, local_dp_rank: int, world_size: int +): """ Returns a comma-separated string of device indices for the specified data parallel rank. @@ -182,14 +191,16 @@ def get_device_indices(device_control_env_var: str, local_dp_rank: int, try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * - world_size)) + for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size) + ) except IndexError as e: - raise Exception(f"Error setting {device_control_env_var}: " - f"local range: [{local_dp_rank * world_size}, " - f"{(local_dp_rank + 1) * world_size}) " - "base value: " - f"\"{os.getenv(device_control_env_var)}\"") from e + raise Exception( + f"Error setting {device_control_env_var}: " + f"local range: [{local_dp_rank * world_size}, " + f"{(local_dp_rank + 1) * world_size}) " + "base value: " + f'"{os.getenv(device_control_env_var)}"' + ) from e return value @@ -215,8 +226,7 @@ class CoreEngineActorManager: import ray from ray.runtime_env import RuntimeEnv - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.v1.engine.core import DPEngineCoreActor @@ -225,8 +235,7 @@ class CoreEngineActorManager: env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") self.env_vars_dict = { - name: os.environ[name] - for name in env_vars_list if name in os.environ + name: os.environ[name] for name in env_vars_list if name in os.environ } runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) @@ -234,37 +243,38 @@ class CoreEngineActorManager: self.executor_class = executor_class self.log_stats = log_stats dp_size = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local + local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size if ray.is_initialized(): - logger.info( - "Ray is already initialized. Skipping Ray initialization.") + logger.info("Ray is already initialized. Skipping Ray initialization.") else: ray.init() if placement_groups is not None: assert local_dp_ranks is not None, ( - "local_dp_ranks must be provided if " - "placement_groups is provided") + "local_dp_ranks must be provided if placement_groups is provided" + ) assert len(placement_groups) == len(local_dp_ranks), ( - "placement_groups and local_dp_ranks must " - "have the same length") + "placement_groups and local_dp_ranks must have the same length" + ) logger.info("Using provided placement groups") # TODO(rui): validate passed-in placement groups self.created_placement_groups = [] else: - placement_groups, local_dp_ranks = \ + placement_groups, local_dp_ranks = ( CoreEngineActorManager.create_dp_placement_groups(vllm_config) + ) self.created_placement_groups = placement_groups assert len(placement_groups) == dp_size, ( - "Number of placement groups must match data parallel size") + "Number of placement groups must match data parallel size" + ) self.placement_group_is_local = [] refs = [] - for index, local_index, pg in zip(range(dp_size), local_dp_ranks, - placement_groups): + for index, local_index, pg in zip( + range(dp_size), local_dp_ranks, placement_groups + ): dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count @@ -275,24 +285,32 @@ class CoreEngineActorManager: # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 if current_platform.is_xpu(): device_evar = current_platform.device_control_env_var - device_indices = get_device_indices(device_evar, local_index, - world_size) + device_indices = get_device_indices( + device_evar, local_index, world_size + ) actor_env_vars = self.env_vars_dict.copy() actor_env_vars[device_evar] = device_indices runtime_env = RuntimeEnv(env_vars=actor_env_vars) - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - ), - runtime_env=runtime_env).remote(vllm_config=dp_vllm_config, - executor_class=executor_class, - log_stats=log_stats, - local_client=local_client, - addresses=addresses, - dp_rank=index, - local_dp_rank=local_index) + actor = ( + ray.remote(DPEngineCoreActor) + .options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env, + ) + .remote( + vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + local_client=local_client, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index, + ) + ) if local_client: self.local_engine_actors.append(actor) else: @@ -307,7 +325,7 @@ class CoreEngineActorManager: @staticmethod def create_dp_placement_groups( - vllm_config: VllmConfig + vllm_config: VllmConfig, ) -> tuple[list["PlacementGroup"], list[int]]: """ Create placement groups for data parallel. @@ -317,23 +335,23 @@ class CoreEngineActorManager: from ray._private.state import available_resources_per_node logger.info("Creating placement groups for data parallel") - dp_master_ip = \ - vllm_config.parallel_config.data_parallel_master_ip + dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip num_pg_to_create = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local + local_engine_count = vllm_config.parallel_config.data_parallel_size_local available_resources = available_resources_per_node() world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] - dp_master_ip_key = f'node:{dp_master_ip}' - nodes = sorted(available_resources.values(), - key=lambda x: dp_master_ip_key not in x) - assert len(nodes) > 0, ( - "No nodes with resources found in Ray cluster.") + dp_master_ip_key = f"node:{dp_master_ip}" + nodes = sorted( + available_resources.values(), key=lambda x: dp_master_ip_key not in x + ) + assert len(nodes) > 0, "No nodes with resources found in Ray cluster." assert dp_master_ip_key in nodes[0], ( - "The DP master node (ip: %s) is missing or dead", dp_master_ip) + "The DP master node (ip: %s) is missing or dead", + dp_master_ip, + ) device_str = current_platform.ray_device_key for node_resources in nodes: if device_str not in node_resources: @@ -341,19 +359,16 @@ class CoreEngineActorManager: # For now, each DP rank can only be assigned to one node # TODO(rui): support allocating a single DP rank # to multiple nodes - available_engine_count = int( - node_resources[device_str]) // world_size + available_engine_count = int(node_resources[device_str]) // world_size if dp_master_ip_key in node_resources: assert available_engine_count >= local_engine_count, ( "Not enough resources to allocate DP ranks " - f"on DP master node {dp_master_ip}") + f"on DP master node {dp_master_ip}" + ) for i in range(local_engine_count): - bundles = [{ - device_str: 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] + bundles = [ + {device_str: 1.0, "node:" + dp_master_ip: 0.001} + ] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( name=f"dp_rank_{len(placement_groups)}", strategy="STRICT_PACK", @@ -379,7 +394,8 @@ class CoreEngineActorManager: "placement groups, only created " f"{len(placement_groups)} placement groups. " "Available resources: " - f"{available_resources}") + f"{available_resources}" + ) return placement_groups, local_dp_ranks @staticmethod @@ -390,8 +406,10 @@ class CoreEngineActorManager: Add placement groups for new data parallel size. """ import ray - from ray._private.state import (available_resources_per_node, - total_resources_per_node) + from ray._private.state import ( + available_resources_per_node, + total_resources_per_node, + ) from ray.util.state import list_nodes old_dp_size = old_vllm_config.parallel_config.data_parallel_size @@ -405,10 +423,10 @@ class CoreEngineActorManager: nodes = list_nodes() nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The first node must be the head node") + assert nodes[0].node_ip == dp_master_ip, "The first node must be the head node" assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") + "There can only be one head node" + ) available_resources = available_resources_per_node() total_resources = total_resources_per_node() @@ -446,12 +464,9 @@ class CoreEngineActorManager: # Create bundles with node constraint for master node if node_ip == dp_master_ip: - bundles = [{ - device_str: 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] + bundles = [ + {device_str: 1.0, "node:" + dp_master_ip: 0.001} + ] * world_size + [{"CPU": 1.0}] else: bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}] @@ -470,69 +485,76 @@ class CoreEngineActorManager: return placement_groups, local_dp_ranks - def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig, - new_data_parallel_size: int) -> None: + def scale_up_elastic_ep( + self, cur_vllm_config: VllmConfig, new_data_parallel_size: int + ) -> None: import copy import ray from ray.runtime_env import RuntimeEnv - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.v1.engine.core import DPEngineCoreActor - cur_data_parallel_size = len(self.local_engine_actors) + \ - len(self.remote_engine_actors) + cur_data_parallel_size = len(self.local_engine_actors) + len( + self.remote_engine_actors + ) assert new_data_parallel_size > cur_data_parallel_size, ( f"New data parallel size {new_data_parallel_size} must be greater " f"than current data parallel size {cur_data_parallel_size} " - "for scale up") + "for scale up" + ) - placement_groups, local_dp_ranks = \ - self.add_dp_placement_groups( - cur_vllm_config, new_data_parallel_size) + placement_groups, local_dp_ranks = self.add_dp_placement_groups( + cur_vllm_config, new_data_parallel_size + ) world_size = cur_vllm_config.parallel_config.world_size dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip new_local_engines = 0 - runtime_env = RuntimeEnv(env_vars=self.env_vars_dict - | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"}) - for i, (pg, - local_rank) in enumerate(zip(placement_groups, - local_dp_ranks)): + runtime_env = RuntimeEnv( + env_vars=self.env_vars_dict | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"} + ) + for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)): rank = cur_data_parallel_size + i dp_vllm_config = copy.deepcopy(cur_vllm_config) - dp_vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + dp_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size dp_vllm_config.parallel_config.placement_group = pg # Check if this placement group is on the head node local_client = any( - bundle.get("node:" + dp_master_ip, 0) > 0 - for bundle in pg.bundle_specs) + bundle.get("node:" + dp_master_ip, 0) > 0 for bundle in pg.bundle_specs + ) if local_client: new_local_engines += 1 # Update data_parallel_size_local dp_vllm_config.parallel_config.data_parallel_size_local = ( - cur_vllm_config.parallel_config.data_parallel_size_local + - new_local_engines) + cur_vllm_config.parallel_config.data_parallel_size_local + + new_local_engines + ) - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - ), - runtime_env=runtime_env).remote( + actor = ( + ray.remote(DPEngineCoreActor) + .options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env, + ) + .remote( vllm_config=dp_vllm_config, executor_class=self.executor_class, log_stats=self.log_stats, local_client=local_client, addresses=self.addresses, dp_rank=rank, - local_dp_rank=local_rank) + local_dp_rank=local_rank, + ) + ) if local_client: self.local_engine_actors.append(actor) @@ -541,37 +563,47 @@ class CoreEngineActorManager: self.created_placement_groups.append(pg) self.placement_group_is_local.append(local_client) - ray.get([ - actor.wait_for_init.remote() - for actor in (self.local_engine_actors[-new_local_engines:] - if new_local_engines > 0 else []) + - self.remote_engine_actors[-(len(placement_groups) - - new_local_engines):] - ]) + ray.get( + [ + actor.wait_for_init.remote() + for actor in ( + self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 + else [] + ) + + self.remote_engine_actors[ + -(len(placement_groups) - new_local_engines) : + ] + ] + ) - actors = (self.local_engine_actors[-new_local_engines:] - if new_local_engines > 0 else []) + \ - self.remote_engine_actors[-(len(placement_groups) - - new_local_engines):] + actors = ( + self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 + else [] + ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :] for actor in actors: self.run_refs.append(actor.run.remote()) - cur_vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Update old_vllm_config with new data_parallel_size_local if any new # local engines were added if new_local_engines > 0: - cur_vllm_config.parallel_config.data_parallel_size_local += \ + cur_vllm_config.parallel_config.data_parallel_size_local += ( new_local_engines + ) - def scale_down_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + def scale_down_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: import ray + assert cur_data_parallel_size > new_data_parallel_size, ( f"cur_data_parallel_size {cur_data_parallel_size} must be greater " f"than new_data_parallel_size {new_data_parallel_size} " - "for scale down") + "for scale down" + ) for _ in range(cur_data_parallel_size - new_data_parallel_size): pg = self.created_placement_groups.pop() is_local = self.placement_group_is_local.pop() @@ -586,6 +618,7 @@ class CoreEngineActorManager: def close(self): import ray + for actor in self.local_engine_actors + self.remote_engine_actors: ray.kill(actor) for pg in self.created_placement_groups: @@ -598,11 +631,13 @@ def launch_core_engines( executor_class: type[Executor], log_stats: bool, num_api_servers: int = 1, -) -> Iterator[tuple[ +) -> Iterator[ + tuple[ Optional[Union[CoreEngineProcManager, CoreEngineActorManager]], Optional[DPCoordinator], EngineZmqAddresses, -]]: + ] +]: """Launch engine and DP coordinator processes as needed.""" parallel_config = vllm_config.parallel_config @@ -611,8 +646,10 @@ def launch_core_engines( local_start_index = parallel_config.data_parallel_rank_local dp_rank = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip - local_engines_only = (parallel_config.data_parallel_hybrid_lb - or parallel_config.data_parallel_external_lb) + local_engines_only = ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) # In offline mode there is an LLM instance per DP rank and # one core engine per LLM, see @@ -621,8 +658,9 @@ def launch_core_engines( # client_local_only = True for cases where this front-end # sends requests only to colocated engines. - client_local_only = (offline_mode or local_engines_only - or (local_engine_count == dp_size)) + client_local_only = ( + offline_mode or local_engines_only or (local_engine_count == dp_size) + ) # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -644,12 +682,13 @@ def launch_core_engines( coordinator = DPCoordinator(parallel_config) addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) + coordinator.get_engine_socket_addresses() + ) addresses.frontend_stats_publish_address = ( - coordinator.get_stats_publish_address()) + coordinator.get_stats_publish_address() + ) - logger.info("Started DP Coordinator process (PID: %d)", - coordinator.proc.pid) + logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid) else: coordinator = None @@ -675,14 +714,14 @@ def launch_core_engines( # Note this also covers the case where we have zero local engines # and rank 0 is headless. engines_to_handshake = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) + CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] else: # Rank > 0 handshakes with just the local cores it is managing. assert local_engines_only, ( "Attempting to launch core_engines from dp_rank > 0, but " - "found internal DPLB, which is incompatible.") + "found internal DPLB, which is incompatible." + ) engines_to_handshake = [ CoreEngine(index=i, local=True) for i in range(dp_rank, dp_rank + local_engine_count) @@ -695,7 +734,8 @@ def launch_core_engines( handshake_local_only = offline_mode or local_engine_count == dp_size handshake_address = get_engine_client_zmq_addr( - handshake_local_only, host, parallel_config.data_parallel_rpc_port) + handshake_local_only, host, parallel_config.data_parallel_rpc_port + ) if local_engines_only and dp_rank > 0: assert not handshake_local_only @@ -705,9 +745,9 @@ def launch_core_engines( local_handshake_address = handshake_address client_handshake_address = None - with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, - bind=True) as handshake_socket: - + with zmq_socket_ctx( + local_handshake_address, zmq.ROUTER, bind=True + ) as handshake_socket: from vllm.v1.engine.core import EngineCoreProc # Start local engines. @@ -722,7 +762,8 @@ def launch_core_engines( local_client=True, local_engine_count=local_engine_count, start_index=dp_rank, - local_start_index=local_start_index or 0) + local_start_index=local_start_index or 0, + ) else: local_engine_manager = None @@ -757,8 +798,10 @@ def wait_for_engine_startup( poller = zmq.Poller() poller.register(handshake_socket, zmq.POLLIN) - remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \ + remote_should_be_headless = ( + not parallel_config.data_parallel_hybrid_lb and not parallel_config.data_parallel_external_lb + ) if proc_manager is not None: for sentinel in proc_manager.sentinels(): @@ -770,67 +813,73 @@ def wait_for_engine_startup( if not events: if any(conn_pending): logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) + "Waiting for %d local, %d remote core engine proc(s) to connect.", + *conn_pending, + ) if any(start_pending): logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) + "Waiting for %d local, %d remote core engine proc(s) to start.", + *start_pending, + ) continue if len(events) > 1 or events[0][0] != handshake_socket: # One of the local core processes exited. finished = proc_manager.finished_procs() if proc_manager else {} if coord_process is not None and coord_process.exitcode is not None: finished[coord_process.name] = coord_process.exitcode - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") + raise RuntimeError( + "Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}" + ) # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, "little") - engine = next((e for e in core_engines if e.identity == eng_identity), - None) + engine = next((e for e in core_engines if e.identity == eng_identity), None) if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") + raise RuntimeError( + f"Message from engine with unexpected data parallel rank: {eng_index}" + ) msg = msgspec.msgpack.decode(ready_msg_bytes) status, local, headless = msg["status"], msg["local"], msg["headless"] if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") + raise RuntimeError( + f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}" + ) # Remote engines must be headless iff we aren't in hybrid dp lb mode. if not local and headless != remote_should_be_headless: if headless: - raise RuntimeError(f"Remote engine {eng_index} must not use " - f"--headless in external or hybrid dp lb " - f"mode") + raise RuntimeError( + f"Remote engine {eng_index} must not use " + f"--headless in external or hybrid dp lb " + f"mode" + ) else: - raise RuntimeError(f"Remote engine {eng_index} must use " - f"--headless unless in external or hybrid " - f"dp lb mode") + raise RuntimeError( + f"Remote engine {eng_index} must use " + f"--headless unless in external or hybrid " + f"dp lb mode" + ) if status == "HELLO" and engine.state == CoreEngineState.NEW: - # Send init message with DP config info. init_message = msgspec.msgpack.encode( EngineHandshakeMetadata( addresses=addresses, parallel_config={ - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "_data_parallel_master_port_list": - parallel_config._data_parallel_master_port_list, - "data_parallel_size": - parallel_config.data_parallel_size, - })) - handshake_socket.send_multipart((eng_identity, init_message), - copy=False) + "data_parallel_master_ip": parallel_config.data_parallel_master_ip, + "data_parallel_master_port": parallel_config.data_parallel_master_port, + "_data_parallel_master_port_list": parallel_config._data_parallel_master_port_list, + "data_parallel_size": parallel_config.data_parallel_size, + }, + ) + ) + handshake_socket.send_multipart((eng_identity, init_message), copy=False) conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 engine.state = CoreEngineState.CONNECTED @@ -846,15 +895,20 @@ def wait_for_engine_startup( # one of the engine handshakes, and passed to the local # front-end process in the response from the other. if addresses.frontend_stats_publish_address is None: - addresses.frontend_stats_publish_address = msg.get( - "dp_stats_address") + addresses.frontend_stats_publish_address = msg.get("dp_stats_address") start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") + raise RuntimeError( + f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state." + ) - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + logger.debug( + "%s from %s core engine process %s.", + status, + "local" if local else "remote", + eng_index, + ) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 625017d52f..064e4b2bbf 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -10,9 +10,9 @@ import torch.distributed as dist from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) -from vllm.executor.uniproc_executor import ( # noqa - UniProcExecutor as UniProcExecutorV0) + ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0, +) +from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa from vllm.utils import resolve_obj_by_qualname from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -30,21 +30,24 @@ class Executor(ExecutorBase): def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class: type[Executor] parallel_config = vllm_config.parallel_config - distributed_executor_backend = ( - parallel_config.distributed_executor_backend) + distributed_executor_backend = parallel_config.distributed_executor_backend # distributed_executor_backend must be set in VllmConfig.__post_init__ if isinstance(distributed_executor_backend, type): if not issubclass(distributed_executor_backend, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") + f"ExecutorBase. Got {distributed_executor_backend}." + ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": from vllm.v1.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor) + RayDistributedExecutor, + ) + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor elif distributed_executor_backend == "uni": executor_class = UniProcExecutor @@ -53,25 +56,24 @@ class Executor(ExecutorBase): # to support external launcher executor_class = ExecutorWithExternalLauncher elif isinstance(distributed_executor_backend, str): - executor_class = resolve_obj_by_qualname( - distributed_executor_backend) + executor_class = resolve_obj_by_qualname(distributed_executor_backend) if not issubclass(executor_class, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {executor_class}.") + f"ExecutorBase. Got {executor_class}." + ) else: - raise ValueError("Unknown distributed executor backend: " - f"{distributed_executor_backend}") + raise ValueError( + f"Unknown distributed executor backend: {distributed_executor_backend}" + ) return executor_class - def initialize_from_config(self, - kv_cache_configs: list[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_from_config", - args=(kv_cache_configs, )) + self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("compile_or_warm_up_model") def register_failure_callback(self, callback: FailureCallback): @@ -87,12 +89,14 @@ class Executor(ExecutorBase): def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + ) -> list[Any]: raise NotImplementedError def execute_model( @@ -100,9 +104,9 @@ class Executor(ExecutorBase): scheduler_output: SchedulerOutput, non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - output = self.collective_rpc("execute_model", - args=(scheduler_output, ), - non_block=non_block) + output = self.collective_rpc( + "execute_model", args=(scheduler_output,), non_block=non_block + ) return output[0] def execute_dummy_batch(self) -> None: @@ -117,7 +121,7 @@ class Executor(ExecutorBase): return 1 def profile(self, is_start: bool = True): - self.collective_rpc("profile", args=(is_start, )) + self.collective_rpc("profile", args=(is_start,)) class UniProcExecutor(UniProcExecutorV0, Executor): @@ -125,12 +129,12 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> list[int]: # in bytes # same as determine_num_available_blocks in v0, # we need to get the min across all ranks. memory = super().determine_available_memory() from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index eecdf8def6..062b604269 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -24,30 +24,36 @@ import torch import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) -from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_pp_group, get_tp_group) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel +from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue +from vllm.distributed.parallel_state import ( + get_dp_group, + get_ep_group, + get_pp_group, + get_tp_group, +) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import (_maybe_force_spawn, decorate_logs, - get_distributed_init_method, get_loopback_ip, - get_mp_context, get_open_port, set_process_title) +from vllm.utils import ( + _maybe_force_spawn, + decorate_logs, + get_distributed_init_method, + get_loopback_ip, + get_mp_context, + get_open_port, + set_process_title, +) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.utils import get_and_update_mm_cache -from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class MultiprocExecutor(Executor): - supports_pp: bool = True def _init_executor(self) -> None: @@ -65,7 +71,8 @@ class MultiprocExecutor(Executor): assert self.world_size == tensor_parallel_size * pp_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). ") + f"_parallel_size ({pp_parallel_size}). " + ) # Set multiprocessing envs set_multiprocessing_worker_envs() @@ -74,14 +81,15 @@ class MultiprocExecutor(Executor): # Since it only works for single node, we can use the loopback address # get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( - get_loopback_ip(), get_open_port()) + get_loopback_ip(), get_open_port() + ) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue(self.world_size, - self.world_size, - max_chunk_bytes=max_chunk_bytes) + self.rpc_broadcast_mq = MessageQueue( + self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes + ) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers @@ -99,7 +107,8 @@ class MultiprocExecutor(Executor): distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, - )) + ) + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -120,8 +129,7 @@ class MultiprocExecutor(Executor): for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() - self._ensure_worker_termination( - [uw.proc for uw in unready_workers]) + self._ensure_worker_termination([uw.proc for uw in unready_workers]) # For pipeline parallel, we use a thread pool for asynchronous # execute_model. @@ -130,7 +138,8 @@ class MultiprocExecutor(Executor): # from the response queue # _async_aggregate_workers_output also assumes a single IO thread self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") + max_workers=1, thread_name_prefix="mp_exec_io" + ) self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None @@ -146,23 +155,22 @@ class MultiprocExecutor(Executor): sentinels = [h.proc.sentinel for h in workers] died = multiprocessing.connection.wait(sentinels) _self = self_ref() - if not _self or getattr(_self, 'shutting_down', False): + if not _self or getattr(_self, "shutting_down", False): return _self.is_failed = True - proc_name = next(h.proc.name for h in workers - if h.proc.sentinel == died[0]) + proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) logger.error( - "Worker proc %s died unexpectedly, " - "shutting down executor.", proc_name) + "Worker proc %s died unexpectedly, shutting down executor.", proc_name + ) _self.shutdown() callback = _self.failure_callback if callback is not None: _self.failure_callback = None callback() - Thread(target=monitor_workers, - daemon=True, - name="MultiprocWorkerMonitor").start() + Thread( + target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" + ).start() def register_failure_callback(self, callback: FailureCallback): if self.is_failed: @@ -175,47 +183,49 @@ class MultiprocExecutor(Executor): scheduler_output: SchedulerOutput, non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - if not self.has_connector: # get output only from a single worker (output_rank) - (output, ) = self.collective_rpc( + (output,) = self.collective_rpc( "execute_model", - args=(scheduler_output, ), + args=(scheduler_output,), unique_reply_rank=self.output_rank, non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) return output # get output from all workers outputs = self.collective_rpc( "execute_model", - args=(scheduler_output, ), + args=(scheduler_output,), non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) # aggregate all workers output to a single output if non_block: - return self.kv_output_aggregator.async_aggregate( - outputs, self.output_rank) + return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) def execute_dummy_batch(self) -> None: - self.collective_rpc("execute_dummy_batch", - unique_reply_rank=self.output_rank) + self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank) def take_draft_token_ids(self) -> Optional[DraftTokenIds]: # OPTIMIZATION: Get output only from a single worker (output_rank) - outputs = self.collective_rpc("take_draft_token_ids", - unique_reply_rank=self.output_rank) + outputs = self.collective_rpc( + "take_draft_token_ids", unique_reply_rank=self.output_rank + ) return outputs[0] - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False, - unique_reply_rank: Optional[int] = None) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None, + ) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -230,42 +240,53 @@ class MultiprocExecutor(Executor): send_method = method else: send_method = cloudpickle.dumps( - method, protocol=pickle.HIGHEST_PROTOCOL) + method, protocol=pickle.HIGHEST_PROTOCOL + ) self.rpc_broadcast_mq.enqueue( - (send_method, args, kwargs, unique_reply_rank)) + (send_method, args, kwargs, unique_reply_rank) + ) - workers = (self.workers[unique_reply_rank], - ) if unique_reply_rank is not None else self.workers + workers = ( + (self.workers[unique_reply_rank],) + if unique_reply_rank is not None + else self.workers + ) responses = [] - def get_response(w: WorkerProcHandle, - dequeue_timeout: Optional[float] = None, - cancel_event: Optional[threading.Event] = None): + def get_response( + w: WorkerProcHandle, + dequeue_timeout: Optional[float] = None, + cancel_event: Optional[threading.Event] = None, + ): status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=cancel_event) + timeout=dequeue_timeout, cancel=cancel_event + ) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" - " stack trace above for the root cause") + " stack trace above for the root cause" + ) return result for w in workers: - dequeue_timeout = None if deadline is None else ( - deadline - time.monotonic()) + dequeue_timeout = ( + None if deadline is None else (deadline - time.monotonic()) + ) if self.io_thread_pool is not None: # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore - get_response, w, dequeue_timeout, self.shutdown_event) + get_response, w, dequeue_timeout, self.shutdown_event + ) if not non_block: result = result.result() elif not non_block: - result = get_response(w, dequeue_timeout, - self.shutdown_event) + result = get_response(w, dequeue_timeout, self.shutdown_event) else: - raise RuntimeError("non_block can only be used when" - " max_concurrent_batches > 1") + raise RuntimeError( + "non_block can only be used when max_concurrent_batches > 1" + ) responses.append(result) return responses @@ -302,11 +323,11 @@ class MultiprocExecutor(Executor): def shutdown(self): """Properly shut down the executor and its workers""" - if not getattr(self, 'shutting_down', False): + if not getattr(self, "shutting_down", False): self.shutting_down = True # Make sure all the worker processes are terminated first. - if workers := getattr(self, 'workers', None): + if workers := getattr(self, "workers", None): for w in workers: # Close death_writer to signal child processes to exit if w.death_writer is not None: @@ -348,6 +369,7 @@ class MultiprocExecutor(Executor): @dataclass class UnreadyWorkerProcHandle: """WorkerProcess handle before READY.""" + proc: BaseProcess rank: int ready_pipe: Connection @@ -363,8 +385,8 @@ class WorkerProcHandle: @classmethod def from_unready_handle( - cls, unready_handle: UnreadyWorkerProcHandle, - worker_response_mq: MessageQueue) -> "WorkerProcHandle": + cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue + ) -> "WorkerProcHandle": return cls( proc=unready_handle.proc, rank=unready_handle.rank, @@ -393,8 +415,7 @@ class WorkerProc: all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] - is_driver_worker = ( - rank % vllm_config.parallel_config.tensor_parallel_size == 0) + is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0 all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, @@ -407,7 +428,8 @@ class WorkerProc: # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( - input_shm_handle, self.worker.rank) + input_shm_handle, self.worker.rank + ) # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) @@ -419,19 +441,22 @@ class WorkerProc: self.async_output_copy_thread = Thread( target=self.async_output_busy_loop, daemon=True, - name="WorkerAsyncOutputCopy") + name="WorkerAsyncOutputCopy", + ) self.async_output_copy_thread.start() # Initialize multimodal receiver cache if needed self.mm_receiver_cache = worker_receiver_cache_from_config( - vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock) + vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock + ) # Initialize device self.worker.init_device() # Set process title and log prefix self.setup_proc_title_and_log_prefix( - enable_ep=vllm_config.parallel_config.enable_expert_parallel) + enable_ep=vllm_config.parallel_config.enable_expert_parallel + ) # Load model self.worker.load_model() @@ -463,10 +488,12 @@ class WorkerProc: "shared_worker_lock": shared_worker_lock, } # Run EngineCore busy loop in background process. - proc = context.Process(target=WorkerProc.worker_main, - kwargs=process_kwargs, - name=f"VllmWorker-{rank}", - daemon=True) + proc = context.Process( + target=WorkerProc.worker_main, + kwargs=process_kwargs, + name=f"VllmWorker-{rank}", + daemon=True, + ) proc.start() writer.close() @@ -476,16 +503,18 @@ class WorkerProc: @staticmethod def wait_for_ready( - unready_proc_handles: list[UnreadyWorkerProcHandle] + unready_proc_handles: list[UnreadyWorkerProcHandle], ) -> list[WorkerProcHandle]: - - e = Exception("WorkerProc initialization failed due to " - "an exception in a background process. " - "See stack trace for root cause.") + e = Exception( + "WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause." + ) pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} - ready_proc_handles: list[Optional[WorkerProcHandle]] = ( - [None] * len(unready_proc_handles)) + ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len( + unready_proc_handles + ) while pipes: ready = multiprocessing.connection.wait(pipes.keys()) for pipe in ready: @@ -499,10 +528,13 @@ class WorkerProc: # Extract the message queue handle. worker_response_mq = MessageQueue.create_from_handle( - response["handle"], 0) + response["handle"], 0 + ) ready_proc_handles[unready_proc_handle.rank] = ( WorkerProcHandle.from_unready_handle( - unready_proc_handle, worker_response_mq)) + unready_proc_handle, worker_response_mq + ) + ) except EOFError: e.__suppress_context__ = True @@ -523,8 +555,8 @@ class WorkerProc: @staticmethod def worker_main(*args, **kwargs): - """ Worker initialization and execution loops. - This runs a background process """ + """Worker initialization and execution loops. + This runs a background process""" # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker @@ -561,9 +593,9 @@ class WorkerProc: except Exception as e: logger.warning("Death monitoring error: %s", e) - death_monitor = Thread(target=monitor_parent_death, - daemon=True, - name="WorkerDeathMonitor") + death_monitor = Thread( + target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" + ) death_monitor.start() try: @@ -571,12 +603,12 @@ class WorkerProc: worker = WorkerProc(*args, **kwargs) # Send READY once we know everything is loaded - ready_writer.send({ - "status": - WorkerProc.READY_STR, - "handle": - worker.worker_response_mq.export_handle(), - }) + ready_writer.send( + { + "status": WorkerProc.READY_STR, + "handle": worker.worker_response_mq.export_handle(), + } + ) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor @@ -653,15 +685,18 @@ class WorkerProc: """Main busy loop for Multiprocessing Workers""" while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( - cancel=cancel, indefinite=True) + cancel=cancel, indefinite=True + ) try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) # retrieve from shm cache if available - if self.mm_receiver_cache is not None \ - and func.__name__ == "execute_model": + if ( + self.mm_receiver_cache is not None + and func.__name__ == "execute_model" + ): get_and_update_mm_cache(self.mm_receiver_cache, args) output = func(*args, **kwargs) except Exception as e: @@ -701,7 +736,7 @@ class WorkerProc: def set_multiprocessing_worker_envs(): - """ Set up environment variables that should be used when there are workers + """Set up environment variables that should be used when there are workers in a multiprocessing environment. This should be called by the parent process before worker processes are created""" @@ -714,13 +749,16 @@ def set_multiprocessing_worker_envs(): # impact on performance. The contention is amplified when running in a # container where CPU limits can cause throttling. default_omp_num_threads = 1 - if "OMP_NUM_THREADS" not in os.environ and ( - current_parallelism := - torch.get_num_threads()) > default_omp_num_threads: + if ( + "OMP_NUM_THREADS" not in os.environ + and (current_parallelism := torch.get_num_threads()) > default_omp_num_threads + ): logger.warning( "Reducing Torch parallelism from %d threads to %d to avoid " "unnecessary CPU contention. Set OMP_NUM_THREADS in the " "external environment to tune this value as needed.", - current_parallelism, default_omp_num_threads) + current_parallelism, + default_omp_num_threads, + ) os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index aadb5fd1dd..e2c2bfd45d 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -6,7 +6,8 @@ from typing import Optional, Union from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0) + RayDistributedExecutor as RayDistributedExecutorV0, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType @@ -18,10 +19,10 @@ logger = init_logger(__name__) class FutureWrapper(Future): """A wrapper around Ray output reference to meet the interface - of .execute_model(): The top level (core busy loop) expects .result() api + of .execute_model(): The top level (core busy loop) expects .result() api to block and return a single output. - - If aggregator is provided, the outputs from all workers are aggregated upon + + If aggregator is provided, the outputs from all workers are aggregated upon the result() call. If not only the first worker's output is returned. """ @@ -101,8 +102,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): return FutureWrapper(refs, self.kv_output_aggregator) def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: self._run_workers("reinitialize_distributed", reconfig_request) - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py index 1855bc9963..884068a438 100644 --- a/vllm/v1/executor/utils.py +++ b/vllm/v1/executor/utils.py @@ -20,4 +20,5 @@ def get_and_update_mm_cache( scheduler_output = args[0] for request_data in scheduler_output.scheduled_new_reqs: request_data.mm_features = receiver_cache.get_and_update_features( - request_data.mm_features) + request_data.mm_features + ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 054ab591b8..9c28eb92c1 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -50,7 +50,8 @@ class KVCacheSpec: Merge a list of KVCacheSpec objects into a single KVCacheSpec object. """ assert all(spec == specs[0] for spec in specs[1:]), ( - "All layers in the same KV cache group must be the same.") + "All layers in the same KV cache group must be the same." + ) return copy.deepcopy(specs[0]) @@ -62,8 +63,13 @@ class AttentionSpec(KVCacheSpec): @property def page_size_bytes(self) -> int: - return 2 * self.block_size * self.num_kv_heads * self.head_size \ - * get_dtype_size(self.dtype) + return ( + 2 + * self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) @dataclass(frozen=True) @@ -82,8 +88,7 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len - dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size + dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # Note(hc): each dcp rank only need save # (max_model_len//dcp_world_size) tokens locally. if dcp_world_size > 1: @@ -99,24 +104,30 @@ class FullAttentionSpec(AttentionSpec): else: raise ValueError( "All attention layers in the same KV cache group must have the " - "same window size.") + "same window size." + ) @classmethod def merge(cls, specs: list[Self]) -> Self: """ - Merge a list of FullAttentionSpec objects into a single + Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object. """ assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "FullAttentionSpec.") + "All attention layers in the same KV cache group must be FullAttentionSpec." + ) - sliding_window = set(spec.sliding_window for spec in specs - if spec.sliding_window is not None) - attention_chunk_size = set(spec.attention_chunk_size for spec in specs - if spec.attention_chunk_size is not None) + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "MLAAttentionSpec should be merged in MLAAttentionSpec.merge") + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) merged_spec = cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, @@ -129,12 +140,14 @@ class FullAttentionSpec(AttentionSpec): for f in fields(AttentionSpec): assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( "All attention layers in the same KV cache group must have " - "the same attention spec.") - assert ( - (merged_spec.sliding_window is not None) + - (merged_spec.attention_chunk_size is not None) <= 1 - ), ("Model with both sliding window layers and chunked local attention " - "layers is not supported.") + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) return merged_spec @@ -149,18 +162,23 @@ class MLAAttentionSpec(FullAttentionSpec): # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` # for details. return self.block_size * 656 - return self.block_size * self.num_kv_heads * self.head_size \ - * get_dtype_size(self.dtype) + return ( + self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) @classmethod def merge(cls, specs: list[Self]) -> Self: assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "MLAAttentionSpec.") + "All attention layers in the same KV cache group must be MLAAttentionSpec." + ) cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) assert len(cache_dtype_str_set) == 1, ( "All attention layers in the same KV cache group must use the same " - "quantization method.") + "quantization method." + ) return cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, @@ -176,15 +194,15 @@ class ChunkedLocalAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # During chunked prefill, we allocate KV cache for at most # `self.attention_chunk_size` computed tokens plus the newly scheduled # tokens. And we won't allocate KV cache for more than `max_model_len` # tokens. - num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, - max_model_len) + num_tokens = min( + self.attention_chunk_size + max_num_batched_tokens, max_model_len + ) return cdiv(num_tokens, self.block_size) * self.page_size_bytes @@ -194,18 +212,19 @@ class SlidingWindowSpec(AttentionSpec): sliding_window: int def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ + assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( "DCP not support sliding window." + ) max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # During chunked prefill, we allocate KV cache for the last # `self.sliding_window-1` computed tokens plus the newly scheduled # tokens. And we won't allocate KV cache for more than `max_model_len` # tokens. - num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, - max_model_len) + num_tokens = min( + self.sliding_window - 1 + max_num_batched_tokens, max_model_len + ) # +1 here because the sliding window may not start from the beginning # of the block. For example, if the block size is 4 and num_token @@ -226,7 +245,8 @@ class MambaSpec(KVCacheSpec): def page_size_bytes(self) -> int: page_size = sum( prod(shape) * get_dtype_size(dtype) - for (shape, dtype) in zip(self.shapes, self.dtypes)) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded @@ -239,7 +259,6 @@ class MambaSpec(KVCacheSpec): @dataclass(frozen=True) class EncoderOnlyAttentionSpec(AttentionSpec): - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # Encoder-only layers do not need KV cache return 0 @@ -254,8 +273,7 @@ class CrossAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # For cross-attention, we need to cache encoder states # Get encoder length (e.g., 1500 for Whisper). - max_encoder_len = vllm_config.scheduler_config.\ - max_num_encoder_input_tokens + max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes @@ -267,18 +285,18 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): sliding window attentions with different window sizes are not the same type and should not be merged into one UniformTypeKVCacheSpecs. """ + kv_cache_specs: dict[str, KVCacheSpec] @property def page_size_bytes(self) -> int: - return sum(spec.page_size_bytes - for spec in self.kv_cache_specs.values()) + return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values()) def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_num_pages = max( - cdiv(spec.max_memory_usage_bytes(vllm_config), - spec.page_size_bytes) - for spec in self.kv_cache_specs.values()) + cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes) + for spec in self.kv_cache_specs.values() + ) return max_num_pages * self.page_size_bytes @classmethod @@ -293,35 +311,38 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): one_spec = next(iter(kv_cache_specs.values())) if isinstance(one_spec, FullAttentionSpec): return all( - isinstance(spec, FullAttentionSpec) - for spec in kv_cache_specs.values()) + isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, CrossAttentionSpec): return all( - isinstance(spec, CrossAttentionSpec) - for spec in kv_cache_specs.values()) + isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, SlidingWindowSpec): return all( isinstance(spec, SlidingWindowSpec) and spec.sliding_window == one_spec.sliding_window - for spec in kv_cache_specs.values()) + for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, ChunkedLocalAttentionSpec): return all( isinstance(spec, ChunkedLocalAttentionSpec) and spec.attention_chunk_size == one_spec.attention_chunk_size - for spec in kv_cache_specs.values()) + for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, MambaSpec): return all( - isinstance(spec, MambaSpec) and spec.num_speculative_blocks == - one_spec.num_speculative_blocks - for spec in kv_cache_specs.values()) + isinstance(spec, MambaSpec) + and spec.num_speculative_blocks == one_spec.num_speculative_blocks + for spec in kv_cache_specs.values() + ) else: # NOTE(Chen): Please add new branches for new KV cache spec types. raise NotImplementedError( - f"Unsupported KV cache spec type: {type(one_spec)}") + f"Unsupported KV cache spec type: {type(one_spec)}" + ) @classmethod - def from_specs(cls, kv_cache_specs: dict[str, - KVCacheSpec]) -> Optional[Self]: + def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Optional[Self]: """ Return a SameTypeKVCacheSpecs object if all layers have the same type of KV cache spec. Return None if not. @@ -338,6 +359,7 @@ class KVCacheTensor: """ A class for specifying how the workers should initialize the KV cache. """ + size: int # size of the KV cache tensor in bytes shared_by: list[str] # layer names that share the same KV cache tensor @@ -348,6 +370,7 @@ class KVCacheGroupSpec: Represents a group of model layers that share the same KV cache block table. These layers are regarded as one layer in the KV cache manager. """ + # The names of model layers in this group layer_names: list[str] # The KV cache spec of this manager layer @@ -359,6 +382,7 @@ class KVCacheConfig: """ The KV cache configuration of a model. """ + """The number of KV cache blocks""" num_blocks: int """How should model runner initialize the KV cache tensors for each layer""" diff --git a/vllm/v1/kv_offload/abstract.py b/vllm/v1/kv_offload/abstract.py index 9f9c044ea1..ce2d0dffc0 100644 --- a/vllm/v1/kv_offload/abstract.py +++ b/vllm/v1/kv_offload/abstract.py @@ -68,7 +68,6 @@ class OffloadingEvent: class OffloadingManager(ABC): - @abstractmethod def lookup(self, block_hashes: Iterable[BlockHash]) -> int: """ @@ -122,8 +121,8 @@ class OffloadingManager(ABC): @abstractmethod def prepare_store( - self, - block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]: + self, block_hashes: Iterable[BlockHash] + ) -> Optional[PrepareStoreOutput]: """ Prepare the given blocks to be offloaded. The given blocks will be protected from eviction until @@ -140,9 +139,7 @@ class OffloadingManager(ABC): """ pass - def complete_store(self, - block_hashes: Iterable[BlockHash], - success: bool = True): + def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): """ Marks blocks which were previously prepared to be stored, as stored. Following this call, the blocks become loadable. diff --git a/vllm/v1/kv_offload/backend.py b/vllm/v1/kv_offload/backend.py index 87a7420011..538f7bf058 100644 --- a/vllm/v1/kv_offload/backend.py +++ b/vllm/v1/kv_offload/backend.py @@ -18,6 +18,7 @@ class BlockStatus(ctypes.Structure): load_store_spec - backend-specific information on how to actually read/write the block. """ + _fields_ = [("ref_cnt", ctypes.c_int32)] def __init__(self): @@ -51,8 +52,7 @@ class Backend(ABC): pass @abstractmethod - def allocate_blocks(self, - block_hashes: list[BlockHash]) -> list[BlockStatus]: + def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: """ Allocate space for writing blocks. This method assumes there is enough space for allocation. @@ -80,8 +80,9 @@ class Backend(ABC): """ pass - def get_load_store_spec(self, block_hashes: Iterable[BlockHash], - blocks: Iterable[BlockStatus]) -> LoadStoreSpec: + def get_load_store_spec( + self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] + ) -> LoadStoreSpec: """ Get backend-specific information on how to read/write blocks. diff --git a/vllm/v1/kv_offload/backends/cpu.py b/vllm/v1/kv_offload/backends/cpu.py index eb1123d1d8..736cf37853 100644 --- a/vllm/v1/kv_offload/backends/cpu.py +++ b/vllm/v1/kv_offload/backends/cpu.py @@ -10,8 +10,7 @@ from vllm.v1.kv_offload.mediums import CPULoadStoreSpec class CPUBlockStatus(BlockStatus): - _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64) - ] # type: ignore + _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)] # type: ignore def __init__(self, block_id: int): super().__init__() @@ -19,23 +18,24 @@ class CPUBlockStatus(BlockStatus): class CPUBackend(Backend): - def __init__(self, block_size: int, num_blocks: int): - super().__init__(block_size=block_size, - medium=CPULoadStoreSpec.medium()) + super().__init__(block_size=block_size, medium=CPULoadStoreSpec.medium()) self.num_blocks: int = num_blocks self.num_allocated_blocks: int = 0 self.allocated_blocks_free_list: list[int] = [] def get_num_free_blocks(self): - return (len(self.allocated_blocks_free_list) + self.num_blocks - - self.num_allocated_blocks) + return ( + len(self.allocated_blocks_free_list) + + self.num_blocks + - self.num_allocated_blocks + ) - def allocate_blocks(self, - block_hashes: list[BlockHash]) -> list[BlockStatus]: - num_fresh_blocks = min(len(block_hashes), - self.num_blocks - self.num_allocated_blocks) + def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: + num_fresh_blocks = min( + len(block_hashes), self.num_blocks - self.num_allocated_blocks + ) num_reused_blocks = len(block_hashes) - num_fresh_blocks assert len(self.allocated_blocks_free_list) >= num_reused_blocks @@ -56,6 +56,7 @@ class CPUBackend(Backend): assert isinstance(block, CPUBlockStatus) self.allocated_blocks_free_list.append(block.block_id) - def get_load_store_spec(self, block_hashes: Iterable[BlockHash], - blocks: Iterable[BlockStatus]) -> LoadStoreSpec: + def get_load_store_spec( + self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] + ) -> LoadStoreSpec: return CPULoadStoreSpec([block.block_id for block in blocks]) diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index b85d375fe6..0c1cf64a23 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -18,14 +18,14 @@ from vllm.v1.kv_offload.worker.worker import OffloadingHandler class CPUOffloadingSpec(OffloadingSpec): - def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) num_cpu_blocks = self.extra_config.get("num_cpu_blocks") if not num_cpu_blocks: - raise Exception("num_cpu_blocks must be specified " - "in kv_connector_extra_config") + raise Exception( + "num_cpu_blocks must be specified in kv_connector_extra_config" + ) self.num_cpu_blocks: int = num_cpu_blocks # scheduler-side @@ -37,27 +37,30 @@ class CPUOffloadingSpec(OffloadingSpec): def get_manager(self) -> OffloadingManager: if not self._manager: kv_events_config = self.vllm_config.kv_events_config - enable_events = (kv_events_config is not None - and kv_events_config.enable_kv_cache_events) - self._manager = LRUOffloadingManager(CPUBackend( - block_size=self.offloaded_block_size, - num_blocks=self.num_cpu_blocks), - enable_events=enable_events) + enable_events = ( + kv_events_config is not None and kv_events_config.enable_kv_cache_events + ) + self._manager = LRUOffloadingManager( + CPUBackend( + block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks + ), + enable_events=enable_events, + ) return self._manager def get_handlers( self, kv_caches: dict[str, torch.Tensor] - ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], - OffloadingHandler]]: + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: if not self._handler: if not current_platform.is_cuda(): - raise Exception("CPU Offloading is currently only supported" - " on CUDA GPUs") + raise Exception( + "CPU Offloading is currently only supported on CUDA GPUs" + ) layer_names = list(kv_caches.keys()) - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) attn_backends = { layer_name: layers[layer_name].get_attn_backend() for layer_name in layer_names @@ -68,7 +71,8 @@ class CPUOffloadingSpec(OffloadingSpec): gpu_block_size=self.gpu_block_size, cpu_block_size=self.offloaded_block_size, num_cpu_blocks=self.num_cpu_blocks, - gpu_caches=kv_caches) + gpu_caches=kv_caches, + ) assert self._handler is not None yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index f9bef6cea9..e0a53460e8 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -16,8 +16,7 @@ class OffloadingSpecFactory: _registry: dict[str, Callable[[], type[OffloadingSpec]]] = {} @classmethod - def register_spec(cls, name: str, module_path: str, - class_name: str) -> None: + def register_spec(cls, name: str, module_path: str, class_name: str) -> None: """Register a spec with a lazy-loading module and class name.""" if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") @@ -51,6 +50,6 @@ class OffloadingSpecFactory: # Register various specs here. -OffloadingSpecFactory.register_spec("CPUOffloadingSpec", - "vllm.v1.kv_offload.cpu", - "CPUOffloadingSpec") +OffloadingSpecFactory.register_spec( + "CPUOffloadingSpec", "vllm.v1.kv_offload.cpu", "CPUOffloadingSpec" +) diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py index 18d3b1d637..36f5eb4a0a 100644 --- a/vllm/v1/kv_offload/lru_manager.py +++ b/vllm/v1/kv_offload/lru_manager.py @@ -5,8 +5,12 @@ from collections.abc import Iterable from typing import Optional from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, - OffloadingManager, PrepareStoreOutput) +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + PrepareStoreOutput, +) from vllm.v1.kv_offload.backend import Backend, BlockStatus @@ -19,8 +23,7 @@ class LRUOffloadingManager(OffloadingManager): self.backend: Backend = backend # block_hash -> BlockStatus self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() - self.events: Optional[list[OffloadingEvent]] = \ - [] if enable_events else None + self.events: Optional[list[OffloadingEvent]] = [] if enable_events else None def lookup(self, block_hashes: Iterable[BlockHash]) -> int: hit_count = 0 @@ -53,16 +56,16 @@ class LRUOffloadingManager(OffloadingManager): block.ref_cnt -= 1 def prepare_store( - self, - block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]: + self, block_hashes: Iterable[BlockHash] + ) -> Optional[PrepareStoreOutput]: # filter out blocks that are already stored block_hashes_to_store = [ - block_hash for block_hash in block_hashes - if block_hash not in self.blocks + block_hash for block_hash in block_hashes if block_hash not in self.blocks ] - num_blocks_to_evict = (len(block_hashes_to_store) - - self.backend.get_num_free_blocks()) + num_blocks_to_evict = ( + len(block_hashes_to_store) - self.backend.get_num_free_blocks() + ) # build list of blocks to evict to_evict = [] @@ -83,10 +86,13 @@ class LRUOffloadingManager(OffloadingManager): if to_evict and self.events is not None: self.events.append( - OffloadingEvent(block_hashes=to_evict, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=True)) + OffloadingEvent( + block_hashes=to_evict, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=True, + ) + ) blocks = self.backend.allocate_blocks(block_hashes_to_store) assert len(blocks) == len(block_hashes_to_store) @@ -95,16 +101,15 @@ class LRUOffloadingManager(OffloadingManager): self.blocks[block_hash] = block # build store specs for allocated blocks - store_spec = self.backend.get_load_store_spec(block_hashes_to_store, - blocks) + store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks) - return PrepareStoreOutput(block_hashes_to_store=block_hashes_to_store, - store_spec=store_spec, - block_hashes_evicted=to_evict) + return PrepareStoreOutput( + block_hashes_to_store=block_hashes_to_store, + store_spec=store_spec, + block_hashes_evicted=to_evict, + ) - def complete_store(self, - block_hashes: Iterable[BlockHash], - success: bool = True): + def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): stored_block_hashes: list[BlockHash] = [] if success: for block_hash in block_hashes: @@ -121,10 +126,13 @@ class LRUOffloadingManager(OffloadingManager): if stored_block_hashes and self.events is not None: self.events.append( - OffloadingEvent(block_hashes=stored_block_hashes, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=False)) + OffloadingEvent( + block_hashes=stored_block_hashes, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=False, + ) + ) def take_events(self) -> Iterable[OffloadingEvent]: if self.events is not None: diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py index ed23d5e519..a3c539a47d 100644 --- a/vllm/v1/kv_offload/spec.py +++ b/vllm/v1/kv_offload/spec.py @@ -22,7 +22,8 @@ class OffloadingSpec(ABC): def __init__(self, vllm_config: "VllmConfig"): logger.warning( "Initializing OffloadingSpec. This API is experimental and " - "subject to change in the future as we iterate the design.") + "subject to change in the future as we iterate the design." + ) self.vllm_config = vllm_config kv_transfer_config = vllm_config.kv_transfer_config @@ -31,7 +32,8 @@ class OffloadingSpec(ABC): self.gpu_block_size = vllm_config.cache_config.block_size self.offloaded_block_size = int( - self.extra_config.get("block_size", self.gpu_block_size)) + self.extra_config.get("block_size", self.gpu_block_size) + ) assert self.offloaded_block_size % self.gpu_block_size == 0 @@ -47,8 +49,7 @@ class OffloadingSpec(ABC): @abstractmethod def get_handlers( self, kv_caches: dict[str, torch.Tensor] - ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], - OffloadingHandler]]: + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: """ Get offloading handlers along with their respective src and dst types. diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index 556c29247e..eb7117a400 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -9,16 +9,21 @@ from vllm.attention import AttentionBackend from vllm.logger import init_logger from vllm.utils import is_pin_memory_available from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec -from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, - TransferResult, TransferSpec) +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + TransferResult, + TransferSpec, +) logger = init_logger(__name__) -def expand_block_ids(block_ids: np.ndarray, - block_size_factor: int, - output: np.ndarray, - skip_count: int = 0): +def expand_block_ids( + block_ids: np.ndarray, + block_size_factor: int, + output: np.ndarray, + skip_count: int = 0, +): """ Convert a list of block IDs to a list of matching block ids, assuming each block is composed of actual block_size_factor blocks. @@ -47,10 +52,14 @@ def expand_block_ids(block_ids: np.ndarray, class CpuGpuOffloadingHandler(OffloadingHandler): - - def __init__(self, gpu_block_size: int, cpu_block_size: int, - num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor], - attn_backends: dict[str, type[AttentionBackend]]): + def __init__( + self, + gpu_block_size: int, + cpu_block_size: int, + num_cpu_blocks: int, + gpu_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], + ): assert cpu_block_size % gpu_block_size == 0 self.block_size_factor = cpu_block_size // gpu_block_size @@ -75,7 +84,8 @@ class CpuGpuOffloadingHandler(OffloadingHandler): gpu_shape = gpu_tensor.shape test_shape = attn_backends[layer_name].get_kv_cache_shape( - num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256) + num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 + ) if test_shape[0] == 1234: # shape is (num_blocks, ...) num_blocks_idx = 0 @@ -94,10 +104,13 @@ class CpuGpuOffloadingHandler(OffloadingHandler): logger.debug("Allocating CPU tensor of shape %r", cpu_shape) self.cpu_tensors.append( - torch.zeros(cpu_shape, - dtype=gpu_tensor.dtype, - device="cpu", - pin_memory=pin_memory)) + torch.zeros( + cpu_shape, + dtype=gpu_tensor.dtype, + device="cpu", + pin_memory=pin_memory, + ) + ) def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: src_spec, dst_spec = spec @@ -122,35 +135,36 @@ class CpuGpuOffloadingHandler(OffloadingHandler): assert src_blocks.ndim == 1 assert dst_blocks.ndim == 1 - dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor) + dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor src_sub_block_count = src_blocks.size * src_block_size_factor assert ( - src_sub_block_count == dst_blocks.size * dst_block_size_factor - - dst_sub_blocks_to_skip) + src_sub_block_count + == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip + ) src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) - expand_block_ids(dst_blocks, - dst_block_size_factor, - src_to_dst[:, 1], - skip_count=dst_sub_blocks_to_skip) + expand_block_ids( + dst_blocks, + dst_block_size_factor, + src_to_dst[:, 1], + skip_count=dst_sub_blocks_to_skip, + ) src_to_dst_tensor = torch.from_numpy(src_to_dst) - event = self.events_pool.pop() if self.events_pool \ - else torch.cuda.Event() + event = self.events_pool.pop() if self.events_pool else torch.cuda.Event() with torch.cuda.stream(stream): for src_tensor, dst_tensor, kv_dim in zip( - src_tensors, dst_tensors, self.kv_dim_before_num_blocks): + src_tensors, dst_tensors, self.kv_dim_before_num_blocks + ): if kv_dim: src_key_cache = src_tensor[0] dst_key_cache = dst_tensor[0] - ops.swap_blocks(src_key_cache, dst_key_cache, - src_to_dst_tensor) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) src_value_cache = src_tensor[1] dst_value_cache = dst_tensor[1] - ops.swap_blocks(src_value_cache, dst_value_cache, - src_to_dst_tensor) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) else: ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) event.record(stream) diff --git a/vllm/v1/kv_offload/worker/worker.py b/vllm/v1/kv_offload/worker/worker.py index b7a52a088f..58ba082497 100644 --- a/vllm/v1/kv_offload/worker/worker.py +++ b/vllm/v1/kv_offload/worker/worker.py @@ -74,12 +74,14 @@ class OffloadingWorker: def __init__(self): self.handlers: set[OffloadingHandler] = set() - self.transfer_type_to_handler: dict[TransferType, - OffloadingHandler] = {} + self.transfer_type_to_handler: dict[TransferType, OffloadingHandler] = {} - def register_handler(self, src_cls: type[LoadStoreSpec], - dst_cls: type[LoadStoreSpec], - handler: OffloadingHandler) -> None: + def register_handler( + self, + src_cls: type[LoadStoreSpec], + dst_cls: type[LoadStoreSpec], + handler: OffloadingHandler, + ) -> None: """ Registers a new handler. @@ -113,19 +115,19 @@ class OffloadingWorker: try: success = handler.transfer_async(job_id, spec) except Exception as e: - logger.warning("Exception in %r transfer %d: %r", - transfer_type, - job_id, - e, - exc_info=True) + logger.warning( + "Exception in %r transfer %d: %r", + transfer_type, + job_id, + e, + exc_info=True, + ) return False if not success: - logger.warning("Failed to submit %r transfer %d", transfer_type, - job_id) + logger.warning("Failed to submit %r transfer %d", transfer_type, job_id) else: - logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, - spec) + logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec) return success diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index ef95f03e88..541af7af17 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,8 +9,7 @@ from typing import Callable, Optional, Union import prometheus_client from vllm.config import SupportsMetricsInfo, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorLogging) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason @@ -32,26 +31,24 @@ class StatLoggerBase(ABC): """ @abstractmethod - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - ... + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): - ... + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): ... @abstractmethod - def log_engine_initialized(self): - ... + def log_engine_initialized(self): ... def log(self): # noqa pass class LoggingStatLogger(StatLoggerBase): - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self.vllm_config = vllm_config @@ -85,21 +82,21 @@ class LoggingStatLogger(StatLoggerBase): return 0.0 return float(tracked_stats / delta_time) - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) if scheduler_stats is not None: - self.prefix_caching_metrics.observe( - scheduler_stats.prefix_cache_stats) + self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: self.kv_connector_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats @@ -107,8 +104,7 @@ class LoggingStatLogger(StatLoggerBase): def log(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) - generation_throughput = self._get_throughput( - self.num_generation_tokens, now) + generation_throughput = self._get_throughput(self.num_generation_tokens, now) self._reset(now) @@ -116,8 +112,13 @@ class LoggingStatLogger(StatLoggerBase): log_fn = logger.info if not any( - (prompt_throughput, generation_throughput, - self.last_prompt_throughput, self.last_generation_throughput)): + ( + prompt_throughput, + generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput, + ) + ): # Avoid log noise on an idle production system log_fn = logger.debug self.last_generation_throughput = generation_throughput @@ -146,8 +147,10 @@ class LoggingStatLogger(StatLoggerBase): if self.vllm_config.cache_config.num_gpu_blocks: logger.info( "Engine %03d: vllm cache_config_info with initialization " - "after num_gpu_blocks is: %d", self.engine_index, - self.vllm_config.cache_config.num_gpu_blocks) + "after num_gpu_blocks is: %d", + self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks, + ) class PrometheusStatLogger(StatLoggerBase): @@ -156,9 +159,9 @@ class PrometheusStatLogger(StatLoggerBase): _histogram_cls = prometheus_client.Histogram _spec_decoding_cls = SpecDecodingProm - def __init__(self, - vllm_config: VllmConfig, - engine_indexes: Optional[list[int]] = None): + def __init__( + self, vllm_config: VllmConfig, engine_indexes: Optional[list[int]] = None + ): if engine_indexes is None: engine_indexes = [0] self.engine_indexes = engine_indexes @@ -167,21 +170,19 @@ class PrometheusStatLogger(StatLoggerBase): self.vllm_config = vllm_config # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future - self.show_hidden_metrics = \ - vllm_config.observability_config.show_hidden_metrics + self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics labelnames = ["model_name", "engine"] model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len spec_decode_labelvalues: dict[int, list[str]] = { - idx: [model_name, str(idx)] - for idx in engine_indexes + idx: [model_name, str(idx)] for idx in engine_indexes } self.spec_decoding_prom = self._spec_decoding_cls( - vllm_config.speculative_config, labelnames, - spec_decode_labelvalues) + vllm_config.speculative_config, labelnames, spec_decode_labelvalues + ) # # Scheduler state @@ -190,19 +191,21 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_scheduler_running = make_per_engine(gauge_scheduler_running, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.gauge_scheduler_running = make_per_engine( + gauge_scheduler_running, engine_indexes, model_name + ) gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_scheduler_waiting = make_per_engine(gauge_scheduler_waiting, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.gauge_scheduler_waiting = make_per_engine( + gauge_scheduler_waiting, engine_indexes, model_name + ) # # GPU cache @@ -215,11 +218,14 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:gpu_cache_usage_perc", documentation=( "GPU KV-cache usage. 1 means 100 percent usage." - "DEPRECATED: Use vllm:kv_cache_usage_perc instead."), + "DEPRECATED: Use vllm:kv_cache_usage_perc instead." + ), multiprocess_mode="mostrecent", - labelnames=labelnames) + labelnames=labelnames, + ) self.gauge_gpu_cache_usage = make_per_engine( - gauge_gpu_cache_usage, engine_indexes, model_name) + gauge_gpu_cache_usage, engine_indexes, model_name + ) # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 @@ -231,9 +237,11 @@ class PrometheusStatLogger(StatLoggerBase): "GPU prefix cache queries, in terms of number of queried" "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead." ), - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_gpu_prefix_cache_queries = make_per_engine( - counter_gpu_prefix_cache_queries, engine_indexes, model_name) + counter_gpu_prefix_cache_queries, engine_indexes, model_name + ) # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 @@ -243,33 +251,42 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:gpu_prefix_cache_hits", documentation=( "GPU prefix cache hits, in terms of number of cached " - "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."), - labelnames=labelnames) + "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead." + ), + labelnames=labelnames, + ) self.counter_gpu_prefix_cache_hits = make_per_engine( - counter_gpu_prefix_cache_hits, engine_indexes, model_name) + counter_gpu_prefix_cache_hits, engine_indexes, model_name + ) gauge_kv_cache_usage = self._gauge_cls( name="vllm:kv_cache_usage_perc", documentation="KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) - self.gauge_kv_cache_usage = make_per_engine(gauge_kv_cache_usage, - engine_indexes, model_name) + labelnames=labelnames, + ) + self.gauge_kv_cache_usage = make_per_engine( + gauge_kv_cache_usage, engine_indexes, model_name + ) counter_prefix_cache_queries = self._counter_cls( name="vllm:prefix_cache_queries", documentation=( - "Prefix cache queries, in terms of number of queried tokens."), - labelnames=labelnames) + "Prefix cache queries, in terms of number of queried tokens." + ), + labelnames=labelnames, + ) self.counter_prefix_cache_queries = make_per_engine( - counter_prefix_cache_queries, engine_indexes, model_name) + counter_prefix_cache_queries, engine_indexes, model_name + ) counter_prefix_cache_hits = self._counter_cls( name="vllm:prefix_cache_hits", - documentation=( - "Prefix cache hits, in terms of number of cached tokens."), - labelnames=labelnames) + documentation=("Prefix cache hits, in terms of number of cached tokens."), + labelnames=labelnames, + ) self.counter_prefix_cache_hits = make_per_engine( - counter_prefix_cache_hits, engine_indexes, model_name) + counter_prefix_cache_hits, engine_indexes, model_name + ) # # Counters @@ -277,36 +294,43 @@ class PrometheusStatLogger(StatLoggerBase): counter_num_preempted_reqs = self._counter_cls( name="vllm:num_preemptions", documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_num_preempted_reqs = make_per_engine( - counter_num_preempted_reqs, engine_indexes, model_name) + counter_num_preempted_reqs, engine_indexes, model_name + ) counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens", documentation="Number of prefill tokens processed.", - labelnames=labelnames) - self.counter_prompt_tokens = make_per_engine(counter_prompt_tokens, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.counter_prompt_tokens = make_per_engine( + counter_prompt_tokens, engine_indexes, model_name + ) counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens", documentation="Number of generation tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_generation_tokens = make_per_engine( - counter_generation_tokens, engine_indexes, model_name) + counter_generation_tokens, engine_indexes, model_name + ) - self.counter_request_success: dict[FinishReason, dict[ - int, prometheus_client.Counter]] = {} + self.counter_request_success: dict[ + FinishReason, dict[int, prometheus_client.Counter] + ] = {} counter_request_success_base = self._counter_cls( name="vllm:request_success", documentation="Count of successfully processed requests.", - labelnames=labelnames + ["finished_reason"]) + labelnames=labelnames + ["finished_reason"], + ) for reason in FinishReason: self.counter_request_success[reason] = { - idx: - counter_request_success_base.labels(model_name, str(idx), - str(reason)) + idx: counter_request_success_base.labels( + model_name, str(idx), str(reason) + ) for idx in engine_indexes } @@ -317,18 +341,21 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_num_prompt_tokens_request = make_per_engine( - histogram_num_prompt_tokens_request, engine_indexes, model_name) + histogram_num_prompt_tokens_request, engine_indexes, model_name + ) histogram_num_generation_tokens_request = self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_num_generation_tokens_request = make_per_engine( - histogram_num_generation_tokens_request, engine_indexes, - model_name) + histogram_num_generation_tokens_request, engine_indexes, model_name + ) # TODO: This metric might be incorrect in case of using multiple # api_server counts which uses prometheus mp. @@ -336,38 +363,42 @@ class PrometheusStatLogger(StatLoggerBase): histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", - buckets=[ - 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 - ], - labelnames=labelnames) + buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + labelnames=labelnames, + ) self.histogram_iteration_tokens = make_per_engine( - histogram_iteration_tokens, engine_indexes, model_name) + histogram_iteration_tokens, engine_indexes, model_name + ) histogram_max_num_generation_tokens_request = self._histogram_cls( name="vllm:request_max_num_generation_tokens", - documentation= - "Histogram of maximum number of requested generation tokens.", + documentation="Histogram of maximum number of requested generation tokens.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_max_num_generation_tokens_request = make_per_engine( - histogram_max_num_generation_tokens_request, engine_indexes, - model_name) + histogram_max_num_generation_tokens_request, engine_indexes, model_name + ) histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", buckets=[1, 2, 5, 10, 20], - labelnames=labelnames) - self.histogram_n_request = make_per_engine(histogram_n_request, - engine_indexes, model_name) + labelnames=labelnames, + ) + self.histogram_n_request = make_per_engine( + histogram_n_request, engine_indexes, model_name + ) histogram_max_tokens_request = self._histogram_cls( name="vllm:request_params_max_tokens", documentation="Histogram of the max_tokens request parameter.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_max_tokens_request = make_per_engine( - histogram_max_tokens_request, engine_indexes, model_name) + histogram_max_tokens_request, engine_indexes, model_name + ) # # Histogram of timing intervals @@ -376,13 +407,34 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", buckets=[ - 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, - 2560.0 + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + 160.0, + 640.0, + 2560.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_time_to_first_token = make_per_engine( - histogram_time_to_first_token, engine_indexes, model_name) + histogram_time_to_first_token, engine_indexes, model_name + ) # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds # TODO: in 0.12, only enable if show_hidden_metrics=True @@ -390,86 +442,167 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:time_per_output_token_seconds", documentation=( "Histogram of time per output token in seconds." - "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), + "DEPRECATED: Use vllm:inter_token_latency_seconds instead." + ), buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_time_per_output_token = make_per_engine( - histogram_time_per_output_token, engine_indexes, model_name) + histogram_time_per_output_token, engine_indexes, model_name + ) histogram_inter_token_latency = self._histogram_cls( name="vllm:inter_token_latency_seconds", documentation="Histogram of inter-token latency in seconds.", buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_inter_token_latency = make_per_engine( - histogram_inter_token_latency, engine_indexes, model_name) + histogram_inter_token_latency, engine_indexes, model_name + ) histogram_request_time_per_output_token = self._histogram_cls( name="vllm:request_time_per_output_token_seconds", - documentation= - "Histogram of time_per_output_token_seconds per request.", + documentation="Histogram of time_per_output_token_seconds per request.", buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_request_time_per_output_token = make_per_engine( - histogram_request_time_per_output_token, engine_indexes, - model_name) + histogram_request_time_per_output_token, engine_indexes, model_name + ) request_latency_buckets = [ - 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, - 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + 0.3, + 0.5, + 0.8, + 1.0, + 1.5, + 2.0, + 2.5, + 5.0, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + 120.0, + 240.0, + 480.0, + 960.0, + 1920.0, + 7680.0, ] histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of e2e request latency in seconds.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_e2e_time_request = make_per_engine( - histogram_e2e_time_request, engine_indexes, model_name) + histogram_e2e_time_request, engine_indexes, model_name + ) histogram_queue_time_request = self._histogram_cls( name="vllm:request_queue_time_seconds", - documentation= - "Histogram of time spent in WAITING phase for request.", + documentation="Histogram of time spent in WAITING phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_queue_time_request = make_per_engine( - histogram_queue_time_request, engine_indexes, model_name) + histogram_queue_time_request, engine_indexes, model_name + ) histogram_inference_time_request = self._histogram_cls( name="vllm:request_inference_time_seconds", - documentation= - "Histogram of time spent in RUNNING phase for request.", + documentation="Histogram of time spent in RUNNING phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_inference_time_request = make_per_engine( - histogram_inference_time_request, engine_indexes, model_name) + histogram_inference_time_request, engine_indexes, model_name + ) histogram_prefill_time_request = self._histogram_cls( name="vllm:request_prefill_time_seconds", - documentation= - "Histogram of time spent in PREFILL phase for request.", + documentation="Histogram of time spent in PREFILL phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_prefill_time_request = make_per_engine( - histogram_prefill_time_request, engine_indexes, model_name) + histogram_prefill_time_request, engine_indexes, model_name + ) histogram_decode_time_request = self._histogram_cls( name="vllm:request_decode_time_seconds", - documentation= - "Histogram of time spent in DECODE phase for request.", + documentation="Histogram of time spent in DECODE phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_decode_time_request = make_per_engine( - histogram_decode_time_request, engine_indexes, model_name) + histogram_decode_time_request, engine_indexes, model_name + ) # # LoRA metrics @@ -480,23 +613,21 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: if len(self.engine_indexes) > 1: - raise NotImplementedError( - "LoRA in DP mode is not supported yet.") + raise NotImplementedError("LoRA in DP mode is not supported yet.") self.labelname_max_lora = "max_lora" self.labelname_waiting_lora_adapters = "waiting_lora_adapters" self.labelname_running_lora_adapters = "running_lora_adapters" self.max_lora = vllm_config.lora_config.max_loras - self.gauge_lora_info = \ - self._gauge_cls( - name="vllm:lora_requests_info", - documentation="Running stats on lora requests.", - multiprocess_mode="sum", - labelnames=[ - self.labelname_max_lora, - self.labelname_waiting_lora_adapters, - self.labelname_running_lora_adapters, - ], - ) + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + multiprocess_mode="sum", + labelnames=[ + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + self.labelname_running_lora_adapters, + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): metrics_info = config_obj.metrics_info() @@ -522,54 +653,65 @@ class PrometheusStatLogger(StatLoggerBase): metrics_info["engine"] = str(engine_index) info_gauge.labels(**metrics_info).set(1) - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): """Log to prometheus.""" if scheduler_stats is not None: self.gauge_scheduler_running[engine_idx].set( - scheduler_stats.num_running_reqs) + scheduler_stats.num_running_reqs + ) self.gauge_scheduler_waiting[engine_idx].set( - scheduler_stats.num_waiting_reqs) + scheduler_stats.num_waiting_reqs + ) if self.show_hidden_metrics: self.gauge_gpu_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) - self.gauge_kv_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) + scheduler_stats.kv_cache_usage + ) + self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage) if self.show_hidden_metrics: self.counter_gpu_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) + scheduler_stats.prefix_cache_stats.queries + ) self.counter_gpu_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + scheduler_stats.prefix_cache_stats.hits + ) self.counter_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) + scheduler_stats.prefix_cache_stats.queries + ) self.counter_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + scheduler_stats.prefix_cache_stats.hits + ) if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats, engine_idx) + scheduler_stats.spec_decoding_stats, engine_idx + ) if iteration_stats is None: return self.counter_num_preempted_reqs[engine_idx].inc( - iteration_stats.num_preempted_reqs) - self.counter_prompt_tokens[engine_idx].inc( - iteration_stats.num_prompt_tokens) + iteration_stats.num_preempted_reqs + ) + self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens[engine_idx].inc( - iteration_stats.num_generation_tokens) + iteration_stats.num_generation_tokens + ) self.histogram_iteration_tokens[engine_idx].observe( - iteration_stats.num_prompt_tokens + \ - iteration_stats.num_generation_tokens) + iteration_stats.num_prompt_tokens + iteration_stats.num_generation_tokens + ) for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: - self.histogram_max_num_generation_tokens_request[ - engine_idx].observe(max_gen_tokens) + self.histogram_max_num_generation_tokens_request[engine_idx].observe( + max_gen_tokens + ) for n_param in iteration_stats.n_params_iter: self.histogram_n_request[engine_idx].observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: @@ -579,40 +721,51 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_time_per_output_token[engine_idx].observe(itl) for finished_request in iteration_stats.finished_requests: - self.counter_request_success[ - finished_request.finish_reason][engine_idx].inc() + self.counter_request_success[finished_request.finish_reason][ + engine_idx + ].inc() self.histogram_e2e_time_request[engine_idx].observe( - finished_request.e2e_latency) + finished_request.e2e_latency + ) self.histogram_queue_time_request[engine_idx].observe( - finished_request.queued_time) + finished_request.queued_time + ) self.histogram_prefill_time_request[engine_idx].observe( - finished_request.prefill_time) + finished_request.prefill_time + ) self.histogram_inference_time_request[engine_idx].observe( - finished_request.inference_time) + finished_request.inference_time + ) self.histogram_decode_time_request[engine_idx].observe( - finished_request.decode_time) + finished_request.decode_time + ) self.histogram_num_prompt_tokens_request[engine_idx].observe( - finished_request.num_prompt_tokens) + finished_request.num_prompt_tokens + ) self.histogram_num_generation_tokens_request[engine_idx].observe( - finished_request.num_generation_tokens) + finished_request.num_generation_tokens + ) self.histogram_request_time_per_output_token[engine_idx].observe( - finished_request.mean_time_per_output_token) + finished_request.mean_time_per_output_token + ) if finished_request.max_tokens_param: self.histogram_max_tokens_request[engine_idx].observe( - finished_request.max_tokens_param) + finished_request.max_tokens_param + ) if self.gauge_lora_info is not None: - running_lora_adapters = \ - ",".join(iteration_stats.running_lora_adapters.keys()) - waiting_lora_adapters = \ - ",".join(iteration_stats.waiting_lora_adapters.keys()) + running_lora_adapters = ",".join( + iteration_stats.running_lora_adapters.keys() + ) + waiting_lora_adapters = ",".join( + iteration_stats.waiting_lora_adapters.keys() + ) lora_info_labels = { self.labelname_running_lora_adapters: running_lora_adapters, self.labelname_waiting_lora_adapters: waiting_lora_adapters, self.labelname_max_lora: self.max_lora, } - self.gauge_lora_info.labels(**lora_info_labels)\ - .set_to_current_time() + self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) @@ -625,8 +778,9 @@ PromMetric = Union[ ] -def make_per_engine(metric: PromMetric, engine_idxs: list[int], - model_name: str) -> dict[int, PromMetric]: +def make_per_engine( + metric: PromMetric, engine_idxs: list[int], model_name: str +) -> dict[int, PromMetric]: return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs} @@ -688,7 +842,8 @@ class StatLoggerManager: if client_count > 1: logger.warning( "AsyncLLM created with api_server_count more than 1; " - "disabling stats logging to avoid incomplete stats.") + "disabling stats logging to avoid incomplete stats." + ) else: factories.append(LoggingStatLogger) @@ -700,12 +855,12 @@ class StatLoggerManager: for logger_factory in factories: # If we get a custom prometheus logger, use that # instead. This is typically used for the ray case. - if (isinstance(logger_factory, type) - and issubclass(logger_factory, PrometheusStatLogger)): + if isinstance(logger_factory, type) and issubclass( + logger_factory, PrometheusStatLogger + ): prometheus_factory = logger_factory continue - loggers.append(logger_factory(vllm_config, - engine_idx)) # type: ignore + loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore self.per_engine_logger_dict[engine_idx] = loggers # For Prometheus, need to share the metrics between EngineCores. @@ -725,8 +880,7 @@ class StatLoggerManager: for logger in per_engine_loggers: logger.record(scheduler_stats, iteration_stats, engine_idx) - self.prometheus_logger.record(scheduler_stats, iteration_stats, - engine_idx) + self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) def log(self): for per_engine_loggers in self.per_engine_logger_dict.values(): diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index a43cf9ce25..5823737968 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -16,9 +16,7 @@ _prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None def setup_multiprocess_prometheus(): - """Set up prometheus multiprocessing directory if not already configured. - - """ + """Set up prometheus multiprocessing directory if not already configured.""" global _prometheus_multiproc_dir if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: @@ -27,19 +25,22 @@ def setup_multiprocess_prometheus(): # cleaned up upon exit. _prometheus_multiproc_dir = tempfile.TemporaryDirectory() os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name - logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", - _prometheus_multiproc_dir.name) + logger.debug( + "Created PROMETHEUS_MULTIPROC_DIR at %s", _prometheus_multiproc_dir.name + ) else: - logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup." + ) def get_prometheus_registry() -> CollectorRegistry: - """Get the appropriate prometheus registry based on multiprocessing + """Get the appropriate prometheus registry based on multiprocessing configuration. - + Returns: Registry: A prometheus registry """ @@ -54,11 +55,11 @@ def get_prometheus_registry() -> CollectorRegistry: def unregister_vllm_metrics(): """Unregister any existing vLLM collectors from the prometheus registry. - + This is useful for testing and CI/CD where metrics may be registered multiple times across test runs. - - Also, in case of multiprocess, we need to unregister the metrics from the + + Also, in case of multiprocess, we need to unregister the metrics from the global registry. """ registry = REGISTRY diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index 6091857538..a6fe2062f7 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -15,11 +15,9 @@ import regex as re class RayPrometheusMetric: - def __init__(self): if ray_metrics is None: - raise ImportError( - "RayPrometheusMetric requires Ray to be installed.") + raise ImportError("RayPrometheusMetric requires Ray to be installed.") self.metric: Metric = None @@ -38,15 +36,14 @@ class RayPrometheusMetric: f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" ) - self.metric.set_default_tags( - dict(zip(self.metric._tag_keys, labels))) + self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels))) return self @staticmethod def _get_sanitized_opentelemetry_name(name: str) -> str: """ - For compatibility with Ray + OpenTelemetry, the metric name must be + For compatibility with Ray + OpenTelemetry, the metric name must be sanitized. In particular, this replaces disallowed character (e.g., ':') with '_' in the metric name. Allowed characters: a-z, A-Z, 0-9, _ @@ -63,21 +60,22 @@ class RayGaugeWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Gauge to provide same API as prometheus_client.Gauge""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - multiprocess_mode: Optional[str] = ""): - + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + multiprocess_mode: Optional[str] = "", + ): # All Ray metrics are keyed by WorkerId, so multiprocess modes like # "mostrecent", "all", "sum" do not apply. This logic can be manually # implemented at the observability layer (Prometheus/Grafana). del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) - self.metric = ray_metrics.Gauge(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self.metric = ray_metrics.Gauge( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def set(self, value: Union[int, float]): return self.metric.set(value) @@ -91,15 +89,17 @@ class RayCounterWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Counter to provide same API as prometheus_client.Counter""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None): + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) - self.metric = ray_metrics.Counter(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self.metric = ray_metrics.Counter( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def inc(self, value: Union[int, float] = 1.0): if value == 0: @@ -111,18 +111,22 @@ class RayHistogramWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Histogram to provide same API as prometheus_client.Histogram""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - buckets: Optional[list[float]] = None): + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) boundaries = buckets if buckets else [] - self.metric = ray_metrics.Histogram(name=name, - description=documentation, - tag_keys=labelnames_tuple, - boundaries=boundaries) + self.metric = ray_metrics.Histogram( + name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries, + ) def observe(self, value: Union[int, float]): return self.metric.observe(value) diff --git a/vllm/v1/metrics/reader.py b/vllm/v1/metrics/reader.py index 4d6e599841..5d50fa9461 100644 --- a/vllm/v1/metrics/reader.py +++ b/vllm/v1/metrics/reader.py @@ -17,6 +17,7 @@ class Metric: in some cases a single vLLM instance may have multiple metrics with the same name but different sets of labels. """ + name: str labels: dict[str, str] @@ -24,6 +25,7 @@ class Metric: @dataclass class Counter(Metric): """A monotonically increasing integer counter.""" + value: int @@ -34,12 +36,14 @@ class Vector(Metric): This type - which doesn't exist in Prometheus - models one very specific metric, vllm:spec_decode_num_accepted_tokens_per_pos. """ + values: list[int] @dataclass class Gauge(Metric): """A numerical value that can go up or down.""" + value: float @@ -58,6 +62,7 @@ class Histogram(Metric): The sum property is the total sum of all observed values. """ + count: int sum: float buckets: dict[str, int] @@ -87,7 +92,8 @@ def get_metrics_snapshot() -> list[Metric]: samples = _get_samples(metric) for s in samples: collected.append( - Gauge(name=metric.name, labels=s.labels, value=s.value)) + Gauge(name=metric.name, labels=s.labels, value=s.value) + ) elif metric.type == "counter": samples = _get_samples(metric, "_total") if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": @@ -99,16 +105,15 @@ def get_metrics_snapshot() -> list[Metric]: # accepted tokens using a Counter labeled with 'position'. # We convert these into a vector of integer values. # - for labels, values in _digest_num_accepted_by_pos_samples( - samples): + for labels, values in _digest_num_accepted_by_pos_samples(samples): collected.append( - Vector(name=metric.name, labels=labels, values=values)) + Vector(name=metric.name, labels=labels, values=values) + ) else: for s in samples: collected.append( - Counter(name=metric.name, - labels=s.labels, - value=int(s.value))) + Counter(name=metric.name, labels=s.labels, value=int(s.value)) + ) elif metric.type == "histogram": # @@ -122,21 +127,24 @@ def get_metrics_snapshot() -> list[Metric]: count_samples = _get_samples(metric, "_count") sum_samples = _get_samples(metric, "_sum") for labels, buckets, count_value, sum_value in _digest_histogram( - bucket_samples, count_samples, sum_samples): + bucket_samples, count_samples, sum_samples + ): collected.append( - Histogram(name=metric.name, - labels=labels, - buckets=buckets, - count=count_value, - sum=sum_value)) + Histogram( + name=metric.name, + labels=labels, + buckets=buckets, + count=count_value, + sum=sum_value, + ) + ) else: raise AssertionError(f"Unknown metric type {metric.type}") return collected -def _get_samples(metric: PromMetric, - suffix: Optional[str] = None) -> list[Sample]: +def _get_samples(metric: PromMetric, suffix: Optional[str] = None) -> list[Sample]: name = (metric.name + suffix) if suffix is not None else metric.name return [s for s in metric.samples if s.name == name] @@ -148,8 +156,7 @@ def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]: def _digest_histogram( - bucket_samples: list[Sample], count_samples: list[Sample], - sum_samples: list[Sample] + bucket_samples: list[Sample], count_samples: list[Sample], sum_samples: list[Sample] ) -> list[tuple[dict[str, str], dict[str, int], int, float]]: # # In the case of DP, we have an indigestable @@ -192,20 +199,25 @@ def _digest_histogram( labels_key = frozenset(s.labels.items()) sums_by_labels[labels_key] = s.value - assert set(buckets_by_labels.keys()) == set( - counts_by_labels.keys()) == set(sums_by_labels.keys()) + assert ( + set(buckets_by_labels.keys()) + == set(counts_by_labels.keys()) + == set(sums_by_labels.keys()) + ) output = [] label_keys = list(buckets_by_labels.keys()) for k in label_keys: labels = dict(k) - output.append((labels, buckets_by_labels[k], counts_by_labels[k], - sums_by_labels[k])) + output.append( + (labels, buckets_by_labels[k], counts_by_labels[k], sums_by_labels[k]) + ) return output def _digest_num_accepted_by_pos_samples( - samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]: + samples: list[Sample], +) -> list[tuple[dict[str, str], list[int]]]: # # In the case of DP, we have an indigestable # per-position-per-engine count as a list of diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 489b8da5c0..5564718d51 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: @dataclass class PrefixCacheStats: """Stores prefix cache hit statistics.""" + # Whether reset_prefix_cache was invoked. reset: bool = False # The number of new requests in this update. @@ -45,8 +46,7 @@ class SchedulerStats: kv_cache_usage: float = 0.0 - prefix_cache_stats: PrefixCacheStats = field( - default_factory=PrefixCacheStats) + prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) spec_decoding_stats: Optional[SpecDecodingStats] = None kv_connector_stats: Optional[dict[str, Any]] = None @@ -112,18 +112,22 @@ class IterationStats: self.running_lora_adapters: dict[str, int] = {} def __repr__(self) -> str: - field_to_value_str = ", ".join(f"{k}={v}" - for k, v in vars(self).items()) + field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items()) return f"{self.__class__.__name__}({field_to_value_str})" def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" return self.iteration_timestamp - start - def update_from_output(self, output: "EngineCoreOutput", - engine_core_timestamp: float, is_prefilling: bool, - prompt_len: int, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats]): + def update_from_output( + self, + output: "EngineCoreOutput", + engine_core_timestamp: float, + is_prefilling: bool, + prompt_len: int, + req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats], + ): num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens @@ -138,8 +142,9 @@ class IterationStats: # Process request-level engine core events if output.events is not None: - self.update_from_events(output.request_id, output.events, - is_prefilling, req_stats, lora_stats) + self.update_from_events( + output.request_id, output.events, is_prefilling, req_stats, lora_stats + ) # Process the batch-level "new tokens" engine core event if is_prefilling: @@ -150,11 +155,17 @@ class IterationStats: req_stats.last_token_ts = engine_core_timestamp - def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], - is_prefilling: bool, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats]): + def update_from_events( + self, + req_id: str, + events: list["EngineCoreEvent"], + is_prefilling: bool, + req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats], + ): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType + for event in events: if event.type == EngineCoreEventType.QUEUED: req_stats.queued_ts = event.timestamp @@ -168,10 +179,13 @@ class IterationStats: self.num_preempted_reqs += 1 LoRARequestStates.preempted_request(lora_stats, req_id) - def update_from_finished_request(self, finish_reason: "FinishReason", - num_prompt_tokens: int, - max_tokens_param: Optional[int], - req_stats: RequestStateStats): + def update_from_finished_request( + self, + finish_reason: "FinishReason", + num_prompt_tokens: int, + max_tokens_param: Optional[int], + req_stats: RequestStateStats, + ): e2e_latency = self._time_since(req_stats.arrival_time) # Queued interval is from first QUEUED event to first SCHEDULED @@ -190,22 +204,24 @@ class IterationStats: inference_time = req_stats.last_token_ts - req_stats.scheduled_ts # Do not count the token generated by the prefill phase - mean_time_per_output_token = (decode_time / - (req_stats.num_generation_tokens - 1) - if req_stats.num_generation_tokens - - 1 > 0 else 0) + mean_time_per_output_token = ( + decode_time / (req_stats.num_generation_tokens - 1) + if req_stats.num_generation_tokens - 1 > 0 + else 0 + ) - finished_req = \ - FinishedRequestStats(finish_reason=finish_reason, - e2e_latency=e2e_latency, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=req_stats.num_generation_tokens, - max_tokens_param=max_tokens_param, - queued_time=queued_time, - prefill_time=prefill_time, - inference_time=inference_time, - decode_time=decode_time, - mean_time_per_output_token=mean_time_per_output_token) + finished_req = FinishedRequestStats( + finish_reason=finish_reason, + e2e_latency=e2e_latency, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, + queued_time=queued_time, + prefill_time=prefill_time, + inference_time=inference_time, + decode_time=decode_time, + mean_time_per_output_token=mean_time_per_output_token, + ) self.finished_requests.append(finished_req) @@ -215,24 +231,24 @@ class LoRARequestStates: def __init__(self): self.lora_name_to_stats: dict[str, LoRAStats] = {} - def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: + def get_stats(self, req_state: "RequestState") -> Optional[LoRAStats]: if req_state.lora_name is None: return None if req_state.lora_name not in self.lora_name_to_stats: self.lora_name_to_stats[req_state.lora_name] = LoRAStats() return self.lora_name_to_stats[req_state.lora_name] - def add_request(self, req_state: 'RequestState'): + def add_request(self, req_state: "RequestState"): if (lora_stats := self.get_stats(req_state)) is not None: lora_stats.waiting_requests.add(req_state.request_id) - def finish_request(self, req_state: 'RequestState'): + def finish_request(self, req_state: "RequestState"): if req_state.lora_name is None: return lora_stats = self.lora_name_to_stats[req_state.lora_name] lora_stats.running_requests.remove(req_state.request_id) - def abort_request(self, req_state: 'RequestState'): + def abort_request(self, req_state: "RequestState"): if req_state.lora_name is None: return lora_stats = self.lora_name_to_stats[req_state.lora_name] @@ -255,14 +271,15 @@ class LoRARequestStates: lora_stats.running_requests.remove(request_id) lora_stats.waiting_requests.add(request_id) - def update_iteration_stats(self, - iteration_stats: Optional[IterationStats]): + def update_iteration_stats(self, iteration_stats: Optional[IterationStats]): if iteration_stats is None: return for lora_name, stats in self.lora_name_to_stats.items(): if stats.waiting_requests: - iteration_stats.waiting_lora_adapters[lora_name] = \ - len(stats.waiting_requests) + iteration_stats.waiting_lora_adapters[lora_name] = len( + stats.waiting_requests + ) if stats.running_requests: - iteration_stats.running_lora_adapters[lora_name] = \ - len(stats.running_requests) + iteration_stats.running_lora_adapters[lora_name] = len( + stats.running_requests + ) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index d15cdf3659..d647b20757 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,12 +8,10 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Union import torch if TYPE_CHECKING: - from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] @@ -30,7 +28,6 @@ class LogprobsLists(NamedTuple): class LogprobsTensors(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] @@ -46,18 +43,18 @@ class LogprobsTensors(NamedTuple): ) @staticmethod - def empty_cpu(num_positions: int, - num_tokens_per_position: int) -> "LogprobsTensors": + def empty_cpu( + num_positions: int, num_tokens_per_position: int + ) -> "LogprobsTensors": """Create empty LogprobsTensors on CPU.""" logprob_token_ids = torch.empty( - (num_positions, num_tokens_per_position), - dtype=torch.int32, - device="cpu") + (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu" + ) logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) - selected_token_ranks = torch.empty(num_positions, - dtype=torch.int32, - device="cpu") + selected_token_ranks = torch.empty( + num_positions, dtype=torch.int32, device="cpu" + ) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, @@ -72,7 +69,6 @@ PoolerOutput = Union[torch.Tensor, list[torch.Tensor]] @dataclass class SamplerOutput: - # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. @@ -92,15 +88,18 @@ class KVConnectorOutput: invalid_block_ids: set[int] = field(default_factory=set) def is_empty(self): - return (not self.finished_sending and not self.finished_recving - and not self.kv_connector_stats and not self.invalid_block_ids) + return ( + not self.finished_sending + and not self.finished_recving + and not self.kv_connector_stats + and not self.invalid_block_ids + ) # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass class ModelRunnerOutput: - # [num_reqs] req_ids: list[str] # req_id -> index @@ -134,11 +133,10 @@ class ModelRunnerOutput: # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): - @abstractmethod def get_output(self) -> ModelRunnerOutput: """Get the ModelRunnerOutput for this async output. - + This is a blocking call that waits until the results are ready, which might involve copying device tensors to the host. This method should only be called once per AsyncModelRunnerOutput. @@ -148,17 +146,18 @@ class AsyncModelRunnerOutput(ABC): @dataclass class DraftTokenIds: - # [num_reqs] req_ids: list[str] # num_reqs x num_draft_tokens draft_token_ids: list[list[int]] -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, +) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 46506d272e..36ae5b40a3 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -29,13 +29,13 @@ class PoolingCursor: ) def is_partial_prefill(self): - return not torch.all( - self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] @@ -44,34 +44,40 @@ class PoolingMetadata: def __getitem__(self, indices: slice): return PoolingMetadata( prompt_lens=self.prompt_lens[indices], - prompt_token_ids=None if self.prompt_token_ids is None else - self.prompt_token_ids[indices], + prompt_token_ids=None + if self.prompt_token_ids is None + else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], pooling_cursor=None - if self.pooling_cursor is None else self.pooling_cursor[indices], + if self.pooling_cursor is None + else self.pooling_cursor[indices], ) - def build_pooling_cursor(self, num_scheduled_tokens: list[int], - device: torch.device): - self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, - self.prompt_lens, device) + def build_pooling_cursor( + self, num_scheduled_tokens: list[int], device: torch.device + ): + self.pooling_cursor = build_pooling_cursor( + num_scheduled_tokens, self.prompt_lens, device + ) -def build_pooling_cursor(num_scheduled_tokens: list[int], - prompt_lens: torch.Tensor, device: torch.device): +def build_pooling_cursor( + num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device +): assert len(prompt_lens) == len(num_scheduled_tokens) n_seq = len(num_scheduled_tokens) index = list(range(n_seq)) num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") - cumsum = torch.zeros(n_seq + 1, - dtype=torch.int64, - pin_memory=pin_memory, - device="cpu") + cumsum = torch.zeros( + n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" + ) torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) cumsum = cumsum.to(device, non_blocking=True) - return PoolingCursor(index=index, - first_token_indices_gpu=cumsum[:n_seq], - last_token_indices_gpu=cumsum[1:] - 1, - prompt_lens_cpu=prompt_lens, - num_scheduled_tokens_cpu=num_scheduled_tokens) + return PoolingCursor( + index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens, + ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index dd0aea645d..ac6e583099 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -13,8 +13,12 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, - EngineCoreRequest, FinishReason) +from vllm.v1.engine import ( + EngineCoreEvent, + EngineCoreEventType, + EngineCoreRequest, + FinishReason, +) from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList @@ -24,7 +28,6 @@ if TYPE_CHECKING: class Request: - def __init__( self, request_id: str, @@ -41,8 +44,7 @@ class Request: cache_salt: Optional[str] = None, priority: int = 0, trace_headers: Optional[Mapping[str, str]] = None, - block_hasher: Optional[Callable[["Request"], - list["BlockHash"]]] = None, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -53,8 +55,7 @@ class Request: self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request - self.arrival_time = arrival_time if arrival_time is not None else \ - time.time() + self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING self.use_structured_output = False @@ -76,20 +77,23 @@ class Request: self.use_structured_output = True if sampling_params.extra_args is not None: - self.kv_transfer_params = \ - sampling_params.extra_args.get("kv_transfer_params") + self.kv_transfer_params = sampling_params.extra_args.get( + "kv_transfer_params" + ) else: - raise ValueError( - "sampling_params and pooling_params can't both be unset") + raise ValueError("sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids self.prompt_embeds = prompt_embeds self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - prompt_token_ids, prompt_embeds) + prompt_token_ids, prompt_embeds + ) self._output_token_ids: list[int] = [] - self._all_token_ids: list[int] = self.prompt_token_ids.copy( - ) if self.prompt_token_ids is not None else [0 - ] * self.num_prompt_tokens + self._all_token_ids: list[int] = ( + self.prompt_token_ids.copy() + if self.prompt_token_ids is not None + else [0] * self.num_prompt_tokens + ) self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 @@ -119,16 +123,16 @@ class Request: self.num_preemptions = 0 self.block_hashes: list[BlockHash] = [] - self.get_hash_new_full_blocks: Optional[Callable[ - [], list[BlockHash]]] = None + self.get_hash_new_full_blocks: Optional[Callable[[], list[BlockHash]]] = None if block_hasher is not None: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() @classmethod def from_engine_core_request( - cls, request: EngineCoreRequest, - block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + cls, + request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]], ) -> "Request": return cls( request_id=request.request_id, @@ -142,8 +146,10 @@ class Request: arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( - sampling_params=request.sampling_params) \ - if request.sampling_params else None, + sampling_params=request.sampling_params + ) + if request.sampling_params + else None, cache_salt=request.cache_salt, priority=request.priority, trace_headers=request.trace_headers, @@ -207,6 +213,7 @@ class Request: class RequestStatus(enum.IntEnum): """Status of a request.""" + WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() WAITING_FOR_REMOTE_KVS = enum.auto() @@ -227,8 +234,7 @@ class RequestStatus(enum.IntEnum): return status > RequestStatus.PREEMPTED @staticmethod - def get_finished_reason( - status: "RequestStatus") -> Union[FinishReason, None]: + def get_finished_reason(status: "RequestStatus") -> Union[FinishReason, None]: return _FINISHED_REASON_MAP.get(status) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 10cad5b530..98c4d8bad0 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -13,15 +13,18 @@ import torch from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - process_dict_updates) -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) -from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, - LogitsProcessors) +from vllm.v1.sample.logits_processor.builtin import ( + LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + process_dict_updates, +) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) +from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors if TYPE_CHECKING: from vllm.config import VllmConfig @@ -30,10 +33,11 @@ logger = init_logger(__name__) # Error message when the user tries to initialize vLLM with a pooling model # and custom logitsproces -STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" - " logits processors.") +STR_POOLING_REJECTS_LOGITSPROCS = ( + "Pooling models do not support custom logits processors." +) -LOGITSPROCS_GROUP = 'vllm.logits_processors' +LOGITSPROCS_GROUP = "vllm.logits_processors" BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ MinTokensLogitsProcessor, @@ -54,27 +58,29 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) if len(installed_logitsprocs_plugins) == 0: - logger.debug("No logitsprocs plugins installed (group %s).", - LOGITSPROCS_GROUP) + logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP) return [] # Load logitsprocs plugins - logger.debug("Loading installed logitsprocs plugins (group %s):", - LOGITSPROCS_GROUP) + logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP) classes: list[type[LogitsProcessor]] = [] for entrypoint in installed_logitsprocs_plugins: try: - logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", - entrypoint.name, entrypoint.value) + logger.debug( + "- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, + entrypoint.value, + ) classes.append(entrypoint.load()) except Exception as e: raise RuntimeError( - f"Failed to load LogitsProcessor plugin {entrypoint}") from e + f"Failed to load LogitsProcessor plugin {entrypoint}" + ) from e return classes def _load_logitsprocs_by_fqcns( - logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], ) -> list[type[LogitsProcessor]]: """Load logit processor types, identifying them by fully-qualified class names (FQCNs). @@ -99,13 +105,14 @@ def _load_logitsprocs_by_fqcns( logger.debug( "%s additional custom logits processors specified, checking whether " - "they need to be loaded.", len(logits_processors)) + "they need to be loaded.", + len(logits_processors), + ) classes: list[type[LogitsProcessor]] = [] for ldx, logitproc in enumerate(logits_processors): if isinstance(logitproc, type): - logger.debug(" - Already-loaded logit processor: %s", - logitproc.__name__) + logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__) if not issubclass(logitproc, LogitsProcessor): raise ValueError( f"{logitproc.__name__} is not a subclass of LogitsProcessor" @@ -131,8 +138,7 @@ def _load_logitsprocs_by_fqcns( if not isinstance(obj, type): raise ValueError("Loaded logit processor must be a type.") if not issubclass(obj, LogitsProcessor): - raise ValueError( - f"{obj.__name__} must be a subclass of LogitsProcessor") + raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor") classes.append(obj) return classes @@ -155,13 +161,13 @@ def _load_custom_logitsprocs( A list of all loaded logitproc types """ from vllm.platforms import current_platform + if current_platform.is_tpu(): # No logitsprocs specified by caller # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs return [] - return (_load_logitsprocs_plugins() + - _load_logitsprocs_by_fqcns(logits_processors)) + return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors) def build_logitsprocs( @@ -174,23 +180,28 @@ def build_logitsprocs( if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) - logger.debug("Skipping logits processor loading because pooling models" - " do not support logits processors.") + logger.debug( + "Skipping logits processor loading because pooling models" + " do not support logits processors." + ) return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( - ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( - BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + ctor(vllm_config, device, is_pin_memory) + for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes + ) + ) class AdapterLogitsProcessor(LogitsProcessor): """Wrapper for per-request logits processors - + To wrap a specific per-request logits processor, * Subclass `AdapterLogitsProcessor` * Implement `self.is_argmax_invariant()` base-class method * Implement `self.new_req_logits_processor(params)` - + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be overridden in general. However, to implement custom constructor behavior - especially any logic which operates on or stores `vllm_config`, `device`, @@ -199,8 +210,9 @@ class AdapterLogitsProcessor(LogitsProcessor): `super().__init__(vllm_config, device, is_pin_memory)` """ - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): """Subclass must invoke `super().__init__(vllm_config, device, is_pin_memory)`. @@ -236,7 +248,7 @@ class AdapterLogitsProcessor(LogitsProcessor): Returns: None if logits processor should not be applied to request; otherwise returns a `RequestLogitsProcessor` instance - + """ raise NotImplementedError @@ -257,11 +269,14 @@ class AdapterLogitsProcessor(LogitsProcessor): Returns: logits processor partial[Tensor] or None - + """ if req_lp := self.new_req_logits_processor(params): - args = [prompt_ids, output_ids] if (len( - inspect.signature(req_lp).parameters) == 3) else [output_ids] + args = ( + [prompt_ids, output_ids] + if (len(inspect.signature(req_lp).parameters) == 3) + else [output_ids] + ) return partial(req_lp, *args) return None @@ -286,9 +301,16 @@ class AdapterLogitsProcessor(LogitsProcessor): __all__ = [ - "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", - "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", - "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", - "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", - "AdapterLogitsProcessor" + "LogitsProcessor", + "LogitBiasLogitsProcessor", + "MinPLogitsProcessor", + "MinTokensLogitsProcessor", + "BatchUpdate", + "BatchUpdateBuilder", + "MoveDirectionality", + "LogitsProcessors", + "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", + "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor", ] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index fc655d993c..3c3ddda7fb 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -6,9 +6,11 @@ from typing import TYPE_CHECKING, Callable, Optional, TypeVar import torch from vllm import SamplingParams -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -17,25 +19,24 @@ T = TypeVar("T") class MinPLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 - self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=is_pin_memory) + self.min_p_cpu_tensor = torch.zeros( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) else: self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor @@ -93,8 +94,7 @@ class MinPLogitsProcessor(LogitsProcessor): if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] if self.use_double_tensor: - self.min_p.copy_(self.min_p_cpu_tensor[:size], - non_blocking=True) + self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -104,28 +104,27 @@ class MinPLogitsProcessor(LogitsProcessor): # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Adjust min_p adjusted_min_p = max_probabilities.mul_(self.min_p) # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing - logits[invalid_token_mask] = -float('inf') + logits[invalid_token_mask] = -float("inf") return logits class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device self.pin_memory = is_pin_memory self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) - self.logits_slice = (self._device_tensor([], torch.int32), - self._device_tensor([], torch.int32)) + self.logits_slice = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """Logit bias can rebalance token probabilities and change the @@ -134,8 +133,8 @@ class LogitBiasLogitsProcessor(LogitsProcessor): def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = process_dict_updates( - self.biases, batch_update, - lambda params, _, __: params.logit_bias or None) + self.biases, batch_update, lambda params, _, __: params.logit_bias or None + ) # Update tensors if needed. if needs_update: @@ -148,15 +147,15 @@ class LogitBiasLogitsProcessor(LogitsProcessor): biases.extend(lb.values()) self.bias_tensor = self._device_tensor(biases, torch.float32) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.biases: @@ -165,20 +164,19 @@ class LogitBiasLogitsProcessor(LogitsProcessor): class MinTokensLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): # index -> (min_toks, output_token_ids, stop_token_ids) self.device = device self.pin_memory = is_pin_memory self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) - self.logits_slice: tuple[torch.Tensor, - torch.Tensor] = (self._device_tensor( - [], torch.int32), - self._device_tensor( - [], torch.int32)) + self.logits_slice: tuple[torch.Tensor, torch.Tensor] = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome @@ -187,8 +185,7 @@ class MinTokensLogitsProcessor(LogitsProcessor): @staticmethod def add_request( - params: SamplingParams, _: Optional[list[int]], - output_tok_ids: list[int] + params: SamplingParams, _: Optional[list[int]], output_tok_ids: list[int] ) -> Optional[tuple[int, Sequence[int], set[int]]]: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: @@ -196,13 +193,16 @@ class MinTokensLogitsProcessor(LogitsProcessor): return min_tokens, output_tok_ids, params.all_stop_token_ids def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = process_dict_updates(self.min_toks, batch_update, - self.add_request) + needs_update = process_dict_updates( + self.min_toks, batch_update, self.add_request + ) if self.min_toks: # Check for any requests that have attained their min tokens. - to_remove = tuple(index for index, (min_toks, out_tok_ids, - _) in self.min_toks.items() - if len(out_tok_ids) >= min_toks) + to_remove = tuple( + index + for index, (min_toks, out_tok_ids, _) in self.min_toks.items() + if len(out_tok_ids) >= min_toks + ) if to_remove: needs_update = True for index in to_remove: @@ -216,15 +216,15 @@ class MinTokensLogitsProcessor(LogitsProcessor): reqs.extend([req] * len(stop_tok_ids)) tok_ids.extend(stop_tok_ids) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: @@ -234,9 +234,9 @@ class MinTokensLogitsProcessor(LogitsProcessor): def process_dict_updates( - req_entries: dict[int, T], batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], - Optional[T]] + req_entries: dict[int, T], + batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], Optional[T]], ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" @@ -246,8 +246,7 @@ def process_dict_updates( updated = False for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - if (state := new_state(params, prompt_tok_ids, - output_tok_ids)) is not None: + if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None: req_entries[index] = state updated = True elif req_entries.pop(index, None) is not None: diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index a84afc2f34..713bd21d38 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -36,6 +36,7 @@ MovedRequest = tuple[int, int, MoveDirectionality] @dataclass(frozen=True) class BatchUpdate: """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch # Metadata for requests added to, removed from, and moved @@ -57,10 +58,10 @@ class BatchUpdate: class LogitsProcessor(ABC): - @abstractmethod - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool) -> None: + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ) -> None: raise NotImplementedError @abstractmethod diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 0a1196559d..a601f66415 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -4,10 +4,12 @@ from collections.abc import Iterator from itertools import chain from typing import TYPE_CHECKING, Optional -from vllm.v1.sample.logits_processor.interface import (AddedRequest, - BatchUpdate, - MovedRequest, - RemovedRequest) +from vllm.v1.sample.logits_processor.interface import ( + AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest, +) if TYPE_CHECKING: from vllm.v1.sample.logits_processor.interface import LogitsProcessor @@ -81,8 +83,9 @@ class BatchUpdateBuilder: index: request index """ if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") + raise RuntimeError( + "Cannot register new removed request after self.removed has been read." + ) self._removed.append(index) self.batch_changed = True @@ -116,7 +119,7 @@ class BatchUpdateBuilder: def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """Generate a logitsprocs batch update data structure and reset internal batch update builder state. - + Args: batch_size: current persistent batch size @@ -146,14 +149,17 @@ class LogitsProcessors: """Encapsulates initialized logitsproc objects.""" def __init__( - self, - logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self, logitsprocs: Optional[Iterator["LogitsProcessor"]] = None + ) -> None: self.argmax_invariant: list[LogitsProcessor] = [] self.non_argmax_invariant: list[LogitsProcessor] = [] if logitsprocs: for logitproc in logitsprocs: - (self.argmax_invariant if logitproc.is_argmax_invariant() else - self.non_argmax_invariant).append(logitproc) + ( + self.argmax_invariant + if logitproc.is_argmax_invariant() + else self.non_argmax_invariant + ).append(logitproc) @property def all(self) -> Iterator["LogitsProcessor"]: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9d6a87cea3..14895db1bd 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -11,7 +11,6 @@ from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass class SamplingMetadata: - temperature: Optional[torch.Tensor] all_greedy: bool all_random: bool diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py index 1b699565f2..25cbb2619e 100644 --- a/vllm/v1/sample/ops/bad_words.py +++ b/vllm/v1/sample/ops/bad_words.py @@ -35,5 +35,4 @@ def apply_bad_words( past_tokens_ids: list[list[int]], ) -> None: for i, bad_words_ids in bad_words_token_ids.items(): - _apply_bad_words_single_batch(logits[i], bad_words_ids, - past_tokens_ids[i]) + _apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i]) diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py index 82875b7c84..cf36d46e13 100644 --- a/vllm/v1/sample/ops/logprobs.py +++ b/vllm/v1/sample/ops/logprobs.py @@ -8,8 +8,7 @@ from vllm.platforms import current_platform @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def batched_count_greater_than(x: torch.Tensor, - values: torch.Tensor) -> torch.Tensor: +def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor: """ Counts elements in each row of x that are greater than the corresponding value in values. Use torch.compile to generate an optimized kernel for diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index 5d54f6679a..e49b8db478 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -19,15 +19,20 @@ def apply_all_penalties( Applies presence, frequency and repetition penalties to the logits. """ _, vocab_size = logits.shape - output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, - logits.device) - return apply_penalties(logits, prompt_token_ids, output_tokens_t, - presence_penalties, frequency_penalties, - repetition_penalties) + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device) + return apply_penalties( + logits, + prompt_token_ids, + output_tokens_t, + presence_penalties, + frequency_penalties, + repetition_penalties, + ) -def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, - device: torch.device) -> torch.Tensor: +def _convert_to_tensors( + output_token_ids: list[list[int]], vocab_size: int, device: torch.device +) -> torch.Tensor: """ Convert the different list data structures to tensors. """ diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5bcf1b5854..dbcdad07e4 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -16,6 +16,7 @@ logger = init_logger(__name__) try: import flashinfer.sampling + is_flashinfer_available = True except ImportError: is_flashinfer_available = False @@ -34,14 +35,17 @@ class TopKTopPSampler(nn.Module): self.logprobs_mode = logprobs_mode # flashinfer optimization does not apply if intermediate # logprobs/logits after top_k/top_p need to be returned - if logprobs_mode not in ("processed_logits", "processed_logprobs" - ) and current_platform.is_cuda(): + if ( + logprobs_mode not in ("processed_logits", "processed_logprobs") + and current_platform.is_cuda() + ): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): logger.warning_once( "FlashInfer version >= 0.2.3 required. " - "Falling back to default sampling implementation.") + "Falling back to default sampling implementation." + ) self.forward = self.forward_native elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for @@ -52,21 +56,22 @@ class TopKTopPSampler(nn.Module): # None means False, while in V1, None means True. This is # why we use the condition # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - logger.info_once( - "Using FlashInfer for top-p & top-k sampling.") + logger.info_once("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda else: logger.warning_once( "FlashInfer is available, but it is not enabled. " "Falling back to the PyTorch-native implementation of " "top-p & top-k sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + "please set VLLM_USE_FLASHINFER_SAMPLER=1." + ) self.forward = self.forward_native else: logger.warning_once( "FlashInfer is not available. Falling back to the PyTorch-" "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer.") + "best performance, please install FlashInfer." + ) self.forward = self.forward_native elif current_platform.is_cpu(): self.forward = self.forward_cpu @@ -109,13 +114,15 @@ class TopKTopPSampler(nn.Module): # CPU-GPU synchronization while `flashinfer_sample` does. if (k is None and p is None) or generators: if generators: - logger.debug_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + logger.debug_once( + "FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation." + ) return self.forward_native(logits, generators, k, p) - assert self.logprobs_mode not in ( - "processed_logits", "processed_logprobs" - ), "FlashInfer does not support returning logits/logprobs" + assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), ( + "FlashInfer does not support returning logits/logprobs" + ) # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. @@ -278,15 +285,18 @@ def flashinfer_sample( # Top-p only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( - probs, p, deterministic=True) + probs, p, deterministic=True + ) elif p is None: # Top-k only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( - probs, k, deterministic=True) + probs, k, deterministic=True + ) else: # Both top-k and top-p. next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits( - logits, k, p, deterministic=True) + logits, k, p, deterministic=True + ) return next_token_ids.view(-1) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 37ce5bef84..ec9366aa25 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -54,7 +54,7 @@ class RejectionSampler(nn.Module): bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - ''' + """ Args: metadata: Metadata for spec decoding. @@ -81,7 +81,7 @@ class RejectionSampler(nn.Module): Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. - ''' + """ assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the @@ -123,11 +123,11 @@ class RejectionSampler(nn.Module): """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. - valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & - (output_token_ids_np < vocab_size)) + valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( + output_token_ids_np < vocab_size + ) outputs = [ - row[valid_mask[i]].tolist() - for i, row in enumerate(output_token_ids_np) + row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs @@ -178,7 +178,7 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - rejection_greedy_sample_kernel[(batch_size, )]( + rejection_greedy_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -213,7 +213,7 @@ def rejection_sample( ) # Rejection sampling for random sampling requests. - rejection_random_sample_kernel[(batch_size, )]( + rejection_random_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -320,7 +320,7 @@ def expand_batch_to_tokens( batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) - expand_kernel[(batch_size, )]( + expand_kernel[(batch_size,)]( expanded_x, x, cu_num_tokens, @@ -368,7 +368,7 @@ def generate_uniform_probs( # https://github.com/pytorch/pytorch/issues/16706. Using float64 # mitigates the issue. uniform_probs = torch.rand( - (num_tokens, ), + (num_tokens,), dtype=torch.float64, device=device, ) @@ -464,8 +464,10 @@ def rejection_greedy_sample_kernel( if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - target_argmax_id) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) if draft_token_id != target_argmax_id: # Reject. rejected = True @@ -474,8 +476,9 @@ def rejection_greedy_sample_kernel( # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @@ -514,12 +517,12 @@ def rejection_random_sample_kernel( if NO_DRAFT_PROBS: draft_prob = 1 else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. @@ -530,15 +533,17 @@ def rejection_random_sample_kernel( # Reject. Use recovered token. rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - token_id) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id + ) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @@ -562,9 +567,7 @@ def expand_kernel( src_val = tl.load(input_ptr + req_idx) src_val = tl.where(src_val == replace_from, replace_to, src_val) offset = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx + offset, - src_val, - mask=offset < num_tokens) + tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens) @triton.jit @@ -595,26 +598,30 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=((vocab_offset < vocab_size) & - (vocab_offset != draft_token_id)), - other=0) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)), + other=0, + ) else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf")) + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf"), + ) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 83ea766b1b..d4d3fb0295 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -24,39 +24,39 @@ class Sampler(nn.Module): A layer that samples the next tokens from the model's outputs with the following steps in order: - 1. If logprobs are requested: + 1. If logprobs are requested: a) If `logprobs_mode` is `raw_logprobs`, compute logprobs - as the final logprobs to return. + as the final logprobs to return. b) If `logprobs_mode` is `raw_logits`, clone the logits - as the final logprobs to return. - 2. Convert logits to float32. - 3. Apply allowed token ids whitelist. - 4. Apply bad words exclusion. + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. 5. Apply logit processors which are not argmax-invariant, - i.e. that can impact greedy sampling. - a) Min tokens processor - b) Logit bias processor - 6. Apply penalties - a) Repetition penalty - b) Frequency penalty - c) Presence penalty - 7. Sample the next tokens. `sample` method performs the following steps: + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: a) If not `all_random`, perform greedy sampling. If `all_greedy`, - return the greedily sampled tokens and final logprobs if requested. - b) Apply temperature. + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. c) Apply logit processors which are argmax-invariant, by default - the min_p processor. - d) Apply top_k and/or top_p. - e) Sample the next tokens with the probability distribution. + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. f) If `all_random` or temperature >= epsilon (1e-5), return the randomly sampled tokens and final logprobs if requested. Else, - return the greedily sampled tokens and logprobs if requested. + return the greedily sampled tokens and logprobs if requested. 8. Gather the logprobs of the top `max_num_logprobs` and sampled token (if requested). Note that if the sampled token is within the top `max_num_logprobs`, the logprob will be eventually merged in `LogprobsProcessor` during output processing. Therefore, the final output may contain either `max_num_logprobs + 1` or - `max_num_logprobs` logprobs. + `max_num_logprobs` logprobs. 9. Return the final `SamplerOutput`. """ @@ -108,8 +108,11 @@ class Sampler(nn.Module): # Gather the logprobs of the topk and sampled token (if requested). # Get logprobs and rank tensors (if requested) - logprobs_tensors = None if num_logprobs is None else \ - self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + logprobs_tensors = ( + None + if num_logprobs is None + else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + ) # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) @@ -150,8 +153,7 @@ class Sampler(nn.Module): may update the logits tensor in-place. """ - assert not (sampling_metadata.all_greedy - and sampling_metadata.all_random) + assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: greedy_sampled = None else: @@ -168,8 +170,9 @@ class Sampler(nn.Module): assert sampling_metadata.temperature is not None # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature, - sampling_metadata.all_random) + logits = self.apply_temperature( + logits, sampling_metadata.temperature, sampling_metadata.all_random + ) # Apply logits processors that only apply to random sampling # (argmax invariant) @@ -224,9 +227,7 @@ class Sampler(nn.Module): """ assert token_ids.dtype == torch.int64 # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) + topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) @@ -267,8 +268,7 @@ class Sampler(nn.Module): sampling_metadata: SamplingMetadata, ) -> torch.Tensor: if sampling_metadata.allowed_token_ids_mask is not None: - logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, - float("-inf")) + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) return logits def apply_bad_words( diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 6491c84f60..b58a94d0bf 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -48,15 +48,13 @@ class TPUSupportedSamplingMetadata: min_tokens = None # impl is not vectorized - logit_bias: list[Optional[dict[int, float]]] = field( - default_factory=lambda: list()) + logit_bias: list[Optional[dict[int, float]]] = field(default_factory=lambda: list()) allowed_token_ids_mask = None bad_words_token_ids = None # Generator not supported by xla - _generators: dict[int, - torch.Generator] = field(default_factory=lambda: dict()) + _generators: dict[int, torch.Generator] = field(default_factory=lambda: dict()) @property def generators(self) -> dict[int, torch.Generator]: @@ -69,13 +67,13 @@ class TPUSupportedSamplingMetadata: input_batch: InputBatch, padded_num_reqs: int, xla_device: torch.device, - generate_params_if_all_greedy: bool = False + generate_params_if_all_greedy: bool = False, ) -> "TPUSupportedSamplingMetadata": """ Copy sampling tensors slices from `input_batch` to on device tensors. - `InputBatch._make_sampling_metadata` causes recompilation on XLA as it - slices dynamic shapes on device tensors. This impl moves the dynamic + `InputBatch._make_sampling_metadata` causes recompilation on XLA as it + slices dynamic shapes on device tensors. This impl moves the dynamic ops to CPU and produces tensors of fixed `padded_num_reqs` size. Args: @@ -87,11 +85,11 @@ class TPUSupportedSamplingMetadata: we want to pre-compile a graph with sampling parameters, even if they are not strictly needed for greedy decoding. """ - needs_logprobs = input_batch.max_num_logprobs>0 if \ - input_batch.max_num_logprobs else False + needs_logprobs = ( + input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False + ) # Early return to avoid unnecessary cpu to tpu copy - if (input_batch.all_greedy is True - and generate_params_if_all_greedy is False): + if input_batch.all_greedy is True and generate_params_if_all_greedy is False: return cls(all_greedy=True, logprobs=needs_logprobs) num_reqs = input_batch.num_reqs @@ -100,25 +98,22 @@ class TPUSupportedSamplingMetadata: # Pad value is the default one. cpu_tensor[num_reqs:padded_num_reqs] = fill_val - fill_slice(input_batch.temperature_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["temperature"]) - fill_slice(input_batch.min_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["min_p"]) - fill_slice(input_batch.top_k_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_k"]) - fill_slice(input_batch.top_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_p"]) + fill_slice( + input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"] + ) + fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"]) + fill_slice(input_batch.top_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( - temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs]. - to(xla_device), + temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].to( + xla_device + ), all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values - top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( - xla_device), - min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - logprobs=needs_logprobs) + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), + min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(xla_device), + logprobs=needs_logprobs, + ) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 17b83a4ba0..ccef283a81 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -14,7 +14,6 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): - def __init__(self): # TODO(houseroad): Add support for logprobs_mode. super().__init__() @@ -35,7 +34,8 @@ class Sampler(nn.Module): # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=sampled.unsqueeze(-1), - logprobs_tensors=None) + logprobs_tensors=None, + ) return sampler_output def apply_temperature( @@ -73,11 +73,13 @@ class Sampler(nn.Module): # Random sample. probs = logits.softmax(dim=-1, dtype=torch.float32) - random_sampled = self.random_sample(probs, - sampling_metadata.generators) + random_sampled = self.random_sample(probs, sampling_metadata.generators) - sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, - greedy_sampled, random_sampled) + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + ) return sampled def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: @@ -107,9 +109,7 @@ class Sampler(nn.Module): Sampled token rank tensor, (num tokens) """ # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) + topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) @@ -138,9 +138,7 @@ class Sampler(nn.Module): # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Reshape min_p for broadcasting adjusted_min_p = min_p.unsqueeze(1) * max_probabilities # Identify valid tokens using threshold comparison diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 876838084b..f7a73cba60 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -18,14 +18,21 @@ from msgspec import msgpack from vllm import envs from vllm.logger import init_logger + # yapf: disable -from vllm.multimodal.inputs import (BaseMultiModalField, - MultiModalBatchedField, - MultiModalFieldConfig, MultiModalFieldElem, - MultiModalFlatField, MultiModalKwargs, - MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField, NestedTensors) +from vllm.multimodal.inputs import ( + BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, + MultiModalFieldElem, + MultiModalFlatField, + MultiModalKwargs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, + NestedTensors, +) + # yapf: enable from vllm.v1.engine import UtilityResult @@ -48,8 +55,10 @@ bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] def _log_insecure_serialization_warning(): - logger.warning_once("Allowing insecure serialization using pickle due to " - "VLLM_ALLOW_INSECURE_SERIALIZATION=1") + logger.warning_once( + "Allowing insecure serialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1" + ) def _typestr(val: Any) -> Optional[tuple[str, str]]: @@ -72,8 +81,8 @@ def _encode_type_info_recursive(obj: Any) -> Any: def _decode_type_info_recursive( - type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], - Any]) -> Any: + type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any] +) -> Any: """Recursively decode type information for nested structures of lists/dicts.""" if type_info is None: @@ -85,8 +94,9 @@ def _decode_type_info_recursive( for k in type_info } if isinstance(type_info, list) and ( - # Exclude serialized tensors/numpy arrays. - len(type_info) != 2 or not isinstance(type_info[0], str)): + # Exclude serialized tensors/numpy arrays. + len(type_info) != 2 or not isinstance(type_info[0], str) + ): assert isinstance(data, list) return [ _decode_type_info_recursive(ti, d, convert_fn) @@ -101,7 +111,7 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. - By default, arrays below 256B are serialized inline Larger will get sent + By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. """ @@ -119,7 +129,7 @@ class MsgpackEncoder: def encode(self, obj: Any) -> Sequence[bytestr]: try: - self.aux_buffers = bufs = [b''] + self.aux_buffers = bufs = [b""] bufs[0] = self.encoder.encode(obj) # This `bufs` list allows us to collect direct pointers to backing # buffers of tensors and np arrays, and return them along with the @@ -143,14 +153,15 @@ class MsgpackEncoder: return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. - if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): + if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"): return self._encode_ndarray(obj) if isinstance(obj, slice): # We are assuming only int-based values will be used here. return tuple( int(v) if v is not None else None - for v in (obj.start, obj.stop, obj.step)) + for v in (obj.start, obj.stop, obj.step) + ) if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) @@ -171,17 +182,20 @@ class MsgpackEncoder: return _encode_type_info_recursive(result), result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise TypeError(f"Object of type {type(obj)} is not serializable" - "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " - "fallback to pickle-based serialization.") + raise TypeError( + f"Object of type {type(obj)} is not serializable" + "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " + "fallback to pickle-based serialization." + ) if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) - return msgpack.Ext(CUSTOM_TYPE_PICKLE, - pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + return msgpack.Ext( + CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + ) def _encode_ndarray( self, obj: np.ndarray @@ -225,27 +239,22 @@ class MsgpackEncoder: for modality, itemlist in items.items() } - def _encode_mm_item(self, - item: MultiModalKwargsItem) -> list[dict[str, Any]]: + def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: return [self._encode_mm_field_elem(elem) for elem in item.values()] - def _encode_mm_field_elem(self, - elem: MultiModalFieldElem) -> dict[str, Any]: + def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]: return { - "modality": - elem.modality, - "key": - elem.key, - "data": (None if elem.data is None else - self._encode_nested_tensors(elem.data)), - "field": - self._encode_mm_field(elem.field), + "modality": elem.modality, + "key": elem.key, + "data": ( + None if elem.data is None else self._encode_nested_tensors(elem.data) + ), + "field": self._encode_mm_field(elem.field), } def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: return { - modality: self._encode_nested_tensors(data) - for modality, data in kw.items() + modality: self._encode_nested_tensors(data) for modality, data in kw.items() } def _encode_nested_tensors(self, nt: NestedTensors) -> Any: @@ -264,8 +273,7 @@ class MsgpackEncoder: raise TypeError(f"Unsupported field type: {field.__class__}") # We just need to copy all of the field values in order # which will be then used to reconstruct the field. - field_values = (getattr(field, f.name) - for f in dataclasses.fields(field)) + field_values = (getattr(field, f.name) for f in dataclasses.fields(field)) return name, *field_values @@ -277,10 +285,10 @@ class MsgpackDecoder: """ def __init__(self, t: Optional[Any] = None): - args = () if t is None else (t, ) - self.decoder = msgpack.Decoder(*args, - ext_hook=self.ext_hook, - dec_hook=self.dec_hook) + args = () if t is None else (t,) + self.decoder = msgpack.Decoder( + *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook + ) self.aux_buffers: Sequence[bytestr] = () if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -320,11 +328,14 @@ class MsgpackDecoder: result_type, result = obj if result_type is not None: if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must " - "be set to use custom utility result types") + raise TypeError( + "VLLM_ALLOW_INSECURE_SERIALIZATION must " + "be set to use custom utility result types" + ) # Use recursive decoding to handle nested structures - result = _decode_type_info_recursive(result_type, result, - self._convert_result) + result = _decode_type_info_recursive( + result_type, result, self._convert_result + ) return UtilityResult(result) def _convert_result(self, result_type: Sequence[str], result: Any) -> Any: @@ -347,8 +358,7 @@ class MsgpackDecoder: # Copy from inline representation, to decouple the memory storage # of the message from the original buffer. And also make Torch # not complain about a readonly memoryview. - buffer = self.aux_buffers[data] if isinstance(data, int) \ - else bytearray(data) + buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) if not buffer: # torch.frombuffer doesn't like empty buffers @@ -360,17 +370,19 @@ class MsgpackDecoder: return arr.view(torch_dtype).view(shape) def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: - return MultiModalKwargsItems({ - modality: [self._decode_mm_item(item) for item in itemlist] - for modality, itemlist in obj.items() - }) + return MultiModalKwargsItems( + { + modality: [self._decode_mm_item(item) for item in itemlist] + for modality, itemlist in obj.items() + } + ) def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: return MultiModalKwargsItem.from_elems( - [self._decode_mm_field_elem(v) for v in obj]) + [self._decode_mm_field_elem(v) for v in obj] + ) - def _decode_mm_field_elem(self, obj: dict[str, - Any]) -> MultiModalFieldElem: + def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem: if obj["data"] is not None: obj["data"] = self._decode_nested_tensors(obj["data"]) @@ -387,10 +399,12 @@ class MsgpackDecoder: return MultiModalFieldElem(**obj) def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: - return MultiModalKwargs({ - modality: self._decode_nested_tensors(data) - for modality, data in obj.items() - }) + return MultiModalKwargs( + { + modality: self._decode_nested_tensors(data) + for modality, data in obj.items() + } + ) def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): @@ -419,5 +433,4 @@ class MsgpackDecoder: if code == CUSTOM_TYPE_CLOUDPICKLE: return cloudpickle.loads(data) - raise NotImplementedError( - f"Extension type code {code} is not supported") + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db01388..5d4822a627 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,8 +10,7 @@ import torch import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) +from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -23,11 +22,15 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.tree_attn import ( + TreeAttentionMetadata, + TreeAttentionMetadataBuilder, +) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -41,7 +44,6 @@ PADDING_SLOT_ID = -1 class EagleProposer: - def __init__( self, vllm_config: VllmConfig, @@ -59,10 +61,8 @@ class EagleProposer: self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -72,62 +72,64 @@ class EagleProposer: # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - vllm_config.model_config) + vllm_config.model_config + ) self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None - self.draft_indexer_metadata_builder: Optional[ - AttentionMetadataBuilder] = None + self.draft_indexer_metadata_builder: Optional[AttentionMetadataBuilder] = None self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] - self.use_cuda_graph = (not current_platform.is_xpu() - and self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager - and not self.speculative_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed(self.vllm_config.compilation_config. - cudagraph_capture_sizes)) if self.use_cuda_graph else [] + self.use_cuda_graph = ( + not current_platform.is_xpu() + and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and not self.vllm_config.model_config.enforce_eager + and not self.speculative_config.enforce_eager + ) + self.cudagraph_batch_sizes = ( + list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + if self.use_cuda_graph + else [] + ) # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) + self.input_ids = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=device + ) self.uses_mrope = self.vllm_config.model_config.uses_mrope if self.uses_mrope: # M-RoPE need (3, max_num_tokens) - self.mrope_positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) + self.mrope_positions = torch.zeros( + (3, self.max_num_tokens), dtype=torch.int64, device=device + ) else: # RoPE need (max_num_tokens,) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) - self.arange = torch.arange(max_num_slots_for_arange, - device=device, - dtype=torch.int32) + self.arange = torch.arange( + max_num_slots_for_arange, device=device, dtype=torch.int32 + ) self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) self.backup_next_token_ids = CpuGpuBuffer( max_batch_size, dtype=torch.int32, pin_memory=is_pin_memory_available(), device=device, - with_numpy=True) + with_numpy=True, + ) # Determine allowed attention backends once during initialization. self.allowed_attn_types: Optional[tuple] = None @@ -136,14 +138,15 @@ class EagleProposer: # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) + AiterFlashAttentionMetadata, + ) + rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) + self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) tree_depth = len(self.tree_choices[-1]) # Precompute per-level properties of the tree. num_drafts_per_level = [0] * tree_depth @@ -152,10 +155,12 @@ class EagleProposer: self.cu_drafts_per_level = [num_drafts_per_level[0]] self.child_drafts_per_level = [num_drafts_per_level[0]] for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) + self.cu_drafts_per_level.append( + self.cu_drafts_per_level[-1] + num_drafts_per_level[level] + ) + self.child_drafts_per_level.append( + num_drafts_per_level[level] // num_drafts_per_level[level - 1] + ) # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, @@ -188,8 +193,7 @@ class EagleProposer: last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], - torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -200,11 +204,12 @@ class EagleProposer: if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states + ) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids @@ -213,17 +218,20 @@ class EagleProposer: # FIXME: need to consider multiple kv_cache_groups ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ + ubatch_id + ] attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0) + common_attn_metadata=common_attn_metadata, draft_index=0 + ) # FIXME: support hybrid kv for draft model (remove separate indexer) if self.draft_indexer_metadata_builder: draft_indexer_metadata = ( self.draft_indexer_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0, - )) + ) + ) else: draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV @@ -235,8 +243,7 @@ class EagleProposer: assert draft_indexer_metadata is not None per_layer_attn_metadata[layer_name] = draft_indexer_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -259,9 +266,9 @@ class EagleProposer: input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self._get_positions(num_input_tokens), @@ -304,28 +311,30 @@ class EagleProposer: draft_token_ids = logits.argmax(dim=-1) - if self.allowed_attn_types is not None and \ - not isinstance(attn_metadata, self.allowed_attn_types): + if self.allowed_attn_types is not None and not isinstance( + attn_metadata, self.allowed_attn_types + ): raise ValueError( f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: " f"{type(attn_metadata)}. Supported types are: " - f"{self.allowed_attn_types}") + f"{self.allowed_attn_types}" + ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 - common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:batch_size + 1]).clone() + self.token_arange_np[: batch_size + 1] + ).clone() for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. @@ -344,14 +353,15 @@ class EagleProposer: exceeds_max_model_len = positions[0] >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where\ - (exceeds_max_model_len.unsqueeze(0), \ - torch.zeros_like(positions), positions) + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), + positions, + ) else: positions += 1 exceeds_max_model_len = positions >= self.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 @@ -359,11 +369,11 @@ class EagleProposer: # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, - 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.num_computed_tokens_cpu = ( common_attn_metadata.seq_lens_cpu - 1 + ) # Compute the slot mapping. if self.uses_mrope: @@ -372,26 +382,28 @@ class EagleProposer: else: block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1)) + dim=1, index=block_numbers.view(-1, 1) + ) block_ids = block_ids.view(-1) if self.uses_mrope: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions[0] % self.block_size) + block_ids * self.block_size + clamped_positions[0] % self.block_size + ) else: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + block_ids * self.block_size + clamped_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID) + exceeds_max_model_len, PADDING_SLOT_ID + ) # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, - draft_index=token_index + 1) + common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 + ) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata @@ -400,8 +412,9 @@ class EagleProposer: self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: - self.inputs_embeds[:batch_size] = \ - self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = self.model.get_input_embeddings( + input_ids + ) input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] @@ -410,9 +423,9 @@ class EagleProposer: inputs_embeds = None # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self._get_positions(input_batch_size), @@ -434,10 +447,12 @@ class EagleProposer: return draft_token_ids def prepare_next_token_ids_cpu( - self, sampled_token_ids: list[list[int]], - requests: dict[str, - CachedRequestState], gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + self, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids for each request based on the sampled @@ -456,23 +471,23 @@ class EagleProposer: # Get the next token id from the request state. req_id = req_ids[i] req_state = requests[req_id] - seq_len = (req_state.num_computed_tokens + - num_scheduled_tokens[req_id]) + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) + next_token_ids = torch.tensor( + next_token_ids, dtype=torch.int32, device=self.input_ids.device + ) return next_token_ids - def prepare_next_token_ids_padded(self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: torch.Tensor, - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int) -> \ - tuple[torch.Tensor, torch.Tensor]: + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids and the number of valid sampled tokens @@ -486,30 +501,34 @@ class EagleProposer: # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ]) + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item() + ) + for i in range(num_reqs) + ] + ) self.backup_next_token_ids.copy_to_gpu(num_reqs) # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = \ - discard_request_indices[:num_discarded_requests] + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] valid_sampled_token_ids_gpu = sampled_token_ids.clone() valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) + 0, discard_sampled_tokens_req_indices, -1 + ) # Generate a mask for all valid tokens within those requests max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, - dtype=torch.bool) + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool) else: - valid_mask = ( - (valid_sampled_token_ids_gpu != -1) & - (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) # Count the number of valid tokens in each request valid_sampled_tokens_count = valid_mask.sum(dim=1) @@ -521,22 +540,25 @@ class EagleProposer: # Get last valid token from each row # (assume undefined state where there is no valid token) selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) # Use last token if valid, pre-computed backup if not batch_size = valid_sampled_token_ids_gpu.shape[0] next_token_ids = torch.where( - last_valid_indices != -1, selected_tokens, - self.backup_next_token_ids.gpu[:batch_size]) + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) return next_token_ids, valid_sampled_tokens_count - def prepare_inputs_padded(self, - common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor) -> \ - tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + def prepare_inputs_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, @@ -545,21 +567,23 @@ class EagleProposer: used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] + - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] + ) num_rejected_tokens_gpu = torch.where( num_draft_tokens_gpu > 0, num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu)) + torch.zeros_like(num_draft_tokens_gpu), + ) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] @@ -569,8 +593,7 @@ class EagleProposer: seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -580,8 +603,9 @@ class EagleProposer: causal=True, ) - token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ - - num_rejected_tokens_gpu + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu + ) return spec_common_attn_metadata, token_indices, token_indices_to_sample @@ -596,10 +620,10 @@ class EagleProposer: hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].get_metadata_builder() - assert isinstance(tree_attn_metadata_builder, - TreeAttentionMetadataBuilder) + tree_attn_metadata_builder = self.runner.attn_groups[0][ + 0 + ].get_metadata_builder() + assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts @@ -608,31 +632,31 @@ class EagleProposer: if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view(batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list = [draft_token_ids] draft_hidden_states = hidden_states.view(batch_size, 1, -1) # Initialize empty tensors for concatenation with the level outputs. - tree_input_ids = torch.empty(0, - device=self.input_ids.device, - dtype=self.input_ids.dtype) - tree_positions = torch.empty(0, - device=self.positions.device, - dtype=self.positions.dtype) - tree_hidden_states = torch.empty(0, - device=self.hidden_states.device, - dtype=self.hidden_states.dtype) + tree_input_ids = torch.empty( + 0, device=self.input_ids.device, dtype=self.input_ids.dtype + ) + tree_positions = torch.empty( + 0, device=self.positions.device, dtype=self.positions.dtype + ) + tree_hidden_states = torch.empty( + 0, device=self.hidden_states.device, dtype=self.hidden_states.dtype + ) # Precompute the draft token positions. flattened_draft_positions = ( - positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) + positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :] + ) tree_depth = len(self.cu_drafts_per_level) for level in range(tree_depth - 1): # Get draft positions for RoPE. draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions + - total_num_drafts) >= self.max_model_len + exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. draft_positions = torch.where( @@ -644,27 +668,28 @@ class EagleProposer: if level_num_drafts > 1: # Repeat the positions for each draft at this level. draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=1) + level_num_drafts, dim=1 + ) if num_children > 1: # Repeat draft hidden states for each child. draft_hidden_states = draft_hidden_states.repeat_interleave( - num_children, dim=1) + num_children, dim=1 + ) # Concatenate the draft tokens, positions, and hidden states. - tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], - dim=1) - tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1) + tree_positions = torch.cat([tree_positions, draft_positions], dim=1) tree_hidden_states = torch.cat( - [tree_hidden_states, draft_hidden_states], dim=1) + [tree_hidden_states, draft_hidden_states], dim=1 + ) # Build new attention metadata for the next level of drafts. # This is necessary to support tree attention. query_len = total_num_drafts common_attn_metadata = replace( common_attn_metadata, - query_start_loc=query_len * self.arange[:batch_size + 1], + query_start_loc=query_len * self.arange[: batch_size + 1], seq_lens=common_attn_metadata.seq_lens + level_num_drafts, num_actual_tokens=batch_size * query_len, max_query_len=query_len, @@ -680,20 +705,20 @@ class EagleProposer: per_layer_attn_metadata[layer_name] = attn_metadata # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + attn_metadata.max_seq_len = min( + attn_metadata.max_seq_len, self.max_model_len + ) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] + query_positions = flattened_draft_positions[:, level : level + query_len] block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, - index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions % self.block_size) + block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) + slot_mapping = ( + block_ids * self.block_size + query_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -705,19 +730,16 @@ class EagleProposer: input_ids = tree_input_ids.view(-1) self.input_ids[:num_tokens] = input_ids self.positions[:num_tokens] = tree_positions.view(-1) - self.hidden_states[:num_tokens] = tree_hidden_states.view( - num_tokens, -1) + self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens) + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -727,28 +749,29 @@ class EagleProposer: # Get the output hidden states for the draft tokens. draft_hidden_states = hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] draft_last_hidden_states = last_hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] # Get the output logits for the draft tokens. logits = self.model.compute_logits( - draft_last_hidden_states.reshape(batch_size * level_num_drafts, - -1)) + draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1) + ) # Sample a draft token for each child at the next tree level. num_children = self.child_drafts_per_level[level + 1] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view( - batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list.append(draft_token_ids) # Update the # drafts counters for the next tree level. - level_num_drafts = self.cu_drafts_per_level[level + - 1] - total_num_drafts + level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list @@ -784,17 +807,14 @@ class EagleProposer: n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() @@ -804,7 +824,8 @@ class EagleProposer: new_query_start_loc_cpu = torch.zeros( query_start_loc_cpu.shape, dtype=torch.int32, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + ) new_query_start_loc_np = new_query_start_loc_cpu.numpy() np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) @@ -814,36 +835,36 @@ class EagleProposer: # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) + new_query_start_locs_expanded = np.repeat( + new_query_start_loc_np[:-1], new_num_tokens_per_req_np + ) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded + token_offests = ( + self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded + ) # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np + ) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) + token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), + query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -856,45 +877,52 @@ class EagleProposer: return spec_common_attn_metadata, token_indices def get_model_name(self, model: nn.Module) -> str: - if hasattr(model, 'module'): # multi-GPU + if hasattr(model, "module"): # multi-GPU model = model.module return model.__class__.__name__ def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache).keys()) + get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ).keys() + ) from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, model_config=draft_model_config + ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - indexer_layers = get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache) - draft_indexer_layer_names = (indexer_layers.keys() - - target_indexer_layer_names) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + - target_attn_layer_names + ) + indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ) + draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names self.attn_layer_names = list(draft_attn_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names) if self.indexer_layer_names: first_layer = self.indexer_layer_names[0] self.draft_indexer_metadata_builder = ( - indexer_layers[first_layer].get_attn_backend().get_builder_cls( - )( + indexer_layers[first_layer] + .get_attn_backend() + .get_builder_cls()( indexer_layers[first_layer].get_kv_cache_spec(), self.indexer_layer_names, self.vllm_config, self.device, - )) + ) + ) else: self.draft_indexer_metadata_builder = None @@ -902,38 +930,41 @@ class EagleProposer: # Even if the target model is multimodal, we can also use # text-only draft models try: - dummy_input_ids = torch.tensor([[1]], - device=self.input_ids.device) - self.model.get_input_embeddings(dummy_input_ids, - multimodal_embeddings=None) + dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device) + self.model.get_input_embeddings( + dummy_input_ids, multimodal_embeddings=None + ) except (NotImplementedError, AttributeError, TypeError): logger.warning( "Draft model does not support multimodal inputs, " - "falling back to text-only mode") + "falling back to text-only mode" + ) self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality - if (self.get_model_name(target_model) == - "Qwen2_5_VLForConditionalGeneration"): - self.model.config.image_token_index = ( - target_model.config.image_token_id) + if ( + self.get_model_name(target_model) + == "Qwen2_5_VLForConditionalGeneration" + ): + self.model.config.image_token_index = target_model.config.image_token_id else: self.model.config.image_token_index = ( - target_model.config.image_token_index) + target_model.config.image_token_index + ) target_language_model = target_model.get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - if hasattr(target_language_model.model, 'embed_tokens'): + if hasattr(target_language_model.model, "embed_tokens"): target_embed_tokens = target_language_model.model.embed_tokens - elif hasattr(target_language_model.model, 'embedding'): + elif hasattr(target_language_model.model, "embedding"): target_embed_tokens = target_language_model.model.embedding else: raise AttributeError( - "Target model does not have 'embed_tokens' or 'embedding' " - "attribute") + "Target model does not have 'embed_tokens' or 'embedding' attribute" + ) # Check if shapes match and we found the embedding eagle_shape = self.model.model.embed_tokens.weight.shape @@ -941,47 +972,53 @@ class EagleProposer: if eagle_shape == target_shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" - " with the target model.") + " with the target model." + ) del self.model.model.embed_tokens self.model.model.embed_tokens = target_embed_tokens else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM if self.vllm_config.speculative_config.method != "eagle3": if hasattr(target_language_model, "lm_head"): - logger.info( - "Loading EAGLE LM head weights from the target model.") + logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head else: - if (hasattr(self.model, "lm_head") - and hasattr(target_language_model, "lm_head") - and self.model.lm_head.weight.shape - == target_language_model.lm_head.weight.shape): - logger.info("Assuming the EAGLE head shares the same lm_head" - " with the target model.") + if ( + hasattr(self.model, "lm_head") + and hasattr(target_language_model, "lm_head") + and self.model.lm_head.weight.shape + == target_language_model.lm_head.weight.shape + ): + logger.info( + "Assuming the EAGLE head shares the same lm_head" + " with the target model." + ) del self.model.lm_head self.model.lm_head = target_language_model.lm_head else: logger.info( "The EAGLE head's lm_head will be loaded separately" - " from the target model.") + " from the target model." + ) @torch.inference_mode() def dummy_run( self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -996,8 +1033,7 @@ class EagleProposer: inputs_embeds=inputs_embeds, ) - def _get_attention_metadata_builder( - self) -> list[AttentionMetadataBuilder]: + def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]: """Find and return the attention metadata builders for EAGLE layers. Returns: @@ -1018,11 +1054,11 @@ class EagleProposer: break assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers.") + "Failed to find attention metadata builder for EAGLE layers." + ) return builder - def validate_same_kv_cache_group(self, - kv_cache_config: KVCacheConfig) -> None: + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ Validate that all eagle layers belong to the same KVCacheGroup. Need this assumption to ensure all eagle layers can use the @@ -1033,12 +1069,17 @@ class EagleProposer: for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( - set([ - kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + assert ( + len( + set( + [ + kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + ] + ) + ) + == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 70b29c05c2..150dde177c 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -27,10 +27,9 @@ class MedusaProposer: # Save config parameters self.vllm_config = vllm_config self.device = device - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.hidden_size = vllm_config.speculative_config.\ - draft_model_config.get_hidden_size( + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.hidden_size = ( + vllm_config.speculative_config.draft_model_config.get_hidden_size() ) self.dtype = vllm_config.model_config.dtype @@ -51,16 +50,19 @@ class MedusaProposer: def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag + with set_model_tag("medusa_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=self.vllm_config. - speculative_config.draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.speculative_config.draft_model_config, + ) @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: - hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device, + ) + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model(hidden_states) diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index b1efb40612..d0695244cb 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -8,7 +8,6 @@ import torch @dataclass class SpecDecodeMetadata: - # [num_tokens] draft_token_ids: torch.Tensor # [batch_size] @@ -36,22 +35,19 @@ class SpecDecodeMetadata: flattened_draft_token_ids = sum(draft_token_ids, []) num_tokens = len(flattened_draft_token_ids) - draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids, - dtype=torch.int32, - device=device) + draft_token_ids_tensor = torch.tensor( + flattened_draft_token_ids, dtype=torch.int32, device=device + ) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) - cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to( - device) + cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) - target_logits_indices = torch.zeros(num_tokens, - dtype=torch.int32, - device=device) - bonus_logits_indices = torch.zeros(batch_size, - dtype=torch.int32, - device=device) - logits_indices = torch.zeros(num_tokens + batch_size, - dtype=torch.int32, - device=device) + target_logits_indices = torch.zeros( + num_tokens, dtype=torch.int32, device=device + ) + bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device) + logits_indices = torch.zeros( + num_tokens + batch_size, dtype=torch.int32, device=device + ) return cls( draft_token_ids=draft_token_ids_tensor, num_draft_tokens=num_draft_tokens, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 282e6f65e7..89a8a11a3d 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -31,8 +31,10 @@ class SpecDecodingStats: @classmethod def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": - return cls(num_spec_tokens=num_spec_tokens, - num_accepted_tokens_per_pos=[0] * num_spec_tokens) + return cls( + num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens, + ) def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): self.num_drafts += 1 @@ -64,10 +66,10 @@ class SpecDecodingLogging: def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) - self.num_accepted_tokens.append( - spec_decoding_stats.num_accepted_tokens) + self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens) self.accepted_tokens_per_pos_lists.append( - spec_decoding_stats.num_accepted_tokens_per_pos) + spec_decoding_stats.num_accepted_tokens_per_pos + ) def log(self, log_fn=logger.info): if not self.num_drafts: @@ -83,8 +85,11 @@ class SpecDecodingLogging: draft_throughput = num_draft_tokens / elapsed_time accepted_throughput = num_accepted_tokens / elapsed_time - draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * - 100 if num_draft_tokens > 0 else float("nan")) + draft_acceptance_rate = ( + num_accepted_tokens / num_draft_tokens * 100 + if num_draft_tokens > 0 + else float("nan") + ) # Conventionally, mean acceptance length includes the bonus token mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) @@ -149,27 +154,36 @@ class SpecDecodingProm: counter_drafts = self._counter_cls( name="vllm:spec_decode_num_drafts", documentation="Number of spec decoding drafts.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_drafts = make_per_engine( - counter_drafts, per_engine_labelvalues) + counter_drafts, per_engine_labelvalues + ) counter_draft_tokens = self._counter_cls( name="vllm:spec_decode_num_draft_tokens", documentation="Number of draft tokens.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_draft_tokens = make_per_engine( - counter_draft_tokens, per_engine_labelvalues) + counter_draft_tokens, per_engine_labelvalues + ) counter_accepted_tokens = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens", documentation="Number of accepted tokens.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_accepted_tokens = make_per_engine( - counter_accepted_tokens, per_engine_labelvalues) + counter_accepted_tokens, per_engine_labelvalues + ) assert speculative_config is not None - num_spec_tokens = (speculative_config.num_speculative_tokens - if self.spec_decoding_enabled else 0) + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if self.spec_decoding_enabled + else 0 + ) pos_labelnames = labelnames + ["position"] base_counter = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_per_pos", @@ -177,33 +191,33 @@ class SpecDecodingProm: labelnames=pos_labelnames, ) self.counter_spec_decode_num_accepted_tokens_per_pos: dict[ - int, list[prometheus_client.Counter]] = { - idx: [ - base_counter.labels(*lv, str(pos)) - for pos in range(num_spec_tokens) - ] - for idx, lv in per_engine_labelvalues.items() - } + int, list[prometheus_client.Counter] + ] = { + idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)] + for idx, lv in per_engine_labelvalues.items() + } - def observe(self, - spec_decoding_stats: SpecDecodingStats, - engine_idx: int = 0): + def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0): if not self.spec_decoding_enabled: return self.counter_spec_decode_num_drafts[engine_idx].inc( - spec_decoding_stats.num_drafts) + spec_decoding_stats.num_drafts + ) self.counter_spec_decode_num_draft_tokens[engine_idx].inc( - spec_decoding_stats.num_draft_tokens) + spec_decoding_stats.num_draft_tokens + ) self.counter_spec_decode_num_accepted_tokens[engine_idx].inc( - spec_decoding_stats.num_accepted_tokens) + spec_decoding_stats.num_accepted_tokens + ) for pos, counter in enumerate( - self. - counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]): + self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx] + ): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) -def make_per_engine(counter: prometheus_client.Counter, - per_engine_labelvalues: dict[int, list[str]]): +def make_per_engine( + counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]] +): """Create a counter for each label value.""" return { idx: counter.labels(*labelvalues) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index aed050a354..e2f83cb24a 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig class NgramProposer: - def __init__(self, vllm_config: VllmConfig): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.prompt_lookup_min is not None @@ -28,8 +27,7 @@ class NgramProposer: # Pre-allocate buffers for numba batch propose. max_num_seqs = vllm_config.scheduler_config.max_num_seqs - self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), - dtype=np.int32) + self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32) self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) # Threshold of total number of tokens in the batch to enable @@ -55,9 +53,13 @@ class NgramProposer: # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. - self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), - np.zeros((1024, self.max_model_len), dtype=np.int32), - set()) + self.propose( + [[]] * 1024, + [""] * 1024, + np.zeros(1024, dtype=np.int32), + np.zeros((1024, self.max_model_len), dtype=np.int32), + set(), + ) def batch_propose( self, @@ -67,20 +69,20 @@ class NgramProposer: token_ids_cpu: np.ndarray, ) -> list[list[int]]: """Batch version of ngram proposer using numba for acceleration. - + Args: - valid_ngram_requests: + valid_ngram_requests: Set of indices of requests that need ngram proposals. - num_tokens_no_spec: - Numpy array of shape (batch_size,) representing the number + num_tokens_no_spec: + Numpy array of shape (batch_size,) representing the number of tokens without speculative tokens for each request. - token_ids_cpu: - Numpy array of shape (batch_size, max_model_len) + token_ids_cpu: + Numpy array of shape (batch_size, max_model_len) representing the token IDs for each request. Returns: - list[list[int]]: - A list where each element is a list of proposed + list[list[int]]: + A list where each element is a list of proposed token IDs for the corresponding request. """ draft_token_ids: list[list[int]] = [] @@ -96,26 +98,32 @@ class NgramProposer: total_tokens = np.sum(num_tokens_no_spec) if total_tokens >= self.num_tokens_threshold: final_num_threads = max( - 1, min(self.num_numba_thread_available, - num_ngram_requests)) + 1, min(self.num_numba_thread_available, num_ngram_requests) + ) set_num_threads(final_num_threads) else: set_num_threads(1) - batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, - token_ids_cpu, self.min_n, self.max_n, - self.max_model_len, self.k, - self.valid_ngram_draft, - self.valid_ngram_num_drafts) + batch_propose_numba( + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + self.min_n, + self.max_n, + self.max_model_len, + self.k, + self.valid_ngram_draft, + self.valid_ngram_num_drafts, + ) # Restore original number of threads. set_num_threads(original_num_numba_threads) for i in range(num_requests): - if i in valid_ngram_requests and \ - self.valid_ngram_num_drafts[i] > 0: - draft_token_ids.append(self.valid_ngram_draft[ - i, :self.valid_ngram_num_drafts[i]].tolist()) + if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0: + draft_token_ids.append( + self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist() + ) else: draft_token_ids.append([]) @@ -129,7 +137,6 @@ class NgramProposer: token_ids_cpu: np.ndarray, spec_decode_unsupported_reqs: set, ) -> list[list[int]]: - # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): @@ -166,12 +173,17 @@ class NgramProposer: @njit(parallel=True) -def batch_propose_numba(valid_ngram_requests: list, - num_tokens_no_spec: np.ndarray, - token_ids_cpu: np.ndarray, min_n: int, max_n: int, - max_model_len: int, k: int, - valid_ngram_draft: np.ndarray, - valid_ngram_num_drafts: np.ndarray): +def batch_propose_numba( + valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + min_n: int, + max_n: int, + max_model_len: int, + k: int, + valid_ngram_draft: np.ndarray, + valid_ngram_num_drafts: np.ndarray, +): for i in prange(len(valid_ngram_requests)): idx = valid_ngram_requests[i] num_tokens = num_tokens_no_spec[idx] @@ -181,19 +193,22 @@ def batch_propose_numba(valid_ngram_requests: list, min_ngram=min_n, max_ngram=max_n, max_model_len=max_model_len, - k=k) + k=k, + ) valid_ngram_num_drafts[i] = drafter_output.shape[0] if len(drafter_output): - valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output + valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output @jit(nopython=True) -def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, - min_ngram: int, - max_ngram: int, - max_model_len: int, - k: int) -> np.ndarray: +def _find_longest_matched_ngram_and_propose_tokens( + origin_tokens: np.ndarray, + min_ngram: int, + max_ngram: int, + max_model_len: int, + k: int, +) -> np.ndarray: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). @@ -203,12 +218,12 @@ def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: - return np.empty((0, ), dtype=origin_tokens.dtype) + return np.empty((0,), dtype=origin_tokens.dtype) # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: - return np.empty((0, ), dtype=origin_tokens.dtype) + return np.empty((0,), dtype=origin_tokens.dtype) # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with @@ -265,7 +280,7 @@ def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, if longest_ngram < min_ngram: # No valid ngram is found - return np.empty((0, ), dtype=origin_tokens.dtype) + return np.empty((0,), dtype=origin_tokens.dtype) # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] @@ -273,4 +288,4 @@ def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, # total_token-1-position+longest_ngram start_position = total_token - 1 - position + longest_ngram k = min(k, total_token - start_position) - return origin_tokens[start_position:start_position + k] + return origin_tokens[start_position : start_position + k] diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 1116179dc5..1901c6fc9f 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -7,8 +7,10 @@ _SAMPLING_EPS = 1e-5 def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: """True if request is incompatible with speculative decoding""" - return (sampling_params.frequency_penalty != 0.0 - or sampling_params.presence_penalty != 0.0 - or sampling_params.repetition_penalty != 1.0 - or sampling_params.min_p > _SAMPLING_EPS - or sampling_params.logprobs is not None) + return ( + sampling_params.frequency_penalty != 0.0 + or sampling_params.presence_penalty != 0.0 + or sampling_params.repetition_penalty != 1.0 + or sampling_params.min_p > _SAMPLING_EPS + or sampling_params.logprobs is not None + ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 13c33d3edf..1f51f98ca9 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -12,8 +12,10 @@ from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, +) from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: @@ -48,8 +50,7 @@ class StructuredOutputManager: # - at least 1 CPU # - at most half the number of CPUs or 8, whichever is less max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) - self.executor_for_fillmask = ThreadPoolExecutor( - max_workers=max_workers) + self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers) if not self.vllm_config.model_config.skip_tokenizer_init: # The default max_workers if not specified is the number of @@ -60,12 +61,15 @@ class StructuredOutputManager: max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config) - reasoning_parser = \ - self.vllm_config.structured_outputs_config.reasoning_parser + model_config=self.vllm_config.model_config + ) + reasoning_parser = ( + self.vllm_config.structured_outputs_config.reasoning_parser + ) if reasoning_parser: reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_parser) + reasoning_parser + ) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: @@ -73,8 +77,10 @@ class StructuredOutputManager: return if TYPE_CHECKING: - assert request.sampling_params is not None and \ - request.sampling_params.structured_outputs is not None + assert ( + request.sampling_params is not None + and request.sampling_params.structured_outputs is not None + ) # Initialize the backend the first time it is needed. # @@ -98,8 +104,7 @@ class StructuredOutputManager: vocab_size=vocab_size, ) elif backend == "outlines": - from vllm.v1.structured_output.backend_outlines import ( - OutlinesBackend) + from vllm.v1.structured_output.backend_outlines import OutlinesBackend self.backend = OutlinesBackend( self.vllm_config, @@ -108,15 +113,16 @@ class StructuredOutputManager: ) elif backend == "lm-format-enforcer": from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 - LMFormatEnforcerBackend) + LMFormatEnforcerBackend, + ) + self.backend = LMFormatEnforcerBackend( self.vllm_config, tokenizer=self.tokenizer, vocab_size=vocab_size, ) else: - raise ValueError( - f"Unsupported structured output backend: {backend}") + raise ValueError(f"Unsupported structured output backend: {backend}") grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] @@ -169,8 +175,9 @@ class StructuredOutputManager: max_num_spec_tokens = 0 if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = \ + max_num_spec_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens + ) if self._grammar_bitmask is None: assert self.backend is not None @@ -179,22 +186,23 @@ class StructuredOutputManager: # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the # bonus token / non-speculative token. - self._grammar_bitmask = \ - self.backend.allocate_token_bitmask( - max_batch_size * (1 + max_num_spec_tokens)) + self._grammar_bitmask = self.backend.allocate_token_bitmask( + max_batch_size * (1 + max_num_spec_tokens) + ) # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. # These are stored inline in the tensor and unpacked by the gpu runner. cumulative_index = 0 - ordered_seq = sorted(structured_output_request_ids.items(), - key=lambda x: x[1]) + ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) # Optimized parallel filling of bitmasks for # non-spec, large-batch-size cases - if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \ - max_num_spec_tokens == 0: + if ( + len(ordered_seq) > self.fill_bitmask_parallel_threshold + and max_num_spec_tokens == 0 + ): promises = [] batch = [] for req_id, _ in ordered_seq: @@ -205,8 +213,9 @@ class StructuredOutputManager: assert structured_output_request.grammar is not None apply_bitmask = self.should_fill_bitmask(request) - batch.append((structured_output_request.grammar, - cumulative_index, apply_bitmask)) + batch.append( + (structured_output_request.grammar, cumulative_index, apply_bitmask) + ) if len(batch) == self.fill_bitmask_parallel_batch_size: promises.append(self._async_submit_fill_bitmask(batch)) batch = [] @@ -232,18 +241,28 @@ class StructuredOutputManager: state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) for i, token in enumerate(req_tokens + [None]): - self._fill_bitmasks([(structured_output_request.grammar, - cumulative_index, apply_bitmask)]) + self._fill_bitmasks( + [ + ( + structured_output_request.grammar, + cumulative_index, + apply_bitmask, + ) + ] + ) - if apply_bitmask and token is not None and \ - not structured_output_request.grammar.is_terminated(): + if ( + apply_bitmask + and token is not None + and not structured_output_request.grammar.is_terminated() + ): assert structured_output_request.grammar.accept_tokens( - req_id, [token]) + req_id, [token] + ) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: - structured_output_request.grammar.rollback( - state_advancements) + structured_output_request.grammar.rollback(state_advancements) bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]: @@ -258,8 +277,9 @@ class StructuredOutputManager: if self.reasoner is not None: assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: - request.structured_output_request.reasoning_ended = \ + request.structured_output_request.reasoning_ended = ( self.reasoner.is_reasoning_end(request.prompt_token_ids) + ) return request.structured_output_request.reasoning_ended return True diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index e06ab6377d..a48a705e8f 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -14,9 +14,11 @@ import torch from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) from vllm.v1.structured_output.request import get_structured_output_key if TYPE_CHECKING: @@ -26,8 +28,7 @@ if TYPE_CHECKING: else: llguidance = LazyLoader("llguidance", globals(), "llguidance") llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") - llguidance_torch = LazyLoader("llguidance.torch", globals(), - "llguidance.torch") + llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") logger = init_logger(__name__) @@ -36,16 +37,18 @@ def _walk_json_for_additional_properties(data: object): if isinstance(data, dict): for value in data.values(): _walk_json_for_additional_properties(value) - if 'additionalProperties' not in data and \ - ('properties' in data or 'patternProperties' in data): - data['additionalProperties'] = False + if "additionalProperties" not in data and ( + "properties" in data or "patternProperties" in data + ): + data["additionalProperties"] = False elif isinstance(data, list): for item in data: _walk_json_for_additional_properties(item) def process_for_additional_properties( - guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + guide_json: Union[str, dict[str, Any]], +) -> dict[str, Any]: if isinstance(guide_json, str): guide_json_obj = json.loads(guide_json) else: @@ -57,21 +60,27 @@ def process_for_additional_properties( @dataclass class GuidanceBackend(StructuredOutputBackend): - def __post_init__(self): - self.disable_any_whitespace = \ + self.disable_any_whitespace = ( self.vllm_config.structured_outputs_config.disable_any_whitespace - self.disable_additional_properties = \ + ) + self.disable_additional_properties = ( self.vllm_config.structured_outputs_config.disable_additional_properties + ) self.ll_tokenizer = llguidance_hf.from_tokenizer( - self.tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec, self.disable_any_whitespace, - self.disable_additional_properties) + request_type, + grammar_spec, + self.disable_any_whitespace, + self.disable_additional_properties, + ) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -90,7 +99,8 @@ class GuidanceBackend(StructuredOutputBackend): def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( - max_num_seqs, self.ll_tokenizer.vocab_size) + max_num_seqs, self.ll_tokenizer.vocab_size + ) def destroy(self): pass @@ -178,15 +188,17 @@ def serialize_guidance_grammar( disable_any_whitespace: bool = False, disable_additional_properties: bool = False, ) -> str: - - def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str: + def _process_schema( + grammar_spec: Union[str, dict[str, Any]], + ) -> str: if disable_additional_properties: grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ "whitespace_flexible": not disable_any_whitespace, - }) + }, + ) if request_type == StructuredOutputOptions.JSON: return _process_schema(grammar_spec) @@ -195,7 +207,8 @@ def serialize_guidance_grammar( '{"type": "object"}', defaults={ "whitespace_flexible": not disable_any_whitespace, - }) + }, + ) else: if request_type == StructuredOutputOptions.REGEX: tp = "regex" @@ -215,29 +228,32 @@ def serialize_guidance_grammar( trig = next((t for t in triggers if begin.startswith(t)), None) if trig is None: raise ValueError( - f"Trigger {begin} not found in triggers {triggers}") + f"Trigger {begin} not found in triggers {triggers}" + ) tags.append( llguidance.StructTag( trigger=trig, begin=s["begin"], grammar=_process_schema(s["schema"]), end=s["end"], - )) + ) + ) if not tags: - raise ValueError( - "No structural tags found in the grammar spec.") + raise ValueError("No structural tags found in the grammar spec.") return llguidance.StructTag.to_grammar(tags) else: - logger.error("Validation should have already occurred. " - "Please file an issue.") - raise ValueError("grammar is not of valid supported types. " - f"({request_type!s})") + logger.error( + "Validation should have already occurred. Please file an issue." + ) + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})" + ) return llguidance.grammar_from(tp, grammar_spec) def validate_guidance_grammar( - sampling_params: SamplingParams, - tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: + sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None +) -> None: tp, grm = get_structured_output_key(sampling_params) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py index 465b2428f8..d9e484092d 100644 --- a/vllm/v1/structured_output/backend_lm_format_enforcer.py +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -13,26 +13,31 @@ from transformers import PreTrainedTokenizerBase from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) if TYPE_CHECKING: import lmformatenforcer import lmformatenforcer.integrations.vllm as lmfe_vllm else: - lmformatenforcer = LazyLoader("lmformatenforcer", globals(), - "lmformatenforcer") - lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), - "lmformatenforcer.integrations.vllm") + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), "lmformatenforcer") + lmfe_vllm = LazyLoader( + "lmformatenforcer.integrations.vllm", + globals(), + "lmformatenforcer.integrations.vllm", + ) @lru_cache def _cached_build_vllm_token_enforcer_tokenizer_data( - tokenizer: PreTrainedTokenizerBase, - vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: + tokenizer: PreTrainedTokenizerBase, vocab_size: int +) -> lmfe_vllm.TokenEnforcerTokenizerData: return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( - tokenizer, use_bitmask=True, vocab_size=vocab_size) + tokenizer, use_bitmask=True, vocab_size=vocab_size + ) @dataclass @@ -44,7 +49,8 @@ class LMFormatEnforcerGrammar(StructuredOutputGrammar): original_len = len(self.current_tokens_prefix) for token in tokens: if not self.token_enforcer.get_allowed_tokens( - self.current_tokens_prefix).is_token_allowed(token): + self.current_tokens_prefix + ).is_token_allowed(token): # Rollback partial updates to ensure atomicity. del self.current_tokens_prefix[original_len:] return False @@ -56,8 +62,8 @@ class LMFormatEnforcerGrammar(StructuredOutputGrammar): prefix = tokens[:prefix_length] next_token = tokens[prefix_length] if not self.token_enforcer.get_allowed_tokens( - self.current_tokens_prefix + - prefix).is_token_allowed(next_token): + self.current_tokens_prefix + prefix + ).is_token_allowed(next_token): break else: return tokens @@ -69,14 +75,16 @@ class LMFormatEnforcerGrammar(StructuredOutputGrammar): def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: allowed_tokens = self.token_enforcer.get_allowed_tokens( - self.current_tokens_prefix) + self.current_tokens_prefix + ) bitmask[batch_index] = allowed_tokens.allowed_tokens def is_terminated(self) -> bool: # We are considered terminated if the prefix ends with eos_token_id - return_value = len( - self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ - -1] == self.token_enforcer.eos_token_id + return_value = ( + len(self.current_tokens_prefix) > 0 + and self.current_tokens_prefix[-1] == self.token_enforcer.eos_token_id + ) return return_value def reset(self): @@ -85,18 +93,18 @@ class LMFormatEnforcerGrammar(StructuredOutputGrammar): @dataclass class LMFormatEnforcerBackend(StructuredOutputBackend): - def __post_init__(self): self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( - self.tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: character_level_parser: lmformatenforcer.CharacterLevelParser if request_type == StructuredOutputOptions.JSON: spec_dict = json.loads(grammar_spec) - character_level_parser = lmformatenforcer.JsonSchemaParser( - spec_dict) + character_level_parser = lmformatenforcer.JsonSchemaParser(spec_dict) elif request_type == StructuredOutputOptions.JSON_OBJECT: character_level_parser = lmformatenforcer.JsonSchemaParser(None) elif request_type == StructuredOutputOptions.REGEX: @@ -104,14 +112,17 @@ class LMFormatEnforcerBackend(StructuredOutputBackend): elif request_type == StructuredOutputOptions.CHOICE: choices = ast.literal_eval(grammar_spec) character_level_parser = lmformatenforcer.UnionParser( - [lmformatenforcer.StringParser(choice) for choice in choices]) + [lmformatenforcer.StringParser(choice) for choice in choices] + ) else: raise ValueError( - "Invalid request type for LM Format Enforcer backend" - f"({request_type!s})") + f"Invalid request type for LM Format Enforcer backend({request_type!s})" + ) max_rollback_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config is not None else 0) + if self.vllm_config.speculative_config is not None + else 0 + ) if max_rollback_tokens > 0: raise ValueError( @@ -136,8 +147,7 @@ class LMFormatEnforcerBackend(StructuredOutputBackend): pass -def validate_structured_output_request_lm_format_enforcer( - params: SamplingParams): +def validate_structured_output_request_lm_format_enforcer(params: SamplingParams): if params.structured_outputs is None: return @@ -163,5 +173,7 @@ def validate_structured_output_request_lm_format_enforcer( elif so_params.choice: return elif so_params.grammar: - raise ValueError("LM Format Enforcer structured outputs backend " - "does not support grammar specifications") + raise ValueError( + "LM Format Enforcer structured outputs backend " + "does not support grammar specifications" + ) diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py index e5e638a6ad..c987533717 100644 --- a/vllm/v1/structured_output/backend_outlines.py +++ b/vllm/v1/structured_output/backend_outlines.py @@ -15,20 +15,23 @@ from regex import escape as regex_escape from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) -from vllm.v1.structured_output.utils import (OutlinesVocabulary, - get_outlines_cache, - get_outlines_vocabulary) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) +from vllm.v1.structured_output.utils import ( + OutlinesVocabulary, + get_outlines_cache, + get_outlines_vocabulary, +) if TYPE_CHECKING: import outlines_core as oc import outlines_core.json_schema as json_schema else: oc = LazyLoader("oc", globals(), "outlines_core") - json_schema = LazyLoader("json_schema", globals(), - "outlines_core.json_schema") + json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema") # Python 3.11+ sre_parse and sre_constants # are deprecated, so we must import them from re @@ -46,13 +49,13 @@ else: @dataclass class OutlinesBackend(StructuredOutputBackend): - def __post_init__(self): self.vocabulary = get_outlines_vocabulary(self.tokenizer) self.cache = get_outlines_cache() - def _compile_index(self, regex_string: str, - vocabulary: OutlinesVocabulary) -> oc.Index: + def _compile_index( + self, regex_string: str, vocabulary: OutlinesVocabulary + ) -> oc.Index: cache_key = f"{vocabulary._hash}_{regex_string}" if cache_key in self.cache: return self.cache[cache_key] @@ -62,8 +65,9 @@ class OutlinesBackend(StructuredOutputBackend): return index - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: regex = json_schema.build_regex_from_schema(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -79,10 +83,13 @@ class OutlinesBackend(StructuredOutputBackend): index = self._compile_index(regex, self.vocabulary) max_rollback_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config is not None else 0) - return OutlinesGrammar(vocab_size=self.vocab_size, - guide=oc.Guide( - index, max_rollback=max_rollback_tokens)) + if self.vllm_config.speculative_config is not None + else 0 + ) + return OutlinesGrammar( + vocab_size=self.vocab_size, + guide=oc.Guide(index, max_rollback=max_rollback_tokens), + ) def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: return torch.full( @@ -98,20 +105,15 @@ class OutlinesBackend(StructuredOutputBackend): @dataclass class OutlinesGrammar(StructuredOutputGrammar): - vocab_size: int guide: oc.Guide = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) + num_processed_tokens: int = field( + default_factory=lambda: 0, repr=False, hash=False, init=False + ) # outlines_core signals done on DFA accept; vLLM expects done after EOS. # We delay the finished flag by one step so EOS can still be emitted. - _prev_finished: bool = field(default=False, - init=False, - repr=False, - hash=False) + _prev_finished: bool = field(default=False, init=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -142,8 +144,7 @@ class OutlinesGrammar(StructuredOutputGrammar): def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: mask = bitmask[idx] - self.guide.write_mask_into(mask.data_ptr(), mask.numel(), - mask.element_size()) + self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size()) def is_terminated(self) -> bool: curr = self.guide.is_finished() @@ -187,8 +188,10 @@ def validate_structured_output_request_outlines(params: SamplingParams): regex = "(" + "|".join(choices) + ")" validate_regex_is_buildable(regex) elif so_params.grammar: - raise ValueError("Outlines structured outputs backend " - "does not support grammar specifications") + raise ValueError( + "Outlines structured outputs backend " + "does not support grammar specifications" + ) def _prefix_needs_context(parsed) -> bool: @@ -196,7 +199,7 @@ def _prefix_needs_context(parsed) -> bool: def subpattern_consumes(parsed) -> bool: """Return True if subpattern can consume at least one character.""" - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: # literal, character class, or dot always consumes if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): @@ -212,17 +215,18 @@ def _prefix_needs_context(parsed) -> bool: if any(subpattern_consumes(br) for br in branches): return True # grouped subpattern: recurse into its contents - elif ttype == sre_parse.SUBPATTERN and subpattern_consumes( - tval[3]): + elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]): return True # No consumers, return False return False - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: # Direct anchors or look-around - if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT, - sre_constants.ASSERT_NOT): + if ttype == sre_parse.AT or ttype in ( + sre_constants.ASSERT, + sre_constants.ASSERT_NOT, + ): return True # Nested subpattern: check @@ -261,9 +265,8 @@ def _prefix_needs_context(parsed) -> bool: def _check_unsupported(parsed) -> None: """Check for regex features unsupported by regex-automata""" - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: - # backreference if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS): raise ValueError("Backreferences are unsupported.") @@ -274,8 +277,7 @@ def _check_unsupported(parsed) -> None: # unicode word boundaries elif ttype == sre_parse.AT: - if tval in (sre_constants.AT_BOUNDARY, - sre_constants.AT_NON_BOUNDARY): + if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY): raise ValueError("Unicode word boundaries are unsupported.") elif ttype == sre_parse.BRANCH: @@ -308,7 +310,8 @@ def validate_regex_is_buildable(pattern: str) -> None: raise ValueError( f"Regex uses unsupported feature for structured outputs: {e}. " "Only basic matching constructs are supported—lookarounds, " - "backreferences, and unicode boundaries are not.") from e + "backreferences, and unicode boundaries are not." + ) from e if _prefix_needs_context(parsed): raise ValueError( @@ -317,4 +320,5 @@ def validate_regex_is_buildable(pattern: str) -> None: "in a way which requires context before any token is matched." "structured outputs needs regexes that can match without needing " "that context. Try rewriting the pattern without using these " - f"constructs. Pattern:\n{pattern}") + f"constructs. Pattern:\n{pattern}" + ) diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 9a53aa7a1a..2051b336e5 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -103,8 +103,9 @@ class StructuredOutputBackend(ABC): vocab_size: int @abstractmethod - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: """ Compiles a grammar specification into a structured output grammar. diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index a853e65407..9f81d09633 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -14,12 +14,16 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) -from vllm.v1.structured_output.utils import (choice_as_grammar, - convert_lark_to_ebnf, - grammar_is_likely_lark) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) +from vllm.v1.structured_output.utils import ( + choice_as_grammar, + convert_lark_to_ebnf, + grammar_is_likely_lark, +) if TYPE_CHECKING: import xgrammar as xgr @@ -31,10 +35,10 @@ logger = init_logger(__name__) @dataclass class XgrammarBackend(StructuredOutputBackend): - def __post_init__(self): - self.disable_any_whitespace = \ + self.disable_any_whitespace = ( self.vllm_config.structured_outputs_config.disable_any_whitespace + ) if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. @@ -44,27 +48,33 @@ class XgrammarBackend(StructuredOutputBackend): encoded_vocab = self.tokenizer._vocab else: encoded_vocab = [ - token for token, _ in sorted( + token + for token, _ in sorted( self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None - if (hasattr( + if ( + hasattr( self.tokenizer, "eos_token_id", - ) and self.tokenizer.eos_token_id is not None): + ) + and self.tokenizer.eos_token_id is not None + ): stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " f"{type(self.tokenizer)}. The tokenizer should have a " - "get_vocab method.") from e + "get_vocab method." + ) from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken + else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, @@ -83,18 +93,21 @@ class XgrammarBackend(StructuredOutputBackend): self.num_speculative_tokens = 0 if self.vllm_config.speculative_config is not None: - self.num_speculative_tokens = \ + self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: ctx = self.compiler.compile_json_schema( - grammar_spec, any_whitespace=not self.disable_any_whitespace) + grammar_spec, any_whitespace=not self.disable_any_whitespace + ) elif request_type == StructuredOutputOptions.JSON_OBJECT: ctx = self.compiler.compile_json_schema( - '{"type": "object"}', - any_whitespace=not self.disable_any_whitespace) + '{"type": "object"}', any_whitespace=not self.disable_any_whitespace + ) elif request_type == StructuredOutputOptions.GRAMMAR: ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -106,17 +119,20 @@ class XgrammarBackend(StructuredOutputBackend): begin=s["begin"], schema=json.dumps(s["schema"]), end=s["end"], - ) for s in s_tag["structures"] + ) + for s in s_tag["structures"] ] structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"]) + tags, s_tag["triggers"] + ) ctx = self.compiler.compile_structural_tag(structural_tag) else: logger.error( "Validation should have already occurred. Please file an issue." ) raise ValueError( - f"grammar is not of valid supported types. ({request_type!s})") + f"grammar is not of valid supported types. ({request_type!s})" + ) return XgrammarGrammar( matcher=xgr.GrammarMatcher( @@ -146,10 +162,9 @@ class XgrammarGrammar(StructuredOutputGrammar): vocab_size: int matcher: xgr.GrammarMatcher = field(hash=False) ctx: xgr.CompiledGrammar = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) + num_processed_tokens: int = field( + default_factory=lambda: 0, repr=False, hash=False, init=False + ) _is_terminated: bool = field(default=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: @@ -164,7 +179,10 @@ class XgrammarGrammar(StructuredOutputGrammar): if not self.matcher.accept_token(token): logger.error( "Failed to advance FSM for request %s " - "for tokens %s. Please file an issue.", request_id, token) + "for tokens %s. Please file an issue.", + request_id, + token, + ) return False self.num_processed_tokens += 1 self._is_terminated = self.matcher.is_terminated() @@ -216,8 +234,9 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Check for array unsupported keywords if obj.get("type") == "array" and any( - key in obj for key in ("uniqueItems", "contains", - "minContains", "maxContains")): + key in obj + for key in ("uniqueItems", "contains", "minContains", "maxContains") + ): return True # Unsupported keywords for strings @@ -226,8 +245,14 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Unsupported keywords for objects if obj.get("type") == "object" and any( - key in obj for key in ("minProperties", "maxProperties", - "propertyNames", "patternProperties")): + key in obj + for key in ( + "minProperties", + "maxProperties", + "propertyNames", + "patternProperties", + ) + ): return True # Recursively check all nested objects and arrays @@ -259,16 +284,18 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: try: xgr.Grammar.from_regex(so_params.regex) except Exception as err: - raise ValueError("Failed to transform regex into a grammar: " - f"{err}") from err + raise ValueError( + f"Failed to transform regex into a grammar: {err}" + ) from err if so_params.choice: choice_grammar = choice_as_grammar(so_params.choice) try: xgr.Grammar.from_ebnf(choice_grammar) except Exception as err: - raise ValueError("Failed to transform choices into a grammar: " - "{err}") from err + raise ValueError( + "Failed to transform choices into a grammar: {err}" + ) from err so_params.choice = None so_params.grammar = choice_grammar return @@ -285,12 +312,14 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: try: xgr.Grammar.from_json_schema(schema) except Exception as err: - raise ValueError("Failed to transform json schema into a grammar: " - f"{err}") from err + raise ValueError( + f"Failed to transform json schema into a grammar: {err}" + ) from err if has_xgrammar_unsupported_json_features(schema): - raise ValueError("The provided JSON schema contains features not " - "supported by xgrammar.") + raise ValueError( + "The provided JSON schema contains features not supported by xgrammar." + ) return if so_params.grammar: @@ -300,7 +329,8 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: so_params.grammar = convert_lark_to_ebnf(so_params.grammar) except ValueError as e: raise ValueError( - "Failed to convert the grammar from Lark to EBNF. ") from e + "Failed to convert the grammar from Lark to EBNF. " + ) from e # Test parsing EBNF grammar, possibly already converted from Lark try: @@ -318,10 +348,12 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: begin=s["begin"], schema=json.dumps(s["schema"]), end=s["end"], - ) for s in s_tag["structures"] + ) + for s in s_tag["structures"] ] structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"]) + tags, s_tag["triggers"] + ) xgr.Grammar.from_structural_tag(structural_tag) except Exception as e: raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 99974ef46e..26f72ae50c 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -10,17 +10,19 @@ from concurrent.futures._base import TimeoutError from typing import Optional, Union, cast from vllm.sampling_params import SamplingParams -from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, - StructuredOutputKey, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputGrammar, + StructuredOutputKey, + StructuredOutputOptions, +) @dataclasses.dataclass class StructuredOutputRequest: - sampling_params: SamplingParams - _grammar: Optional[Union[Future[StructuredOutputGrammar], - StructuredOutputGrammar]] = None + _grammar: Optional[ + Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] + ] = None reasoning_ended: Optional[bool] = None def _check_grammar_completion(self) -> bool: @@ -43,13 +45,15 @@ class StructuredOutputRequest: @property def grammar(self) -> Optional[StructuredOutputGrammar]: completed = self._check_grammar_completion() - return cast(Optional[StructuredOutputGrammar], - self._grammar) if completed else None + return ( + cast(Optional[StructuredOutputGrammar], self._grammar) + if completed + else None + ) @grammar.setter def grammar( - self, grammar: Union[StructuredOutputGrammar, - Future[StructuredOutputGrammar]] + self, grammar: Union[StructuredOutputGrammar, Future[StructuredOutputGrammar]] ) -> None: self._grammar = grammar @@ -58,8 +62,7 @@ class StructuredOutputRequest: return get_structured_output_key(self.sampling_params) -def get_structured_output_key( - sampling_params: SamplingParams) -> StructuredOutputKey: +def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey: params = sampling_params.structured_outputs assert params is not None, "params can't be None." if params.json is not None: diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index b9b09bea1e..b7326847d0 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -76,27 +76,31 @@ def apply_grammar_bitmask( for req_id, batch_index in seq: logit_index = batch_index + cumulative_offset cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) if req_id in scheduler_output.structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]), - fill_value=-1, - dtype=grammar_bitmask.dtype) + sorted_bitmask = np.full( + shape=(logits.shape[0], grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype, + ) cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) + seq = sorted( + scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1] + ) for req_id, _ in seq: num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) if req_id in struct_out_req_batch_indices: logit_index = struct_out_req_batch_indices[req_id] for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] + sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask @@ -128,8 +132,7 @@ class OutlinesVocabulary: self.inner = vocabulary # Have to do abs(hash()) because python hashes can # be negative, and we are using hash as a cache key. - hex_str = hashlib.sha256( - vocabulary.__repr__().encode('utf-8')).hexdigest() + hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest() hash_int = int(hex_str, 16) self._hash = hash_int @@ -165,16 +168,18 @@ def get_outlines_cache(): cache_dir = get_outlines_cache_path() if envs.VLLM_V1_USE_OUTLINES_CACHE: - logger.warning("Enabling outlines cache. This is an unbounded on-disk " - "cache. It may consume a lot of disk space and should " - "not be used with untrusted clients.") + logger.warning( + "Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients." + ) cache = Cache(cache_dir, eviction_policy="none", cull_limit=0) outlines_version = importlib.metadata.version("outlines_core") - cached_version = cache.get('__version__', None) + cached_version = cache.get("__version__", None) if cached_version != outlines_version: cache.clear() - cache.set('__version__', outlines_version) + cache.set("__version__", outlines_version) return cache else: return LRUCache(maxsize=128) @@ -194,19 +199,17 @@ def _reduced_vocabulary( A Dict of token string -> equivalent token ids """ - unicode_to_bytes = { - v: k - for k, v in tokenization_gpt2.bytes_to_unicode().items() - } + unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: - string = tokenizer.convert_tokens_to_string([token]) # A hack to handle missing spaces to HF's Llama tokenizers - if (type(token) is str - and token.startswith(file_utils.SPIECE_UNDERLINE) - or token == "<0x20>"): + if ( + type(token) is str + and token.startswith(file_utils.SPIECE_UNDERLINE) + or token == "<0x20>" + ): return " " + string return string @@ -226,8 +229,7 @@ def _reduced_vocabulary( # by this point. token_bytes = bytes(token_str) # type: ignore[arg-type] - elif "\ufffd" in token_str and not re_replacement_seq.match( - token_str): + elif "\ufffd" in token_str and not re_replacement_seq.match(token_str): # Handle tokens with invalid UTF-8 sequences. if re_llama_byte_token.match(token): # Llama-like tokenizers use <0xXX> for incomplete sequences. @@ -238,12 +240,13 @@ def _reduced_vocabulary( if None in byte_vals: raise RuntimeError( f"Cannot convert token `{token}`" - f" ({token_idx}) to bytes: {token_str}") + f" ({token_idx}) to bytes: {token_str}" + ) # safe to ignore, since if None in byte_vals, # an error is thrown. token_bytes = bytes(byte_vals) # type: ignore[arg-type] else: - token_bytes = token_str.encode('utf-8') + token_bytes = token_str.encode("utf-8") if token_idx != eos_token_id: vocabulary.setdefault(token_bytes, []).append(token_idx) @@ -254,16 +257,18 @@ def _reduced_vocabulary( def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: - """Get the `Vocabulary` object for a given tokenizer. - """ + """Get the `Vocabulary` object for a given tokenizer.""" if hasattr(tokenizer, "_outlines_vocabulary"): return tokenizer._outlines_vocabulary # type: ignore try: - if hasattr( + if ( + hasattr( tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: + ) + and tokenizer.eos_token_id is not None + ): eos_token_id = tokenizer.eos_token_id else: raise ValueError( @@ -272,17 +277,18 @@ def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: reduced_vocab = _reduced_vocabulary( tokenizer, - eos_token_id #type: ignore + eos_token_id, # type: ignore ) - vocabulary = OutlinesVocabulary( - oc.Vocabulary(eos_token_id, reduced_vocab)) + vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab)) tokenizer._outlines_vocabulary = vocabulary # type: ignore return vocabulary except AttributeError as e: - raise ValueError(f"Cannot get the vocabulary of the tokenizer " - f"({type(tokenizer)}). The tokenizer should have a " - "get_vocab method.") from e + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"({type(tokenizer)}). The tokenizer should have a " + "get_vocab method." + ) from e def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -304,14 +310,14 @@ def grammar_is_likely_lark(grammar_str: str) -> bool: if not grammar_str or not isinstance(grammar_str, str): return False - for line in grammar_str.split('\n'): + for line in grammar_str.split("\n"): # Remove both comment styles - line = re.sub(r'(#|//).*$', '', line).strip() + line = re.sub(r"(#|//).*$", "", line).strip() if not line: continue # Look for EBNF rule definition - if '::=' in line: + if "::=" in line: return False return True @@ -348,40 +354,41 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: def clean_line(line: str) -> str: """Remove comments and whitespace from line.""" - return re.sub(r'(#|//).*$', '', line).strip() + return re.sub(r"(#|//).*$", "", line).strip() def check_quotes(text: str, rule_name: str, line_num: int) -> None: """Validate quote matching in text.""" if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: - raise ValueError( - f"Mismatched quotes in {rule_name} on line {line_num}") + raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}") def extract_references(text: str) -> set[str]: """Extract rule references from text.""" # Remove quoted strings and special characters - text = re.sub(r'"[^"]*"', '', text) - text = re.sub(r'[+*?()|\[\]{}]', ' ', text) - return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + text = re.sub(r'"[^"]*"', "", text) + text = re.sub(r"[+*?()|\[\]{}]", " ", text) + return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text)) # First pass: Find root rule and validate rule definitions - lines = [clean_line(line) for line in grammar_str.split('\n')] + lines = [clean_line(line) for line in grammar_str.split("\n")] first_rule = None for line_num, line in enumerate(lines, 1): - if not line or line.startswith('|'): + if not line or line.startswith("|"): continue - if ':' in line: + if ":" in line: try: - name = line.split(':', 1)[0].strip().strip('?') + name = line.split(":", 1)[0].strip().strip("?") defined_rules.add(name) if first_rule is None: first_rule = name - if name == 'start': - first_rule = 'start' + if name == "start": + first_rule = "start" except IndexError as e: - raise ValueError(f"Invalid rule format on line {line_num}. " - "Expected 'rule_name: definition'") from e + raise ValueError( + f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'" + ) from e if not defined_rules: raise ValueError("No valid rules found in grammar") @@ -398,29 +405,33 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: continue try: - if ':' in line and not line.startswith('|'): + if ":" in line and not line.startswith("|"): # Save previous rule if exists if current_rule: output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + f"{current_rule} ::= {' | '.join(current_definition)}" + ) # Process new rule - name, definition = line.split(':', 1) - current_rule = name.strip().strip('?') + name, definition = line.split(":", 1) + current_rule = name.strip().strip("?") check_quotes(definition, f"rule '{current_rule}'", line_num) definition = re.sub(r"'([^']*)'", r'"\1"', definition) referenced_rules.update(extract_references(definition)) current_definition = [definition.strip()] - elif line.startswith('|'): + elif line.startswith("|"): if not current_rule: - raise ValueError(f"Alternative '|' on line {line_num} " - "without a preceding rule definition") + raise ValueError( + f"Alternative '|' on line {line_num} " + "without a preceding rule definition" + ) alt_def = line[1:].strip() - check_quotes(alt_def, f"alternative for rule '{current_rule}'", - line_num) + check_quotes( + alt_def, f"alternative for rule '{current_rule}'", line_num + ) alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) referenced_rules.update(extract_references(alt_def)) current_definition.append(alt_def) @@ -430,25 +441,24 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: # Add final rule if exists if current_rule: - output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}") # Validate all rules are defined - undefined_rules = referenced_rules - defined_rules - {'root'} + undefined_rules = referenced_rules - defined_rules - {"root"} if undefined_rules: - raise ValueError("Referenced rules are not defined: " - f"{', '.join(sorted(undefined_rules))}") + raise ValueError( + f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}" + ) - return '\n'.join(output_lines) + return "\n".join(output_lines) def choice_as_grammar(choice: list[str]) -> str: - def escape_ebnf_string(s: str) -> str: """Escape special characters in a EBNF string.""" # Escape double quotes and backslashes - return re.sub(r'(["\\])', r'\\\1', s) + return re.sub(r'(["\\])', r"\\\1", s) escaped_choices = (escape_ebnf_string(c) for c in choice) - grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices) return grammar diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index ee0c1168f3..c96f221228 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -9,25 +9,35 @@ from collections.abc import Sequence from contextlib import AbstractContextManager from multiprocessing import connection from multiprocessing.process import BaseProcess -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Optional, + TypeVar, + Union, + overload, +) import torch from torch.autograd.profiler import record_function import vllm.envs as envs from vllm.logger import init_logger -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, - kill_process_tree) +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import ( + get_open_port, + get_open_zmq_ipc_path, + get_tcp_uri, + kill_process_tree, +) if TYPE_CHECKING: import numpy as np from vllm.v1.engine.coordinator import DPCoordinator - from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager) + from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager logger = init_logger(__name__) @@ -35,7 +45,6 @@ T = TypeVar("T") class ConstantList(Generic[T], Sequence): - def __init__(self, x: list[T]) -> None: self._x = x @@ -57,31 +66,23 @@ class ConstantList(Generic[T], Sequence): def clear(self): raise TypeError("Cannot clear a constant list") - def index(self, - item: T, - start: int = 0, - stop: Optional[int] = None) -> int: - return self._x.index(item, start, - stop if stop is not None else len(self._x)) + def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int: + return self._x.index(item, start, stop if stop is not None else len(self._x)) @overload - def __getitem__(self, item: int) -> T: - ... + def __getitem__(self, item: int) -> T: ... @overload - def __getitem__(self, s: slice, /) -> list[T]: - ... + def __getitem__(self, s: slice, /) -> list[T]: ... def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]: return self._x[item] @overload - def __setitem__(self, item: int, value: T): - ... + def __setitem__(self, item: int, value: T): ... @overload - def __setitem__(self, s: slice, value: T, /): - ... + def __setitem__(self, s: slice, value: T, /): ... def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]): raise TypeError("Cannot set item in a constant list") @@ -113,10 +114,7 @@ class CpuGpuBuffer: pin_memory: bool, with_numpy: bool = True, ) -> None: - self.cpu = torch.zeros(*size, - dtype=dtype, - device="cpu", - pin_memory=pin_memory) + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory) self.gpu = torch.zeros_like(self.cpu, device=device) self.np: np.ndarray # To keep type hints simple (avoiding generics and subclasses), we @@ -126,7 +124,8 @@ class CpuGpuBuffer: if dtype == torch.bfloat16: raise ValueError( "Bfloat16 torch tensors cannot be directly cast to a " - "numpy array, so call CpuGpuBuffer with with_numpy=False") + "numpy array, so call CpuGpuBuffer with with_numpy=False" + ) self.np = self.cpu.numpy() def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: @@ -142,9 +141,7 @@ class CpuGpuBuffer: return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) -def get_engine_client_zmq_addr(local_only: bool, - host: str, - port: int = 0) -> str: +def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: """Assign a new ZMQ socket address. If local_only is True, participants are colocated and so a unique IPC @@ -153,8 +150,11 @@ def get_engine_client_zmq_addr(local_only: bool, Otherwise, the provided host and port will be used to construct a TCP address (port == 0 means assign an available port).""" - return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( - host, port or get_open_port())) + return ( + get_open_zmq_ipc_path() + if local_only + else (get_tcp_uri(host, port or get_open_port())) + ) class APIServerProcessManager: @@ -195,21 +195,23 @@ class APIServerProcessManager: spawn_context = multiprocessing.get_context("spawn") self.processes: list[BaseProcess] = [] - for i, in_addr, out_addr in zip(range(num_servers), input_addresses, - output_addresses): + for i, in_addr, out_addr in zip( + range(num_servers), input_addresses, output_addresses + ): client_config = { "input_address": in_addr, "output_address": out_addr, "client_count": num_servers, - "client_index": i + "client_index": i, } if stats_update_address is not None: client_config["stats_update_address"] = stats_update_address - proc = spawn_context.Process(target=target_server_fn, - name=f"ApiServer_{i}", - args=(listen_address, sock, args, - client_config)) + proc = spawn_context.Process( + target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, client_config), + ) self.processes.append(proc) proc.start() @@ -224,10 +226,12 @@ class APIServerProcessManager: def wait_for_completion_or_failure( - api_server_manager: APIServerProcessManager, - engine_manager: Optional[Union["CoreEngineProcManager", - "CoreEngineActorManager"]] = None, - coordinator: Optional["DPCoordinator"] = None) -> None: + api_server_manager: APIServerProcessManager, + engine_manager: Optional[ + Union["CoreEngineProcManager", "CoreEngineActorManager"] + ] = None, + coordinator: Optional["DPCoordinator"] = None, +) -> None: """Wait for all processes to complete or detect if any fail. Raises an exception if any process exits with a non-zero status. @@ -240,16 +244,14 @@ def wait_for_completion_or_failure( coordinator: The coordinator for data parallel. """ - from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager) + from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager try: logger.info("Waiting for API servers to complete ...") # Create a mapping of sentinels to their corresponding processes # for efficient lookup sentinel_to_proc: dict[Any, BaseProcess] = { - proc.sentinel: proc - for proc in api_server_manager.processes + proc.sentinel: proc for proc in api_server_manager.processes } if coordinator: @@ -265,8 +267,7 @@ def wait_for_completion_or_failure( # Check if any process terminates while sentinel_to_proc or actor_run_refs: # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, - timeout=5) + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) # Process any terminated processes for sentinel in ready_sentinels: @@ -276,17 +277,18 @@ def wait_for_completion_or_failure( if proc.exitcode != 0: raise RuntimeError( f"Process {proc.name} (PID: {proc.pid}) " - f"died with exit code {proc.exitcode}") + f"died with exit code {proc.exitcode}" + ) if actor_run_refs: import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) except KeyboardInterrupt: logger.info("Received KeyboardInterrupt, shutting down API servers...") except Exception as e: - logger.exception("Exception occurred while running API servers: %s", - str(e)) + logger.exception("Exception occurred while running API servers: %s", str(e)) raise finally: logger.info("Terminating remaining processes ...") @@ -319,8 +321,9 @@ def shutdown(procs: list[BaseProcess]): kill_process_tree(pid) -def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, - length: int) -> torch.Tensor: +def copy_slice( + from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int +) -> torch.Tensor: """ Copy the first length elements of a tensor into another tensor in a non-blocking manner. @@ -333,8 +336,8 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def report_usage_stats( - vllm_config, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None: + vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT +) -> None: """Report usage statistics if enabled.""" if not is_usage_stats_enabled(): @@ -347,32 +350,21 @@ def report_usage_stats( usage_context, extra_kvs={ # Common configuration - "dtype": - str(vllm_config.model_config.dtype), - "tensor_parallel_size": - vllm_config.parallel_config.tensor_parallel_size, - "block_size": - vllm_config.cache_config.block_size, - "gpu_memory_utilization": - vllm_config.cache_config.gpu_memory_utilization, - "kv_cache_memory_bytes": - vllm_config.cache_config.kv_cache_memory_bytes, + "dtype": str(vllm_config.model_config.dtype), + "tensor_parallel_size": vllm_config.parallel_config.tensor_parallel_size, + "block_size": vllm_config.cache_config.block_size, + "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, + "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes, # Quantization - "quantization": - vllm_config.model_config.quantization, - "kv_cache_dtype": - str(vllm_config.cache_config.cache_dtype), - + "quantization": vllm_config.model_config.quantization, + "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype), # Feature flags - "enable_lora": - bool(vllm_config.lora_config), - "enable_prefix_caching": - vllm_config.cache_config.enable_prefix_caching, - "enforce_eager": - vllm_config.model_config.enforce_eager, - "disable_custom_all_reduce": - vllm_config.parallel_config.disable_custom_all_reduce, - }) + "enable_lora": bool(vllm_config.lora_config), + "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": vllm_config.parallel_config.disable_custom_all_reduce, + }, + ) _PROFILER_FUNC = None @@ -390,6 +382,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: func = record_function elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: import nvtx + func = nvtx.annotate _PROFILER_FUNC = func diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 82b6d1b514..4d3688453c 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -14,7 +14,6 @@ logger = init_logger(__name__) class BlockTable: - def __init__( self, block_size: int, @@ -31,13 +30,14 @@ class BlockTable: self.pin_memory = pin_memory self.device = device - self.block_table = self._make_buffer(max_num_reqs, - max_num_blocks_per_req, - dtype=torch.int32) + self.block_table = self._make_buffer( + max_num_reqs, max_num_blocks_per_req, dtype=torch.int32 + ) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - self.slot_mapping = self._make_buffer(self.max_num_batched_tokens, - dtype=torch.int64) + self.slot_mapping = self._make_buffer( + self.max_num_batched_tokens, dtype=torch.int64 + ) try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -56,7 +56,7 @@ class BlockTable: num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks - self.block_table.np[row_idx, start:start + num_blocks] = block_ids + self.block_table.np[row_idx, start : start + num_blocks] = block_ids def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 @@ -73,8 +73,9 @@ class BlockTable: self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src] self.block_table.np[src_tgt] = self.block_table.np[tgt_src] - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping( + self, req_indices: np.ndarray, positions: np.ndarray + ) -> None: # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. @@ -89,8 +90,10 @@ class BlockTable: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // virtual_block_size) + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + + positions // virtual_block_size + ) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. @@ -101,16 +104,20 @@ class BlockTable: # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local - self.slot_mapping.np[:req_indices.shape[0]] = np.where( - mask, slot_mapping, -1) + self.slot_mapping.np[: req_indices.shape[0]] = np.where( + mask, slot_mapping, -1 + ) else: - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + positions // self.block_size + ) block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping.np[:req_indices.shape[0]]) + np.add( + block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping.np[: req_indices.shape[0]], + ) def commit_block_table(self, num_reqs: int) -> None: self.block_table.copy_to_gpu(num_reqs) @@ -134,25 +141,27 @@ class BlockTable: """Returns the numpy array of the block table.""" return self.block_table.np - def _make_buffer(self, *size: Union[int, torch.SymInt], - dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory) + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory + ) class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - pin_memory: bool, - device: torch.device, - block_sizes: list[int], - num_speculative_tokens: int = 0) -> None: + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0, + ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -165,14 +174,20 @@ class MultiGroupBlockTable: self.block_tables = [ BlockTable( - block_size, max_num_reqs, - max(cdiv(max_model_len, block_size * dcp_world_size), - 1 + num_speculative_tokens), max_num_batched_tokens, - pin_memory, device) for block_size in block_sizes + block_size, + max_num_reqs, + max( + cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens, + ), + max_num_batched_tokens, + pin_memory, + device, + ) + for block_size in block_sizes ] - def append_row(self, block_ids: tuple[list[int], ...], - row_idx: int) -> None: + def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) @@ -188,8 +203,9 @@ class MultiGroupBlockTable: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping( + self, req_indices: np.ndarray, positions: np.ndarray + ) -> None: for block_table in self.block_tables: block_table.compute_slot_mapping(req_indices, positions) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 964e4c6b23..f48b354e8a 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -19,7 +19,6 @@ logger = init_logger(__name__) class CPUModelRunner(GPUModelRunner): - def __init__(self, vllm_config: VllmConfig, device: torch.device): with _torch_cuda_wrapper(): super().__init__(vllm_config, device) @@ -35,14 +34,15 @@ class CPUModelRunner(GPUModelRunner): # Note: Remove the override after new attention backend finished def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: if len(self.kv_cache_config.kv_cache_groups) > 1: - raise ValueError("Multiple KVCacheGroups is not" - "currently supported with CPU model runner.") + raise ValueError( + "Multiple KVCacheGroups is not" + "currently supported with CPU model runner." + ) super()._may_reorder_batch(scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors - def replace_tensor(obj: Any, cpu_attr_name: str, - device_attr_name) -> None: + def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: cpu_tensor = getattr(obj, cpu_attr_name, None) device_tensor = getattr(obj, device_attr_name, None) if cpu_tensor is not None and device_tensor is not None: @@ -68,8 +68,7 @@ class CPUModelRunner(GPUModelRunner): self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) + self.model = self.load_lora_model(self.model, self.vllm_config, self.device) def get_model(self) -> nn.Module: return self.model @@ -90,23 +89,19 @@ class CPUModelRunner(GPUModelRunner): def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: return sampled_token_ids.tolist() - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: # Note: For CPU backend, dp padding is not required for now. return 0, None @contextmanager def _torch_cuda_wrapper(): - class _EventPlaceholder: - def __init__(self, *args, **kwargs) -> None: self.record = lambda: None self.synchronize = lambda: None class _StreamPlaceholder: - def __init__(self, *args, **kwargs) -> None: pass diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index c6a686d6b7..ee865ec8e6 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -13,25 +13,27 @@ from vllm.model_executor.utils import set_random_seed from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo from vllm.v1.worker.cpu_model_runner import CPUModelRunner -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment logger = init_logger(__name__) class CPUWorker(Worker): - - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False): - super().__init__(vllm_config, - local_rank, - rank, - distributed_init_method, - is_driver_worker=is_driver_worker) + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__( + vllm_config, + local_rank, + rank, + distributed_init_method, + is_driver_worker=is_driver_worker, + ) self.parallel_config.disable_custom_all_reduce = True @@ -43,11 +45,13 @@ class CPUWorker(Worker): if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X): # For S390X/POWERPC SMT-8/4/2 self.local_omp_cpuid = self._get_autobind_cpu_ids( - lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) + lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4] + ) elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: # For x86 SMT-2, use 1 CPU per core self.local_omp_cpuid = self._get_autobind_cpu_ids( - lambda cpus: cpus[-1:]) + lambda cpus: cpus[-1:] + ) else: self.local_omp_cpuid = "all" else: @@ -55,9 +59,9 @@ class CPUWorker(Worker): omp_cpuids = omp_cpuids.split("|") if local_dp_rank is not None: world_size = self.parallel_config.world_size - omp_cpuids = omp_cpuids[local_dp_rank * - world_size:(local_dp_rank + 1) * - world_size] + omp_cpuids = omp_cpuids[ + local_dp_rank * world_size : (local_dp_rank + 1) * world_size + ] self.local_omp_cpuid = omp_cpuids[self.rank] if self.local_omp_cpuid != "all": @@ -66,19 +70,22 @@ class CPUWorker(Worker): logger.info(ret) # Note: unique identifier for creating allreduce shared memory - os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( - ":")[-1] + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1] # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner: CPUModelRunner = CPUModelRunner( - self.vllm_config, torch.device("cpu")) + self.vllm_config, torch.device("cpu") + ) def sleep(self, level: int = 1) -> None: logger.warning("sleep mode is not supported on CPU, ignore it.") @@ -98,31 +105,31 @@ class CPUWorker(Worker): self.model_runner.warming_up_model() def _get_autobind_cpu_ids( - self, cpu_selector: Callable[[list[LogicalCPUInfo]], - list[LogicalCPUInfo]] + self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]] ) -> str: """ - Return CPU ids to bind based on NUMA nodes. - Currently for rank N, only CPU ids on the N-th node in available NUMA + Return CPU ids to bind based on NUMA nodes. + Currently for rank N, only CPU ids on the N-th node in available NUMA node list will be selected. Args: - cpu_selector: a callable object to select CPUs from a CPU list + cpu_selector: a callable object to select CPUs from a CPU list of a physical core. The input is a LogicalCPUInfo list, sorted by - the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be + the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be returned. """ - allowed_numa_nodes, logical_cpu_list = \ + allowed_numa_nodes, logical_cpu_list = ( CpuPlatform.get_allowed_cpu_core_node_list() + ) assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( f"No enough allowed NUMA nodes to bind threads of " f"{self.parallel_config.world_size} CPUWorkers. " f"Allowed NUMA nodes are {allowed_numa_nodes}. " - "Please try to bind threads manually.") + "Please try to bind threads manually." + ) # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`` - selected_numa_node = allowed_numa_nodes[ - self.local_rank] # type: ignore + selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore logical_cpu_list = [ x for x in logical_cpu_list if x.numa_node == selected_numa_node ] @@ -142,15 +149,20 @@ class CPUWorker(Worker): # Reserve CPUs for other processes reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU if reserve_cpu_num is None: - need_reserve = (self.parallel_config.world_size > 1 or - self.parallel_config.data_parallel_size_local > 1) + need_reserve = ( + self.parallel_config.world_size > 1 + or self.parallel_config.data_parallel_size_local > 1 + ) reserve_cpu_num = 1 if need_reserve else 0 assert len(logical_cpu_list) > reserve_cpu_num, ( f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " - f"should less than {len(logical_cpu_list)}.") + f"should less than {len(logical_cpu_list)}." + ) if reserve_cpu_num != 0: logical_cpu_list = logical_cpu_list[:-reserve_cpu_num] - logger.info("auto thread-binding list (id, physical core): %s", - [(x.id, x.physical_core) for x in logical_cpu_list]) + logger.info( + "auto thread-binding list (id, physical core): %s", + [(x.id, x.physical_core) for x in logical_cpu_list], + ) return ",".join([str(x.id) for x in logical_cpu_list]) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7848283a90..06f9354236 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -15,9 +15,11 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - LogitsProcessors, - MoveDirectionality) +from vllm.v1.sample.logits_processor import ( + BatchUpdateBuilder, + LogitsProcessors, + MoveDirectionality, +) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -26,7 +28,6 @@ from vllm.v1.worker.block_table import MultiGroupBlockTable @dataclass class CachedRequestState: - req_id: str prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] @@ -46,7 +47,8 @@ class CachedRequestState: def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - self.prompt_token_ids, self.prompt_embeds) + self.prompt_token_ids, self.prompt_embeds + ) @property def num_tokens(self) -> int: @@ -57,7 +59,8 @@ class CachedRequestState: if self.prompt_token_ids is None: raise ValueError( f"Tried to access token index {idx}, but that token was " - "provided via prompt_embeds, and its ID is unknown.") + "provided via prompt_embeds, and its ID is unknown." + ) return self.prompt_token_ids[idx] elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] @@ -66,7 +69,6 @@ class CachedRequestState: class InputBatch: - def __init__( self, max_num_reqs: int, @@ -104,10 +106,9 @@ class InputBatch: pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() - self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), - device="cpu", - dtype=bool, - pin_memory=False) + self.is_token_ids = torch.zeros( + (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False + ) # Store prompt embeddings per request to avoid OOM from large upfront # allocation if max_model_len is big. # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) @@ -116,13 +117,12 @@ class InputBatch: self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), + (max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( @@ -136,34 +136,27 @@ class InputBatch: ) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) + self.temperature_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device) + self.top_k_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() @@ -171,54 +164,43 @@ class InputBatch: self.spec_decode_unsupported_reqs: set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.frequency_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + self.presence_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device ) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.repetition_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # Speculative decoding - self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), - dtype=torch.int64, - device="cpu", - pin_memory=pin_memory) - self.num_accepted_tokens_cpu = \ - self.num_accepted_tokens_cpu_tensor.numpy() + self.num_accepted_tokens_cpu_tensor = torch.ones( + (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory + ) + self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -250,8 +232,7 @@ class InputBatch: # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} - self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, - dtype=bool) + self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool) self.req_output_token_ids: list[Optional[list[int]]] = [] @@ -291,8 +272,13 @@ class InputBatch: # Detailed added request metadata is only required for non-pooling # models, to support logitsprocs. self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, - request.prompt_token_ids, request.output_token_ids)) + ( + new_req_index, + request.sampling_params, + request.prompt_token_ids, + request.output_token_ids, + ) + ) return new_req_index @@ -314,20 +300,19 @@ class InputBatch: # Copy the prompt token ids and output token ids. num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) + request.prompt_token_ids, request.prompt_embeds + ) self.num_prompt_tokens[req_index] = num_prompt_tokens start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) if request.prompt_token_ids is not None: - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids + self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids self.is_token_ids[req_index, :num_prompt_tokens] = True else: self.is_token_ids[req_index, :num_prompt_tokens] = False if request.prompt_embeds is not None: self.req_prompt_embeds[req_index] = request.prompt_embeds - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids + self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids self.is_token_ids[req_index, start_idx:end_idx] = True # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. @@ -339,8 +324,7 @@ class InputBatch: self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): + if self.is_spec_decode and is_spec_decode_unsupported(sampling_params): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Should avoid division by zero later when apply_temperature. @@ -359,16 +343,15 @@ class InputBatch: else: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty + self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty + self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty + self.repetition_penalties_cpu[req_index] = ( + sampling_params.repetition_penalty + ) if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) @@ -378,13 +361,17 @@ class InputBatch: self.generators[req_index] = request.generator if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = (self.vocab_size - if sampling_params.logprobs == -1 - else sampling_params.logprobs) + self.num_logprobs[req_id] = ( + self.vocab_size + if sampling_params.logprobs == -1 + else sampling_params.logprobs + ) if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( - self.vocab_size if sampling_params.prompt_logprobs == -1 - else sampling_params.prompt_logprobs) + self.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -395,24 +382,29 @@ class InputBatch: self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device=self.device) + device=self.device, + ) self.allowed_token_ids_mask_cpu_tensor = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device="cpu") + device="cpu", + ) self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + sampling_params.allowed_token_ids + ] = False if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + self.bad_words_token_ids[req_index] = ( + sampling_params.bad_words_token_ids + ) elif pooling_params := request.pooling_params: self.pooling_params[req_id] = pooling_params self.logits_processing_needs_token_ids[req_index] = ( - pooling_params.requires_token_ids) + pooling_params.requires_token_ids + ) else: raise NotImplementedError("Unrecognized request type") @@ -489,21 +481,32 @@ class InputBatch: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( + self.req_output_token_ids[i2], + self.req_output_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( + self.req_id_to_index[old_id_i2], + self.req_id_to_index[old_id_i1], + ) + self.num_tokens[i1], self.num_tokens[i2] = ( + self.num_tokens[i2], + self.num_tokens[i1], + ) + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( + self.num_tokens_no_spec[i2], + self.num_tokens_no_spec[i1], + ) + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = ( + self.num_prompt_tokens[i2], + self.num_prompt_tokens[i1], + ) + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = ( + self.num_computed_tokens_cpu[i2], + self.num_computed_tokens_cpu[i1], + ) # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -530,8 +533,10 @@ class InputBatch: self.block_table.swap_row(i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = ( + self.request_lora_mapping[i2], + self.request_lora_mapping[i1], + ) if self.is_pooling_model: # Sampling and logits parameters don't apply to pooling models. @@ -539,32 +544,42 @@ class InputBatch: # For autoregressive models, track detailed request reordering info # to support logitsprocs. - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) + self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP)) - self.temperature_cpu[i1], self.temperature_cpu[i2] = \ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] = \ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] = \ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ - self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] + self.temperature_cpu[i1], self.temperature_cpu[i2] = ( + self.temperature_cpu[i2], + self.temperature_cpu[i1], + ) + self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = ( + self.frequency_penalties_cpu[i2], + self.frequency_penalties_cpu[i1], + ) + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = ( + self.presence_penalties_cpu[i2], + self.presence_penalties_cpu[i1], + ) + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = ( + self.repetition_penalties_cpu[i2], + self.repetition_penalties_cpu[i1], + ) + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = ( + self.num_accepted_tokens_cpu[i2], + self.num_accepted_tokens_cpu[i1], + ) swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] + ( + self.allowed_token_ids_mask_cpu_tensor[i1], + self.allowed_token_ids_mask_cpu_tensor[i2], + ) = ( + self.allowed_token_ids_mask_cpu_tensor[i2], + self.allowed_token_ids_mask_cpu_tensor[i1], + ) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -616,23 +631,28 @@ class InputBatch: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] if last_req_index in self.req_prompt_embeds: - self.req_prompt_embeds[ - empty_index] = self.req_prompt_embeds.pop(last_req_index) + self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop( + last_req_index + ) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] + last_req_index + ] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ + last_req_index + ] self.block_table.move_row(last_req_index, empty_index) self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] + last_req_index + ] if self.is_pooling_model: last_req_index -= 1 @@ -642,33 +662,35 @@ class InputBatch: # Autoregressive models require detailed tracking of condense # operations to support logitsprocs self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) + (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL) + ) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] - self.num_accepted_tokens_cpu[ - empty_index] = self.num_accepted_tokens_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[ + last_req_index + ] + self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[ + last_req_index + ] + self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[ + last_req_index + ] + self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[ + last_req_index + ] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] + self.allowed_token_ids_mask_cpu_tensor[empty_index] = ( + self.allowed_token_ids_mask_cpu_tensor[last_req_index] + ) - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) + bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids @@ -700,8 +722,9 @@ class InputBatch: def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs if not self.all_greedy: - temperature = copy_slice(self.temperature_cpu_tensor, - self.temperature, num_reqs) + temperature = copy_slice( + self.temperature_cpu_tensor, self.temperature, num_reqs + ) else: temperature = None if not self.no_top_p: @@ -713,16 +736,22 @@ class InputBatch: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. - copy_slice(self.frequency_penalties_cpu_tensor, - self.frequency_penalties, num_reqs) - copy_slice(self.presence_penalties_cpu_tensor, - self.presence_penalties, num_reqs) - copy_slice(self.repetition_penalties_cpu_tensor, - self.repetition_penalties, num_reqs) + copy_slice( + self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs + ) + copy_slice( + self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs + ) + copy_slice( + self.repetition_penalties_cpu_tensor, + self.repetition_penalties, + num_reqs, + ) needs_prompt_token_ids = ( not self.no_penalties - or self.logits_processing_needs_token_ids[:num_reqs].any()) + or self.logits_processing_needs_token_ids[:num_reqs].any() + ) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. @@ -735,8 +764,11 @@ class InputBatch: allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None - copy_slice(self.allowed_token_ids_mask_cpu_tensor, - self.allowed_token_ids_mask, num_reqs) + copy_slice( + self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, + num_reqs, + ) allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] return SamplingMetadata( @@ -766,8 +798,7 @@ class InputBatch: pooling_params = self.get_pooling_params() return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]), + prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) @@ -786,9 +817,8 @@ class InputBatch: # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) + prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray @@ -804,12 +834,12 @@ class InputBatch: 3. lora_requests: Set of relevant LoRA requests. """ - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + req_lora_mapping = self.request_lora_mapping[: self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) + token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) + self.lora_id_to_lora_request.values() + ) return prompt_lora_mapping, token_lora_mapping, active_lora_requests @@ -835,9 +865,11 @@ class InputBatch: @property def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) + return ( + len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0 + ) @property def max_num_logprobs(self) -> Optional[int]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 11e24e4d13..90429b6b0c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,70 +24,112 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, +) +from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache + # yapf conflicts with isort for this block # yapf: disable -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - is_mixture_of_experts, - supports_eagle3, - supports_mrope, - supports_multimodal_pruning, - supports_transcription) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_multimodal_pruning, + supports_transcription, +) + # yapf: enable from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) + VllmModelForPooling, + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, cdiv, check_use_alibi, get_dtype_size, - is_pin_memory_available, - length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + DeviceMemoryProfiler, + GiB_bytes, + cdiv, + check_use_alibi, + get_dtype_size, + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, + round_up, + supports_dynamo, +) from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) + reorder_batch_to_split_decodes_and_prefills, + split_attn_metadata, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + # yapf conflicts with isort for this block # yapf: disable -from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - CrossAttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) + # yapf: enable -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, PoolerOutput, SamplerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, + SamplerOutput, +) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -101,18 +143,21 @@ from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, - ubatch_split) +from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from .utils import (AttentionGroup, MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import ( + AttentionGroup, + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + gather_mm_placeholders, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders, +) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -122,13 +167,11 @@ logger = init_logger(__name__) AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled -PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], - AttnMetadataDict] +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] # Wrapper for ModelRunnerOutput to support overlapped execution. class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): - def __init__( self, model_runner_output: ModelRunnerOutput, @@ -151,12 +194,13 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True) + "cpu", non_blocking=True + ) self._async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. - + This function blocks until the copy is finished. """ self._async_copy_ready_event.synchronize() @@ -174,7 +218,6 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -192,10 +235,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - from vllm.model_executor.layers.batch_invariant import ( - init_batch_invariance) + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + init_batch_invariance() model_config = self.model_config @@ -208,13 +251,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self.is_pooling_model = (model_config.runner_type == 'pooling') + self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( - model_config.is_multimodal_raw_input_only_model) + model_config.is_multimodal_raw_input_only_model + ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len @@ -227,12 +270,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( - self.parallel_config.distributed_executor_backend - == "external_launcher" and len(get_pp_group().ranks) > 0) + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) # Model-related. - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size # Only relevant for models using ALiBi (e.g, MPT) @@ -244,13 +287,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) if self.model_config.is_encoder_decoder: # Maximum length of the encoder input, only for encoder-decoder # models. - self.max_encoder_len = scheduler_config.\ - max_num_encoder_input_tokens + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens else: self.max_encoder_len = 0 @@ -284,17 +327,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore + vllm_config=self.vllm_config, device=self.device + ) # type: ignore else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) self.rejection_sampler = RejectionSampler() # Request states. @@ -322,58 +366,64 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + self.vllm_config.model_config.logits_processors, + ), is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = torch.cuda.Stream() if \ - self.use_async_scheduling else None + self.async_output_copy_stream = ( + torch.cuda.Stream() if self.use_async_scheduling else None + ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - if self.compilation_config.cudagraph_capture_sizes and \ - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes) + ) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, - dtype=torch.int64) - self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, - dtype=torch.int32) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. - self.inputs_embeds = self._make_buffer(self.max_num_tokens, - self.hidden_size, - dtype=self.dtype, - numpy=False) - self.is_token_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) - self.discard_request_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_indices = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) self.num_discarded_requests = 0 - self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) - self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -388,7 +438,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64) + (3, self.max_num_tokens + 1), dtype=torch.int64 + ) # CUDA event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. @@ -403,10 +454,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -418,19 +469,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device) + self.max_num_tokens, dtype=torch.int32, device=self.device + ) - self.uniform_decode_query_len = 1 if not self.speculative_config else \ - 1 + self.speculative_config.num_speculative_tokens + self.uniform_decode_query_len = ( + 1 + if not self.speculative_config + else 1 + self.speculative_config.num_speculative_tokens + ) # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) self.reorder_batch_threshold: Optional[int] = None @@ -440,14 +499,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], - torch.Tensor]] = None + self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): @@ -459,15 +518,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.mrope_positions.gpu[:, num_tokens] return self.positions.gpu[num_tokens] - def _make_buffer(self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy) + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy, + ) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -480,9 +540,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) + is not None + ): token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: @@ -497,7 +559,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) + device=self.device + ) return model_kwargs def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -523,17 +586,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ + assert self.reorder_batch_threshold == 1, ( "DCP not support reorder_batch_threshold > 1 now." + ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + ) # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ + """Initialize attributes from torch.cuda.get_device_properties""" self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -589,8 +653,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -647,14 +713,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. @@ -662,21 +728,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is not None: - old_end_idx = self.input_batch.num_tokens_no_spec[ - req_index] - end_idx = self.input_batch.num_prompt_tokens[ - req_index] + num_output_tokens + old_end_idx = self.input_batch.num_tokens_no_spec[req_index] + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx - self.input_batch.is_token_ids[req_index, - end_idx:old_end_idx] = False + self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = ( + False + ) # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -693,11 +760,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -706,21 +771,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids + req_index, start_index:end_token_index + ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens @@ -737,7 +803,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor) -> None: + self, output_token_ids: torch.Tensor + ) -> None: """Update the cached states after model execution. This is used for MTP/EAGLE for hybrid models, as in linear attention, @@ -750,14 +817,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = (torch.cat( - [ - output_token_ids, - torch.full((output_token_ids.size(0), 1), - -1, - device=output_token_ids.device), - ], - dim=1) == -1).int().argmax(-1).cpu().numpy() + num_accepted_tokens = ( + ( + torch.cat( + [ + output_token_ids, + torch.full( + (output_token_ids.size(0), 1), + -1, + device=output_token_ids.device, + ), + ], + dim=1, + ) + == -1 + ) + .int() + .argmax(-1) + .cpu() + .numpy() + ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens @@ -784,7 +863,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): use_audio_in_video = True if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( self.model.get_mrope_input_positions( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -794,8 +873,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) else: - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -805,6 +885,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, @@ -823,10 +904,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -862,10 +943,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return cu_num_tokens, arange - def _prepare_input_ids(self, total_num_scheduled_tokens: int, - cu_num_tokens: np.ndarray) -> None: + def _prepare_input_ids( + self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + ) -> None: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -894,7 +976,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # last token in each common request. flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_indices.append(flattened_index) - indices_match &= (prev_index == flattened_index) + indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) num_commmon_tokens = len(flattened_indices) if num_commmon_tokens < total_num_scheduled_tokens: @@ -914,28 +996,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, - 0], - non_blocking=True) + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + non_blocking=True, + ) if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor(flattened_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to( - self.device, - non_blocking=True) + input_ids_index_tensor = torch.tensor( + flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to(self.device, non_blocking=True) + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, index=input_ids_index_tensor, src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0]) + prev_common_req_indices_tensor, 0 + ], + ) def _get_encoder_seq_lens( self, @@ -957,10 +1038,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> tuple[PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray, - Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], - Optional[torch.Tensor], bool]: + ) -> tuple[ + PerLayerAttnMetadata, + torch.Tensor, + Optional[SpecDecodeMetadata], + np.ndarray, + Optional[CommonAttentionMetadata], + int, + Optional[UBatchSlices], + Optional[torch.Tensor], + bool, + ]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -986,19 +1074,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1009,24 +1097,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) if self.enable_prompt_embeds: is_token_ids = self.input_batch.is_token_ids.flatten() torch.index_select( is_token_ids, 0, token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) # Because we did not pre-allocate a massive prompt_embeds CPU tensor on # the InputBatch, we need to fill in the prompt embeds into the expected @@ -1060,52 +1152,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): actual_num_sched = actual_end - start_pos if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) self.query_start_loc.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded) - uniform_decode = \ - (max_num_scheduled_tokens == self.uniform_decode_query_len) and \ - (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = \ - ubatch_split(num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config) + num_tokens_unpadded + ) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + ubatch_slices, num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config, + ) self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) # Fill unused with 0 for full cuda graph mode. self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.copy_to_gpu() seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) # Record the index of requests that should not be sampled, @@ -1113,8 +1202,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = ( + discard_request_indices + ) self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) @@ -1125,13 +1215,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + non_blocking=True, + ) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -1149,27 +1239,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens + ) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices) + logits_indices + ) attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: @@ -1177,26 +1275,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): use_cascade_attn = False # Used in the below loop. - query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ] spec_decode_common_attn_metadata = None if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): encoder_seq_lens = self._get_encoder_seq_lens( - scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs + ) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( @@ -1205,7 +1306,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), + (total_num_scheduled_tokens,), dtype=torch.int64, device=self.device, ) @@ -1213,16 +1314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens] + slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( - -1) - num_common_prefix_blocks = ( - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id]) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ + kv_cache_group_id + ] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1242,11 +1341,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): encoder_seq_lens=encoder_seq_lens, ) - if (self.speculative_config - and spec_decode_common_attn_metadata is None): + if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if (self.drafter.attn_layer_names[0] - in kv_cache_group_spec.layer_names): + if ( + self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names + ): spec_decode_common_attn_metadata = common_attn_metadata else: spec_decode_common_attn_metadata = common_attn_metadata @@ -1264,24 +1364,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, - GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs + ], ) if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): - attn_metadata_i = (attn_group.get_metadata_builder( - ubatch_id=ubid).build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + common_attn_metadata_list + ): + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1290,9 +1393,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", - False) + **extra_attn_metadata_args, + ) + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1304,10 +1407,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding, use_cascade_attn) + return ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + max_num_scheduled_tokens, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) def _compute_cascade_attn_prefix_len( self, @@ -1379,18 +1489,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # this case. num_reqs = len(num_scheduled_tokens) common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() + ) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + common_prefix_len = ( + common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size + ) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None + ) + use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None + ) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1410,18 +1522,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - req.prompt_token_ids, req.prompt_embeds) + req.prompt_token_ids, req.prompt_embeds + ) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 @@ -1435,8 +1544,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) + self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ + :, src_start:src_end + ] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1476,10 +1586,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32) + num_sampled_tokens, cumsum_dtype=np.int32 + ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens + ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -1490,22 +1602,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # cu_num_draft_tokens: [3, 3, 5, 5, 6] # arange: [0, 1, 2, 0, 1, 0] cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32) + num_draft_tokens, cumsum_dtype=np.int32 + ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens + ) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) + self.device, non_blocking=True + ) + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] @@ -1529,23 +1647,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert self.kv_sharing_fast_prefill_logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices) # There might have leftover indices in logits_indices[num_logits:] # from previous iterations, whose values may be greater than the # batch size in the current iteration. To ensure indices are always # valid, we fill the padded indices with the last index. self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): + logits_indices[-1].item() + ) + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) else: num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ + :num_logits_padded + ] return logits_indices_padded def _batch_mm_kwargs_from_scheduler( @@ -1584,7 +1705,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs using the helper method. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output) + scheduler_output + ) if not mm_kwargs: return @@ -1599,10 +1721,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data.This solves the issue with scheduler @@ -1616,11 +1738,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): micro_batch_size = 1 for i in range(0, num_items, micro_batch_size): micro_batch_mm_inputs = dict( - (k, v[i:i + micro_batch_size]) - for k, v in mm_kwargs_group.items()) + (k, v[i : i + micro_batch_size]) + for k, v in mm_kwargs_group.items() + ) micro_batch_outputs = model.get_multimodal_embeddings( - **micro_batch_mm_inputs) + **micro_batch_mm_inputs + ) curr_group_outputs.extend(micro_batch_outputs) else: @@ -1631,8 +1755,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -1664,11 +1787,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens + num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position @@ -1696,15 +1817,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True if is_embed is None else is_embed + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed + ) mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], @@ -1721,7 +1842,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): multimodal_embeddings=mm_embeds_req, mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, - )) + ) + ) req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta @@ -1755,10 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model = cast(SupportsMultiModal, self.model) encoder_features = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Add the grouped features to encoder_features dict # This allows the model to receive them as kwargs (e.g., @@ -1795,21 +1917,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): supported_tasks = list(model.pooler.get_supported_tasks()) - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): + if ( + self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks + ): supported_tasks.remove("encode") - logger.debug_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") + logger.debug_once( + "Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it." + ) if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: supported_tasks.remove("score") - logger.debug_once( - "Score API is only enabled for num_labels == 1.") + logger.debug_once("Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1824,9 +1949,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return tuple(tasks) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size @@ -1838,21 +1965,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_rs - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) + v[:copy_len], non_blocking=True + ) - return IntermediateTensors({ - k: - v[:num_tokens // - tp] if k == "residual" and is_rs else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) + return IntermediateTensors( + { + k: v[: num_tokens // tp] + if k == "residual" and is_rs + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1869,8 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: """ Determines the total number of tokens that each rank will run. All ranks will be padded out so that they run with the same number @@ -1897,31 +2023,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return 0, None num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) + num_tokens, dp_size, dp_rank + ) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor( + [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 + ) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding def get_local_padding(self, num_tokens_unpadded: int) -> int: - num_tokens_padded = num_tokens_unpadded - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_tokens_unpadded) + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: + if ( + self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_pad_tokens = num_tokens_padded - num_tokens_unpadded @@ -1931,12 +2059,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, - num_total_tokens: int): - padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, - num_total_tokens) - ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + padded_second_ubatch_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + ubatch_slices[1] = UBatchSlice( + padded_second_ubatch_slice, padded_second_ubatch_slice + ) def _pool( self, @@ -1944,16 +2073,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + "Either all or none of the requests in a batch must be pooling request" + ) hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.get_pooling_metadata() - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens_np.tolist(), device=hidden_states.device + ) + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -1968,8 +2097,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens + ): output = raw_output if seq_len == prompt_len else None pooler_output.append(output) @@ -1983,11 +2112,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): # Use CUDA graphs. # Add padding to the batch size. return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) @@ -1996,8 +2127,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1): + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2007,10 +2140,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], dict[str, Any]]: - + ) -> tuple[ + int, + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if ubatch_slices: assert num_tokens_after_padding is not None @@ -2018,18 +2157,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif ubatch_slices is None: num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding( - num_input_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if (self.supports_mm_inputs and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder): + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -2041,8 +2181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2063,14 +2202,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If a batch only has token ids, then including the embedding layer # in the CUDA graph will be more performant (like in the else case # below). - token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ - .nonzero(as_tuple=False) \ + token_ids_idx = ( + self.is_token_ids.gpu[:num_scheduled_tokens] + .nonzero(as_tuple=False) .squeeze(1) + ) # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings( - input_ids=token_ids) + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2093,10 +2233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + num_input_tokens, intermediate_tensors, True + ) - if (self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs): + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): encoder_inputs = self._extract_encoder_inputs(scheduler_output) model_kwargs.update(encoder_inputs) @@ -2112,8 +2255,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) def _sample( - self, logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata] + self, + logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -2152,24 +2296,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return sampler_output def _bookkeeping_sync( - self, scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, num_scheduled_tokens: int + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, + num_scheduled_tokens: int, ) -> tuple[ - dict[str, int], - Optional[LogprobsLists], - list[list[int]], - dict[str, Optional[LogprobsTensors]], - list[str], - dict[str, int], - list[int], + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] + discard_sampled_tokens_req_indices = self.discard_request_indices.np[ + : self.num_discarded_requests + ] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -2178,14 +2326,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -2220,10 +2368,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = ( invalid_req_indices_set + ) self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2238,8 +2386,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if \ - req_idx not in invalid_req_indices_set else None + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: @@ -2250,7 +2397,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert end_idx <= self.max_model_len + 1, ( "Sampled token IDs exceed the max model length + 1. " f"Total number of tokens: {end_idx} > max_model_len + 1: " - f"{self.max_model_len + 1}") + f"{self.max_model_len + 1}" + ) n_tokens_cache = len(sampled_ids) @@ -2263,11 +2411,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if end_idx == self.max_model_len + 1: n_tokens_cache -= 1 - self.input_batch.token_ids_cpu[req_idx, start_idx:( - start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache] - self.input_batch.is_token_ids[req_idx, - start_idx:(start_idx + - n_tokens_cache)] = True + self.input_batch.token_ids_cpu[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = sampled_ids[:n_tokens_cache] + self.input_batch.is_token_ids[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2312,7 +2461,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """Helper method to call the model forward pass. This method can be overridden by subclasses for model execution. - Motivation: We can inspect only this method versus + Motivation: We can inspect only this method versus the whole execute_model, which has additional logic. Args: @@ -2349,18 +2498,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward( - scheduler_output, self.vllm_config) + scheduler_output, self.vllm_config + ) if self.cache_config.kv_sharing_fast_prefill: assert not self.input_batch.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs") + "it when the requests need prompt logprobs" + ) # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_after_padding, - use_cascade_attn) = self._prepare_inputs(scheduler_output) + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) ( num_scheduled_tokens, @@ -2371,26 +2529,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) + ) = self._preprocess( + scheduler_output, + intermediate_tensors, + ubatch_slices, + num_tokens_after_padding, + ) - uniform_decode = (max_query_len - == self.uniform_decode_query_len) and ( - num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor, - use_cascade_attn) + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + ) # Set cudagraph mode to none if calc_kv_scales is true. if attn_metadata is not None: - metadata_list = (attn_metadata.values() if isinstance( - attn_metadata, dict) else [attn_metadata]) + metadata_list = ( + attn_metadata.values() + if isinstance(attn_metadata, dict) + else [attn_metadata] + ) if any( - getattr(m, 'enable_kv_scales_calculation', False) - for m in metadata_list): + getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list + ): cudagraph_runtime_mode = CUDAGraphMode.NONE # This is currently to get around the assert in the DPMetadata @@ -2400,7 +2565,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Run the model. # Use persistent buffers for CUDA graphs. - with (set_forward_context( + with ( + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, @@ -2408,9 +2574,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, - ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as - kv_connector_output): + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -2438,8 +2605,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.is_pooling_model: # Return the pooling output. - output = self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) output.kv_connector_output = kv_connector_output return output @@ -2451,14 +2619,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) } get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + all_gather_tensors=all_gather_tensors, + ) logits = None else: sample_hidden_states = hidden_states[logits_indices] @@ -2468,16 +2637,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() - model_output_broadcast_data = get_pp_group( - ).broadcast_tensor_dict(model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask(scheduler_output, self.input_batch, - logits, self.device) + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -2496,22 +2666,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_common_attn_metadata, ) - use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_drafter_batch + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len - if (self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len - is not None): + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len) + self.speculative_config.draft_model_config.max_model_len + ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + - self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len) + spec_decode_common_attn_metadata.max_seq_len + + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) if use_padded_batch_for_eagle and input_fits_in_drafter: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. @@ -2526,12 +2701,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync(scheduler_output, sampler_output, - logits, hidden_states, - num_scheduled_tokens) + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) - if (self.speculative_config and not use_padded_batch_for_eagle - and input_fits_in_drafter): + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2587,10 +2769,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - sampled_token_ids, self.input_batch.req_ids, + sampled_token_ids, + self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs) + self.input_batch.spec_decode_unsupported_reqs, + ) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2603,8 +2787,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): offset = 0 assert spec_decode_metadata is not None for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) @@ -2621,29 +2805,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" "padded-batch is disabled." + ) next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, - scheduler_output.num_scheduled_tokens) + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ + ) + next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, self.discard_request_indices.gpu, - self.num_discarded_requests + self.num_discarded_requests, ) + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -2653,32 +2843,34 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ + common_attn_metadata, token_indices, token_indices_to_sample = ( self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, - valid_sampled_tokens_count) + valid_sampled_tokens_count, + ) + ) target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[token_indices] @@ -2706,9 +2898,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -2721,26 +2914,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) + global_expert_load, old_global_expert_indices = EplbState.recv_state() num_logical_experts = global_expert_load.shape[1] self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts + num_local_physical_experts * new_ep_size - num_logical_experts + ) + assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0 + old_ep_size = ( + old_global_expert_indices.shape[1] // num_local_physical_experts + ) rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) + old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) } else: global_expert_load = None @@ -2752,36 +2943,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + self.model.get_eagle3_aux_hidden_state_layers() + ) else: raise RuntimeError( "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") + "aux_hidden_state_outputs was requested" + ) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) prepare_communication_buffer_for_model(self.model) - self.is_multimodal_pruning_enabled = (supports_multimodal_pruning( - self.model) and self.model_config.multimodal_config. - is_multimodal_pruning_enabled()) + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.model) + and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + ) - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( self.model, self.device, @@ -2792,11 +2988,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + and supports_dynamo() ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.dynamo_as_is_count += 1 self.model.compile(fullgraph=True, backend=backend) return @@ -2804,26 +2999,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ - and not self.parallel_config.enable_dbo: - self.model = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.enable_dbo + ): + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.FULL, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) else: - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.NONE, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), - model_config=self.model_config) + model_loader.load_weights(self.get_model(), model_config=self.model_config) def save_tensorized_model( self, @@ -2861,7 +3060,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) @@ -2869,7 +3069,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( - num_prompt_tokens - 1, num_prompt_logprobs + 1) + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. @@ -2899,27 +3100,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() - prompt_hidden_states = hidden_states[offset:offset + num_logits] + prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. - tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( - logprobs, num_prompt_logprobs, tgt_token_ids) + logprobs, num_prompt_logprobs, tgt_token_ids + ) # Transfer GPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( - token_ids, non_blocking=True) - logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, - non_blocking=True) + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( - ranks, non_blocking=True) + ranks, non_blocking=True + ) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. @@ -2947,8 +3150,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_index = self.input_batch.req_id_to_index[req_id] num_nans_in_logits[req_id] = ( int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + if num_nans_for_index is not None and req_index < logits.shape[0] + else 0 + ) return num_nans_in_logits except IndexError: return {} @@ -2974,11 +3178,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + dtype=input_ids.dtype, + ) logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) + input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) @@ -3003,13 +3207,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_mm_items = [dummy_mm_item] * max_items_per_batch model = cast(SupportsMultiModal, self.model) - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) @torch.inference_mode() def _dummy_run( @@ -3046,8 +3252,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode is None or \ - cudagraph_runtime_mode.valid_runtime_modes() + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -3062,8 +3270,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -3079,9 +3286,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs = num_decode_tokens + 1 # Create decode requests (1 token each) followed by prefill request - num_scheduled_tokens_list = [1] * num_decode_tokens + [ - num_prefill_tokens - ] + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: @@ -3098,8 +3303,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) ubatch_slices = None @@ -3153,56 +3357,61 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange( - num_scheduled_tokens) - self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], + query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], seq_lens=self.seq_lens.gpu[:num_reqs], seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch. - block_table[kv_cache_group_id].get_device_tensor(num_reqs), + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ - kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) + kv_cache_group_id + ].slot_mapping.gpu[:num_tokens], + causal=True, + ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): + common_attn_metadata_list + ): assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = (attn_group\ - .get_metadata_builder(ubatch_id=ubid)\ - .build_for_cudagraph_capture(common_attn_metadata)) + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) for layer_name in attn_group.layer_names: assert type(attn_metadata) is list - attn_metadata[ubid][ - layer_name] = attn_metadata_i + attn_metadata[ubid][layer_name] = attn_metadata_i else: assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder()\ - .build_for_cudagraph_capture(common_attn_metadata) + attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( + common_attn_metadata + ) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, remove_lora): + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): model_kwargs = self._init_model_kwargs(num_tokens) - if (self.supports_mm_inputs - and not self.model_config.is_encoder_decoder): + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { @@ -3230,23 +3439,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + ) + ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) + num_tokens, None, False + ) # filter out the valid batch descriptor - _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode)) \ - if not is_profile else (CUDAGraphMode.NONE, None) + _cg_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + if not is_profile + else (CUDAGraphMode.NONE, None) + ) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture - assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode == _cg_mode, ( + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) else: cudagraph_runtime_mode = _cg_mode @@ -3258,14 +3479,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_after_padding - with self.maybe_randomize_inputs(input_ids), set_forward_context( + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices): + ubatch_slices=ubatch_slices, + ), + ): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3309,8 +3534,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) + dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=dummy_tensors(0.5), @@ -3331,37 +3555,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logitsprocs=LogitsProcessors(), ) try: - sampler_output = self.sampler(logits=logits, - sampling_metadata=dummy_metadata) + sampler_output = self.sampler( + logits=logits, sampling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " "`max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, self.device) + draft_token_ids, self.device + ) num_tokens = sum(len(ids) for ids in draft_token_ids) # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) + target_logits = torch.randn( + num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype + ) # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + bonus_token_ids = torch.zeros( + num_reqs, device=self.device, dtype=torch.int32 + ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, @@ -3391,9 +3617,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens_list, device="cpu", ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) + dummy_token_ids = torch.zeros( + (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device + ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) @@ -3407,19 +3633,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooling_params=[dummy_pooling_params] * num_reqs, ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor( + num_scheduled_tokens_list, device=hidden_states.device + ) try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) + return model.pooler( + hidden_states=hidden_states, pooling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e @@ -3445,7 +3674,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -3455,8 +3685,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -3474,9 +3705,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -3493,7 +3724,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): expanded_outputs = [] for output in dummy_encoder_outputs: expanded = output.new_zeros( - (encoder_budget, encoder_output_shape[-1])) + (encoder_budget, encoder_output_shape[-1]) + ) num_tokens = output.shape[0] expanded[:num_tokens].copy_(output) expanded_outputs.append(expanded) @@ -3501,12 +3733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_encoder_outputs = expanded_outputs # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True + ) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -3523,7 +3755,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to `NONE`") + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) return 0 else: self.initialize_cudagraph_capture() @@ -3563,24 +3796,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + ) # Capture full cudagraph for uniform decode batches if we # don't already have full mixed prefill-decode cudagraphs. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len + x + for x in self.cudagraph_batch_sizes + if x <= max_num_tokens and x >= self.uniform_decode_query_len ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + ) torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -3596,16 +3834,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) return cuda_graph_size - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode.valid_runtime_modes(), \ - f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + def _capture_cudagraphs( + self, + compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3614,7 +3859,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + cudagraph_runtime_mode.name, + ), + ) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: @@ -3622,14 +3869,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = self.parallel_config.enable_dbo \ - and cudagraph_runtime_mode == CUDAGraphMode.FULL \ - and uniform_decode \ + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, num_tokens=num_tokens, uniform_decode=uniform_decode, ) + ) for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. @@ -3637,29 +3886,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + assert len(self.attn_groups) == 0, "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] @@ -3669,8 +3920,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_group_spec: KVCacheGroupSpec, ) -> dict[AttentionGroupKey, list[str]]: layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, - kv_cache_group_spec.layer_names) + self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names + ) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -3690,23 +3941,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name] + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] key = (full_cls_name, layer_kv_cache_spec) - attn_backends[key] = AttentionGroupKey(attn_backend, - layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) attn_backend_layers[key].append(layer_name) - return { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - } + return {attn_backends[k]: v for k, v in attn_backend_layers.items()} def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for (attn_backend, - kv_cache_spec), layer_names in attn_backends_map.items(): + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): attn_group = AttentionGroup.create_with_metadata_builders( attn_backend, layer_names, @@ -3714,7 +3961,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config, self.device, num_metadata_builders=1 - if not self.parallel_config.enable_dbo else 2, + if not self.parallel_config.enable_dbo + else 2, ) attn_groups.append(attn_group) @@ -3729,7 +3977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def initialize_cudagraph_capture(self) -> None: """ - Resolve the cudagraph_mode when there are multiple attention + Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. @@ -3745,81 +3993,110 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported - if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ - and min_cg_support != AttentionCGSupport.ALWAYS: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") + if ( + cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL + and min_cg_support != AttentionCGSupport.ALWAYS + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += "; please try cudagraph_mode=PIECEWISE, and "\ + msg += ( + "; please try cudagraph_mode=PIECEWISE, and " "make sure compilation level is piecewise" + ) raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY + ) logger.warning(msg) # check that if we are doing decode full-cudagraphs it is supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and min_cg_support == AttentionCGSupport.NEVER): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") - if (self.compilation_config.level == CompilationLevel.PIECEWISE and - (self.compilation_config.splitting_ops_contain_attention() - or self.compilation_config.use_inductor_graph_partition)): - msg += "; setting cudagraph_mode=PIECEWISE because "\ + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) + if self.compilation_config.level == CompilationLevel.PIECEWISE and ( + self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition + ): + msg += ( + "; setting cudagraph_mode=PIECEWISE because " "attention is compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: - msg += "; setting cudagraph_mode=NONE because "\ + msg += ( + "; setting cudagraph_mode=NONE because " "attention is not compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is # supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_cg_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})") + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 + and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})" + ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if cudagraph_mode.has_full_cudagraphs() \ - and min_cg_support == AttentionCGSupport.NEVER: - raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" - f"support:{min_cg_support}) " - "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise") + if ( + cudagraph_mode.has_full_cudagraphs() + and min_cg_support == AttentionCGSupport.NEVER + ): + raise ValueError( + f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise" + ) # Trigger cudagraph dispatching keys initialization here (after # initializing attn backends). self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len) + self.compilation_config.cudagraph_mode, self.uniform_decode_query_len + ) def calculate_reorder_batch_threshold(self) -> None: """ @@ -3831,22 +4108,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: + if reorder_batch_threshold_i != self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " - f"{self.reorder_batch_threshold}") + f"{self.reorder_batch_threshold}" + ) else: self.reorder_batch_threshold = reorder_batch_threshold_i - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -3863,7 +4138,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=max(self.max_model_len, self.max_encoder_len), @@ -3877,11 +4153,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), + if self.vllm_config.speculative_config + else 0 + ), ) def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -3891,12 +4170,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -3906,8 +4185,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: @@ -3945,8 +4225,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -3954,41 +4233,43 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype) + cache_dtype_str=self.cache_config.cache_dtype, + ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len( - kv_cache_shape) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple( - range(len(kv_cache_shape))) + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and # then permute it according to the stride order which could # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + kv_cache_spec.page_size_bytes // dtype_size + ) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) @@ -4012,7 +4293,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor]) -> None: + self, kv_caches: dict[str, torch.Tensor] + ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to (num_blocks, 2, ...). @@ -4025,19 +4307,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if (isinstance(kv_cache_spec, AttentionSpec) - and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, \ - "Fail to determine whether the layout is " \ - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " f"a tensor of shape {kv_cache.shape}" + ) hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_(size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, - *kv_cache.stride()[2:])) + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -4050,25 +4334,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors + ) # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - num_attn_module = 2 \ - if self.model_config.hf_config.model_type == "longcat_flash" else 1 - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, num_attn_module) + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) return kv_caches def maybe_add_kv_sharing_layers_to_kv_cache_groups( - self, kv_cache_config: KVCacheConfig) -> None: + self, kv_cache_config: KVCacheConfig + ) -> None: """ Add layers that re-use KV cache to KV cache group of its target layer. Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` @@ -4087,12 +4375,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other # similar KV sharing setups, only the layers that generate KV caches # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) else: break @@ -4124,23 +4410,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.dcp_world_size > 1: layer_names = self.attn_groups[0][0].layer_names - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) for layer in layers.values(): assert layer.impl.need_to_return_lse_for_decode, ( "DCP requires attention impls to return" " the softmax lse for decode, but the impl " f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode.") + "does not return the softmax lse for decode." + ) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: @@ -4148,16 +4434,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -4174,8 +4462,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -4190,59 +4477,67 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for sliding" \ - "window" + assert not use_mla, "MLA is not supported for slidingwindow" kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window) + sliding_window=attn_module.sliding_window, + ) elif use_mla: kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): + cache_dtype_str=cache_dtype_str, + ) + elif self.attention_chunk_size is not None and isinstance( + attn_module, ChunkedLocalAttention + ): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size) + attention_chunk_size=self.attention_chunk_size, + ) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"] + ): raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") + "Mamba with speculative decoding is not supported yet." + ) mamba_block_size = self.vllm_config.cache_config.mamba_block_size - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( @@ -4253,10 +4548,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mamba_type=mamba_module.mamba_type, num_speculative_blocks=( self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), + if self.speculative_config + else 0 + ), ) ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache) + self.vllm_config, DeepseekV32IndexerCache + ) for layer_name, ds_indexer_module in ds_indexer_layers.items(): kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() @@ -4271,7 +4569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # this is in the critical path of every single model # forward loop, this has caused perf issue for a disagg # setup. - pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]] pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 39be8c7410..3bd7c9d538 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -11,10 +11,12 @@ import vllm.envs as envs from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group -from vllm.distributed.device_communicators.pynccl_allocator import ( - set_graph_pool_id) -from vllm.forward_context import (create_forward_context, get_forward_context, - override_forward_context) +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id +from vllm.forward_context import ( + create_forward_context, + get_forward_context, + override_forward_context, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -42,27 +44,31 @@ class CUDAGraphMetaData: class SMControlContextManager: - - def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None], - set_compute_sms: Callable[[int], None]): + def __init__( + self, + comm_sms: int, + set_comm_sms: Callable[[int], None], + set_compute_sms: Callable[[int], None], + ): """ - Context manager for controlling SM (Streaming Multiprocessor) + Context manager for controlling SM (Streaming Multiprocessor) allocation. Upon entering the context, it sets the number of SMs allocated for communication and computation to comm_sms and total_sms - comm_sms respectively. Upon exiting, it restores the allocation to use all available SMs (i.e. total_sms). Args: - comm_sms (int): The number of SMs to allocate for communication. + comm_sms (int): The number of SMs to allocate for communication. (The remainder will be used for computation.) - set_comm_sms (Callable[[int], None]): + set_comm_sms (Callable[[int], None]): A function that sets the number of SMs for communication. - set_compute_sms (Callable[[int], None]): + set_compute_sms (Callable[[int], None]): A function that sets the number of SMs for computation. """ - assert current_platform.is_cuda(), \ + assert current_platform.is_cuda(), ( "SM control is currently only supported on CUDA" + ) props = torch.cuda.get_device_properties(torch.cuda.current_device()) total_sms = props.multi_processor_count @@ -84,9 +90,13 @@ class SMControlContextManager: class UBatchWrapper: - - def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, device: torch.cuda.device): + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + device: torch.cuda.device, + ): self.runnable = runnable self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config @@ -100,7 +110,8 @@ class UBatchWrapper: self.graph_pool = None if runtime_mode is not CUDAGraphMode.NONE: self.cudagraph_wrapper = CUDAGraphWrapper( - runnable, vllm_config, runtime_mode=runtime_mode) + runnable, vllm_config, runtime_mode=runtime_mode + ) self.graph_pool = current_platform.get_global_graph_pool() self.sm_control = self._create_sm_control_context(vllm_config) @@ -114,8 +125,7 @@ class UBatchWrapper: if vllm_config.parallel_config.enable_expert_parallel: # Currently only DeepEP highthroughput supports SM control so this # only affects that case. - all2all_manager = get_ep_group( - ).device_communicator.all2all_manager + all2all_manager = get_ep_group().device_communicator.all2all_manager if all2all_manager.max_sms_used() is not None: comm_sms = min(comm_sms, all2all_manager.max_sms_used()) @@ -127,18 +137,23 @@ class UBatchWrapper: set_compute_sms = lambda sms: None if has_deep_gemm() and comm_sms > 0: import deep_gemm as dg + set_compute_sms = lambda sms: dg.set_num_sms(sms) - return SMControlContextManager(comm_sms=comm_sms, - set_comm_sms=set_comm_sms, - set_compute_sms=set_compute_sms) + return SMControlContextManager( + comm_sms=comm_sms, + set_comm_sms=set_comm_sms, + set_compute_sms=set_compute_sms, + ) def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError(f"Attribute {key} not exists in the runnable of " - f"cudagraph wrapper: {self.runnable}") + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}" + ) def unwrap(self) -> Callable: # in case we need to access the original runnable. @@ -153,14 +168,14 @@ class UBatchWrapper: the graph capture. The flow is as follows: - 1. The main thread starts up each ubatch thread. Each thread will + 1. The main thread starts up each ubatch thread. Each thread will initialize its cuda context (torch.cuda.current_blas_handle()) before going to sleep upon entering the ubatch_context. - 2. The main thread starts the graph capture and wakes up the first + 2. The main thread starts the graph capture and wakes up the first ubatch thread. - 3. Each ubatch thread runs the model to completion and returns the + 3. Each ubatch thread runs the model to completion and returns the completed output tensors back to the main thread. 4. The main thread stores the captured cudagraph along with its metadata @@ -187,36 +202,38 @@ class UBatchWrapper: results: list[tuple[int, torch.Tensor]] = [] compute_stream = ubatch_metadata[0].context.compute_stream - num_tokens = ubatch_metadata[0].num_tokens + \ - ubatch_metadata[1].num_tokens + num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): ubatch_threads = [] for metadata in ubatch_metadata: - thread = threading.Thread(target=_capture_ubatch_thread, - args=( - results, - metadata, - )) + thread = threading.Thread( + target=_capture_ubatch_thread, + args=( + results, + metadata, + ), + ) ubatch_threads.append(thread) thread.start() self.ready_barrier.wait() # Wait for both threads to be ready # Capture the cudagraph - cudagraph_metadata = \ - CUDAGraphMetaData( - cudagraph=torch.cuda.CUDAGraph(), - ubatch_metadata=ubatch_metadata, - ) + cudagraph_metadata = CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + ubatch_metadata=ubatch_metadata, + ) if self.graph_pool is not None: set_graph_pool_id(self.graph_pool) else: set_graph_pool_id(current_platform.graph_pool_handle()) - with torch.cuda.graph(cudagraph_metadata.cudagraph, - stream=compute_stream, - pool=self.graph_pool): + with torch.cuda.graph( + cudagraph_metadata.cudagraph, + stream=compute_stream, + pool=self.graph_pool, + ): ubatch_metadata[0].context.cpu_wait_event.set() for thread in ubatch_threads: thread.join() @@ -227,7 +244,6 @@ class UBatchWrapper: return cudagraph_metadata.outputs def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: - @torch.inference_mode() def _ubatch_thread(results, model, ubatch_metadata): with ubatch_metadata.context: @@ -247,12 +263,14 @@ class UBatchWrapper: with override_forward_context(None): ubatch_threads = [] for metadata in ubatch_metadata: - thread = threading.Thread(target=_ubatch_thread, - args=( - results, - model, - metadata, - )) + thread = threading.Thread( + target=_ubatch_thread, + args=( + results, + model, + metadata, + ), + ) ubatch_threads.append(thread) thread.start() self.ready_barrier.wait() # Wait for both threads to be ready @@ -263,11 +281,19 @@ class UBatchWrapper: result = torch.cat(sorted_results, dim=0) return result - def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, - positions, inputs_embeds, intermediate_tensors, - compute_stream, dp_metadata, batch_descriptor, - cudagraph_runtime_mode) -> list[UbatchMetadata]: - + def _make_ubatch_metadata( + self, + ubatch_slices, + attn_metadata, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + compute_stream, + dp_metadata, + batch_descriptor, + cudagraph_runtime_mode, + ) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] for i, ubatch_slice in enumerate(ubatch_slices): @@ -277,22 +303,32 @@ class UBatchWrapper: self.vllm_config, dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=cudagraph_runtime_mode)) + cudagraph_runtime_mode=cudagraph_runtime_mode, + ) + ) ubatch_ctxs = make_ubatch_contexts( num_micro_batches=len(ubatch_slices), comm_stream=self.comm_stream, compute_stream=compute_stream, forward_contexts=forward_contexts, - ready_barrier=self.ready_barrier) + ready_barrier=self.ready_barrier, + ) ubatch_metadata: list[UbatchMetadata] = [] for i, ubatch_slice in enumerate(ubatch_slices): - sliced_input_ids, sliced_positions, sliced_inputs_embeds, \ - sliced_intermediate_tensors = \ - self._slice_model_inputs( - ubatch_slice.token_slice, input_ids, positions, - inputs_embeds, intermediate_tensors) + ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) = self._slice_model_inputs( + ubatch_slice.token_slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ) ubatch_metadata.append( UbatchMetadata( context=ubatch_ctxs[i], @@ -300,13 +336,21 @@ class UBatchWrapper: positions=sliced_positions, inputs_embeds=sliced_inputs_embeds, intermediate_tensors=sliced_intermediate_tensors, - num_tokens=ubatch_slice.token_slice.stop - - ubatch_slice.token_slice.start)) + num_tokens=ubatch_slice.token_slice.stop + - ubatch_slice.token_slice.start, + ) + ) return ubatch_metadata - def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, - inputs_embeds, intermediate_tensors): + def _slice_model_inputs( + self, + tokens_slice: slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ): sliced_input_ids = input_ids[tokens_slice] # if we are using mrope. Mrope adds an additional dimension to the # positions tensor @@ -314,13 +358,17 @@ class UBatchWrapper: sliced_positions = positions[:, tokens_slice] else: sliced_positions = positions[tokens_slice] - sliced_inputs_embeds = inputs_embeds[ - tokens_slice] if inputs_embeds else None - sliced_intermediate_tensors = intermediate_tensors[ - tokens_slice] if intermediate_tensors else None + sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None + sliced_intermediate_tensors = ( + intermediate_tensors[tokens_slice] if intermediate_tensors else None + ) - return (sliced_input_ids, sliced_positions, sliced_inputs_embeds, - sliced_intermediate_tensors) + return ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) def __call__(self, *args, **kwargs): forward_context = get_forward_context() @@ -330,7 +378,6 @@ class UBatchWrapper: # If there's no ubatching, just run the runnable object if ubatch_slices is None: - # This is to account for the case where ubatching was aborted. # When we capture full graphs we only capture one graph per shape, # meaning that if we have a ubatched cudagraph for the current @@ -342,20 +389,20 @@ class UBatchWrapper: if batch_descriptor.num_tokens in self.cudagraphs: cudagraph_runtime_mode = CUDAGraphMode.NONE - if cudagraph_runtime_mode in (CUDAGraphMode.NONE, - CUDAGraphMode.PIECEWISE): + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): return self.runnable(*args, **kwargs) else: assert self.cudagraph_wrapper is not None return self.cudagraph_wrapper(*args, **kwargs) attn_metadata = forward_context.attn_metadata - num_tokens = (ubatch_slices[0].token_slice.stop - - ubatch_slices[0].token_slice.start) * 2 - input_ids = kwargs['input_ids'] - positions = kwargs['positions'] - intermediate_tensors = kwargs['intermediate_tensors'] - inputs_embeds = kwargs['inputs_embeds'] + num_tokens = ( + ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start + ) * 2 + input_ids = kwargs["input_ids"] + positions = kwargs["positions"] + intermediate_tensors = kwargs["intermediate_tensors"] + inputs_embeds = kwargs["inputs_embeds"] compute_stream = torch.cuda.current_stream() dp_metadata = forward_context.dp_metadata @@ -363,8 +410,10 @@ class UBatchWrapper: # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None - if num_tokens not in self.cudagraphs \ - and cudagraph_runtime_mode is CUDAGraphMode.FULL: + if ( + num_tokens not in self.cudagraphs + and cudagraph_runtime_mode is CUDAGraphMode.FULL + ): ubatch_metadata = self._make_ubatch_metadata( ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, @@ -375,11 +424,14 @@ class UBatchWrapper: compute_stream=compute_stream, dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=CUDAGraphMode.NONE) + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ) with self.sm_control: return self._capture_ubatches(ubatch_metadata, self.model) - elif num_tokens in self.cudagraphs \ - and cudagraph_runtime_mode is CUDAGraphMode.FULL: + elif ( + num_tokens in self.cudagraphs + and cudagraph_runtime_mode is CUDAGraphMode.FULL + ): cudagraph_metadata = self.cudagraphs[num_tokens] cudagraph_metadata.cudagraph.replay() return cudagraph_metadata.outputs @@ -394,6 +446,7 @@ class UBatchWrapper: compute_stream=compute_stream, dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=CUDAGraphMode.NONE) + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ) with self.sm_control: return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a135a594ac..271aabb9e2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" + import copy import gc import os @@ -13,9 +14,11 @@ import torch.nn as nn import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce, +) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -28,8 +31,12 @@ from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, ModelRunnerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + ModelRunnerOutput, +) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -43,7 +50,6 @@ if TYPE_CHECKING: class Worker(WorkerBase): - def __init__( self, vllm_config: VllmConfig, @@ -52,16 +58,18 @@ class Worker(WorkerBase): distributed_init_method: str, is_driver_worker: bool = False, ): - - super().__init__(vllm_config=vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker) + super().__init__( + vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() # Buffers saved before sleep @@ -71,8 +79,10 @@ class Worker(WorkerBase): # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", @@ -91,7 +101,9 @@ class Worker(WorkerBase): with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) + torch_profiler_trace_dir, use_gzip=True + ), + ) else: self.profiler = None @@ -104,20 +116,20 @@ class Worker(WorkerBase): if level == 2: model = self.model_runner.model self._sleep_saved_buffers = { - name: buffer.cpu().clone() - for name, buffer in model.named_buffers() + name: buffer.cpu().clone() for name, buffer in model.named_buffers() } allocator = CuMemAllocator.get_instance() - allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep used_bytes = total - free_bytes_after_sleep assert freed_bytes >= 0, "Memory usage increased after sleeping." logger.info( - "Sleep mode freed %.2f GiB memory, " - "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, - used_bytes / GiB_bytes) + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) def wake_up(self, tags: Optional[list[str]] = None) -> None: from vllm.device_allocator.cumem import CuMemAllocator @@ -133,23 +145,21 @@ class Worker(WorkerBase): buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} - def _maybe_get_memory_pool_context(self, - tag: str) -> AbstractContextManager: + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() if tag == "weights": assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") + "Sleep mode can only be used for one instance per process." + ) context = allocator.use_memory_pool(tag=tag) else: context = nullcontext() return context - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -166,10 +176,13 @@ class Worker(WorkerBase): # memory snapshot # This ensures NCCL buffers are allocated before we measure # available memory - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # Set random seed. set_random_seed(self.model_config.seed) @@ -180,8 +193,10 @@ class Worker(WorkerBase): # take current memory snapshot self.init_snapshot = MemorySnapshot() - self.requested_memory = (self.init_snapshot.total_memory * - self.cache_config.gpu_memory_utilization) + self.requested_memory = ( + self.init_snapshot.total_memory + * self.cache_config.gpu_memory_utilization + ) if self.init_snapshot.free_memory < self.requested_memory: GiB = lambda b: round(b / GiB_bytes, 2) raise ValueError( @@ -194,12 +209,12 @@ class Worker(WorkerBase): f"utilization or reduce GPU memory used by other processes." ) else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") + raise RuntimeError(f"Not support device type: {self.device_config.device}") # Construct the model runner self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device) + self.vllm_config, self.device + ) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -247,7 +262,8 @@ class Worker(WorkerBase): "size. If OOM'ed, check the difference of initial free " "memory between the current run and the previous run " "where kv_cache_memory_bytes is suggested and update it " - "correspondingly.") + "correspondingly." + ) logger.info(msg) return kv_cache_memory_bytes @@ -257,8 +273,8 @@ class Worker(WorkerBase): # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( - self.init_snapshot, - weights_memory=int(self.model_runner.model_memory_usage), + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), ) as profile_result: self.model_runner.profile_run() @@ -275,15 +291,15 @@ class Worker(WorkerBase): "This happens when other processes sharing the same container " "release GPU memory while vLLM is profiling during initialization. " "To fix this, ensure consistent GPU memory allocation or " - "isolate vLLM in its own container.") - self.available_kv_cache_memory_bytes = self.requested_memory \ - - profile_result.non_kv_cache_memory + "isolate vLLM in its own container." + ) + self.available_kv_cache_memory_bytes = ( + self.requested_memory - profile_result.non_kv_cache_memory + ) - unrequested_memory = self.init_snapshot.free_memory \ - - self.requested_memory + unrequested_memory = self.init_snapshot.free_memory - self.requested_memory logger.debug( - "Initial free memory: %.2f GiB; " - "Requested memory: %.2f (util), %.2f GiB", + "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB", GiB(self.init_snapshot.free_memory), self.cache_config.gpu_memory_utilization, GiB(self.requested_memory), @@ -295,8 +311,10 @@ class Worker(WorkerBase): GiB(free_gpu_memory - unrequested_memory), ) logger.debug(profile_result) - logger.info("Available KV cache memory: %.2f GiB", - GiB(self.available_kv_cache_memory_bytes)) + logger.info( + "Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + ) gc.collect() return int(self.available_kv_cache_memory_bytes) @@ -324,15 +342,14 @@ class Worker(WorkerBase): warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes + x + for x in warmup_sizes + if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, - skip_eplb=True, - remove_lora=False) + self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False) self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) # Warmup and tune the kernels used during model execution before @@ -343,8 +360,9 @@ class Worker(WorkerBase): if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model() - if (self.cache_config.kv_cache_memory_bytes is None - and hasattr(self, "peak_activation_memory")): + if self.cache_config.kv_cache_memory_bytes is None and hasattr( + self, "peak_activation_memory" + ): # Suggests optimal kv cache memory size if we rely on # memory_profiling to guess the kv cache memory size which # provides peak_activation_memory and a few other memory @@ -358,16 +376,22 @@ class Worker(WorkerBase): # slightly underestimate the memory consumption. # So leave a small buffer (=150MiB) to avoid OOM. redundancy_buffer_memory = 150 * (1 << 20) - non_kv_cache_memory = (self.model_runner.model_memory_usage + - self.peak_activation_memory + - self.non_torch_memory + - cuda_graph_memory_bytes) + non_kv_cache_memory = ( + self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes + ) kv_cache_memory_bytes_to_gpu_limit = ( - self.init_snapshot.free_memory - non_kv_cache_memory - - redundancy_buffer_memory) + self.init_snapshot.free_memory + - non_kv_cache_memory + - redundancy_buffer_memory + ) kv_cache_memory_bytes_to_requested_limit = ( - int(self.requested_memory) - non_kv_cache_memory - - redundancy_buffer_memory) + int(self.requested_memory) + - non_kv_cache_memory + - redundancy_buffer_memory + ) msg = ( f"Free memory on device " @@ -388,7 +412,8 @@ class Worker(WorkerBase): f"{kv_cache_memory_bytes_to_gpu_limit}` " f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully " f"utilize gpu memory. Current kv cache memory in use is " - f"{GiB(self.available_kv_cache_memory_bytes)} GiB.") + f"{GiB(self.available_kv_cache_memory_bytes)} GiB." + ) logger.debug(msg) @@ -398,20 +423,20 @@ class Worker(WorkerBase): # NOTE: This is called after `capture_model` on purpose to prevent # memory buffers from being cleared by `torch.cuda.empty_cache`. if get_pp_group().is_last_rank: - max_num_reqs = min(self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens) + max_num_reqs = min( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + ) # We skip EPLB here since we don't want to record dummy metrics - hidden_states, last_hidden_states = \ - self.model_runner._dummy_run( - num_tokens=max_num_reqs, - skip_eplb=True, - ) + hidden_states, last_hidden_states = self.model_runner._dummy_run( + num_tokens=max_num_reqs, + skip_eplb=True, + ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) else: - self.model_runner._dummy_sampler_run( - hidden_states=last_hidden_states) + self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. @@ -431,32 +456,36 @@ class Worker(WorkerBase): intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.model_runner._get_num_input_tokens( - num_scheduled_tokens) + num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp(self.vllm_config, - num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) } if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors)) + all_gather_tensors=all_gather_tensors, + ) + ) - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) + output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output assert isinstance(output, IntermediateTensors) parallel_config = self.vllm_config.parallel_config - assert parallel_config.distributed_executor_backend != ( - "external_launcher") and not get_pp_group().is_last_rank + assert ( + parallel_config.distributed_executor_backend != ("external_launcher") + and not get_pp_group().is_last_rank + ) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + get_pp_group().send_tensor_dict( + output.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) kv_connector_output = output.kv_connector_output if not kv_connector_output: @@ -483,8 +512,9 @@ class Worker(WorkerBase): self.profiler.stop() # only print profiler results on rank 0 if self.local_rank == 0: - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) + print( + self.profiler.key_averages().table(sort_by="self_cuda_time_total") + ) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1, uniform_decode=True) @@ -505,68 +535,79 @@ class Worker(WorkerBase): # worker will always be healthy as long as it's running. return - def _eplb_before_scale_down(self, old_ep_size: int, - new_ep_size: int) -> None: + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: from vllm.distributed.parallel_state import get_ep_group + if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "before scaling down...") + logger.info( + "[Elastic EP] Starting expert resharding before scaling down..." + ) rank_mapping = { old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 for old_ep_rank in range(old_ep_size) } assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange(self.model_runner.model, - execute_shuffle=True, - global_expert_load=None, - rank_mapping=rank_mapping) - torch.cuda.synchronize() - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _eplb_after_scale_up( - self, old_ep_size: int, new_ep_size: int, - global_expert_load: Optional[torch.Tensor]) -> None: - from vllm.distributed.parallel_state import get_ep_group - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "after scaling up...") - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } - assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange( + self.model_runner.model, + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping, + ) + torch.cuda.synchronize() + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + def _eplb_after_scale_up( + self, + old_ep_size: int, + new_ep_size: int, + global_expert_load: Optional[torch.Tensor], + ) -> None: + from vllm.distributed.parallel_state import get_ep_group + + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding after scaling up...") + rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} + assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=True, global_expert_load=global_expert_load, - rank_mapping=rank_mapping) + rank_mapping=rank_mapping, + ) if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") def _reconfigure_parallel_config( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: """ Update parallel config with provided reconfig_request """ parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank - if reconfig_request.new_data_parallel_rank_local != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank_local = \ + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( reconfig_request.new_data_parallel_rank_local - parallel_config.data_parallel_master_ip = \ + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) - def _reconfigure_moe(self, old_ep_size: int, - new_ep_size: int) -> Optional[torch.Tensor]: + def _reconfigure_moe( + self, old_ep_size: int, new_ep_size: int + ) -> Optional[torch.Tensor]: """ Reconfigure MoE modules with provided reconfig_request @@ -574,20 +615,26 @@ class Worker(WorkerBase): otherwise None """ from vllm.distributed.parallel_state import ( - get_dp_group, get_ep_group, prepare_communication_buffer_for_model) - from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoEParallelConfig) + get_dp_group, + get_ep_group, + prepare_communication_buffer_for_model, + ) + from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig parallel_config = self.vllm_config.parallel_config moe_modules = [ - module for module in self.model_runner.model.modules() - if (module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE") + module + for module in self.model_runner.model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) ] num_local_experts = moe_modules[0].moe_config.num_local_experts - assert all(module.moe_config.num_local_experts == num_local_experts - for module in moe_modules), ( - "All MoE modules must have the same number of experts") + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" for module in moe_modules: module.moe_config.num_experts = num_local_experts * new_ep_size module.global_num_experts = module.moe_config.num_experts @@ -600,49 +647,62 @@ class Worker(WorkerBase): if new_ep_size < old_ep_size: num_local_physical_experts = num_local_experts assert self.model_runner.eplb_state is not None - new_physical_experts = \ + new_physical_experts = ( self.model_runner.eplb_state.physical_to_logical_map.shape[1] + ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1]) + new_physical_experts + - self.model_runner.eplb_state.logical_replica_count.shape[1] + ) global_expert_load = None else: - num_local_physical_experts = torch.tensor([num_local_experts], - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + num_local_physical_experts = torch.tensor( + [num_local_experts], dtype=torch.int32, device="cpu" + ) + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = num_local_physical_experts.item() new_physical_experts = num_local_physical_experts * new_ep_size assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=False) + self.model_runner.model, execute_shuffle=False + ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_load.shape[1]) + new_physical_experts - global_expert_load.shape[1] + ) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( num_physical_experts=new_physical_experts, - num_local_physical_experts=num_local_physical_experts) + num_local_physical_experts=num_local_physical_experts, + ) return global_expert_load def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: from vllm.config import set_current_vllm_config from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_ep_group) + cleanup_dist_env_and_memory, + get_ep_group, + ) old_ep_size = get_ep_group().world_size old_ep_rank = get_ep_group().rank - new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( - ).world_size * get_pp_group().world_size + new_ep_size = ( + reconfig_request.new_data_parallel_size + * get_tp_group().world_size + * get_pp_group().world_size + ) if new_ep_size < old_ep_size: self._eplb_before_scale_down(old_ep_size, new_ep_size) cleanup_dist_env_and_memory() - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): assert old_ep_rank >= new_ep_size # shutdown return @@ -650,16 +710,18 @@ class Worker(WorkerBase): self._reconfigure_parallel_config(reconfig_request) with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + ) global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) if new_ep_size > old_ep_size: assert global_expert_load is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, - global_expert_load) + self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load) def save_sharded_state( self, @@ -668,6 +730,7 @@ class Worker(WorkerBase): max_size: Optional[int] = None, ) -> None: from vllm.model_executor.model_loader import ShardedStateLoader + ShardedStateLoader.save_model( self.model_runner.model, path, @@ -680,7 +743,8 @@ class Worker(WorkerBase): tensorizer_config: "TensorizerConfig", ) -> None: self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + tensorizer_config=tensorizer_config, + ) def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): @@ -698,12 +762,14 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, backend) + init_distributed_environment( + parallel_config.world_size, rank, distributed_init_method, local_rank, backend + ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, - parallel_config.decode_context_parallel_size) + parallel_config.decode_context_parallel_size, + ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index cdc0d317ff..473982bebb 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -3,22 +3,30 @@ """ Define KV connector functionality mixin for model runners. """ + import copy +from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Generator # noqa: UP035 -from typing import TYPE_CHECKING, Optional +from typing import ( + TYPE_CHECKING, # noqa: UP035 + Optional, +) from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, - get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_shutdown, + get_kv_transfer_group, + has_kv_transfer_group, +) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, - ModelRunnerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -28,7 +36,6 @@ logger = init_logger(__name__) # Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) class KVConnectorModelRunnerMixin: - @staticmethod def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). @@ -36,8 +43,7 @@ class KVConnectorModelRunnerMixin: kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) # Background KV cache transfers happen here. # These transfers are designed to be async and the requests @@ -62,17 +68,21 @@ class KVConnectorModelRunnerMixin: ) -> tuple[Optional[set[str]], Optional[set[str]]]: if has_kv_transfer_group(): return get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids) + scheduler_output.finished_req_ids + ) return None, None @staticmethod - def kv_connector_no_forward(scheduler_output: "SchedulerOutput", - vllm_config: VllmConfig) -> ModelRunnerOutput: + def kv_connector_no_forward( + scheduler_output: "SchedulerOutput", vllm_config: VllmConfig + ) -> ModelRunnerOutput: # KV send/recv even if no work to do. - with set_forward_context( - None, vllm_config - ), KVConnectorModelRunnerMixin._get_kv_connector_output( - scheduler_output, wait_for_save=False) as kv_connector_output: + with ( + set_forward_context(None, vllm_config), + KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output, wait_for_save=False + ) as kv_connector_output, + ): pass if kv_connector_output.is_empty(): @@ -84,18 +94,20 @@ class KVConnectorModelRunnerMixin: @staticmethod def maybe_get_kv_connector_output( - scheduler_output: "SchedulerOutput" + scheduler_output: "SchedulerOutput", ) -> AbstractContextManager[Optional[KVConnectorOutput]]: - return KVConnectorModelRunnerMixin._get_kv_connector_output( - scheduler_output) if has_kv_transfer_group() else nullcontext() + return ( + KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) + if has_kv_transfer_group() + else nullcontext() + ) # This context manager must be used within an active forward context. # It encapsulates the entire KV connector lifecycle within execute_model @staticmethod @contextmanager def _get_kv_connector_output( - scheduler_output: "SchedulerOutput", - wait_for_save: bool = True + scheduler_output: "SchedulerOutput", wait_for_save: bool = True ) -> Generator[KVConnectorOutput, None, None]: output = KVConnectorOutput() @@ -103,8 +115,7 @@ class KVConnectorModelRunnerMixin: kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) # Background KV cache transfers happen here. # These transfers are designed to be async and the requests @@ -118,12 +129,13 @@ class KVConnectorModelRunnerMixin: kv_connector.wait_for_save() output.finished_sending, output.finished_recving = ( - kv_connector.get_finished(scheduler_output.finished_req_ids)) - output.invalid_block_ids = ( - kv_connector.get_block_ids_with_load_errors()) + kv_connector.get_finished(scheduler_output.finished_req_ids) + ) + output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() - output.kv_connector_stats = KVConnectorModelRunnerMixin.\ - get_kv_connector_stats() + output.kv_connector_stats = ( + KVConnectorModelRunnerMixin.get_kv_connector_stats() + ) kv_connector.clear_connector_metadata() @staticmethod diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e416f50322..e7358c4271 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -28,19 +28,19 @@ logger = init_logger(__name__) # Defined as a mixin for GPUModelRunner class LoRAModelRunnerMixin: - LORA_WARMUP_RANK = 8 - def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig, - device: torch.device) -> nn.Module: - + def load_lora_model( + self, model: nn.Module, vllm_config: VllmConfig, device: torch.device + ) -> nn.Module: if not supports_lora(model): - raise ValueError( - f"{model.__class__.__name__} does not support LoRA yet.") + raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.") if supports_multimodal(model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model." + ) # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( @@ -51,41 +51,44 @@ class LoRAModelRunnerMixin: ) return self.lora_manager.create_lora_manager(model) - def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], - token_lora_mapping: tuple[int, ...], - lora_requests: set[LoRARequest]) -> None: + def _set_active_loras( + self, + prompt_lora_mapping: tuple[int, ...], + token_lora_mapping: tuple[int, ...], + lora_requests: set[LoRARequest], + ) -> None: self._ensure_lora_enabled() # Set is_prefill to True, so we always use the SGMV kernels on # non-cuda platforms. # On cuda platforms we use the same kernels for prefill and # decode and this flag is generally ignored. - lora_mapping = LoRAMapping(token_lora_mapping, - prompt_lora_mapping, - is_prefill=True) + lora_mapping = LoRAMapping( + token_lora_mapping, prompt_lora_mapping, is_prefill=True + ) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def _ensure_lora_enabled(self) -> None: if not hasattr(self, "lora_manager"): - raise RuntimeError( - "LoRA is not enabled. Use --enable-lora to enable LoRA.") - - def set_active_loras(self, input_batch: InputBatch, - num_scheduled_tokens: np.ndarray) -> None: + raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") + def set_active_loras( + self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray + ) -> None: prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs - token_lora_mapping: tuple[int, - ...] # of size np.sum(num_scheduled_tokens) + token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) lora_requests: set[LoRARequest] - prompt_lora_mapping, token_lora_mapping, lora_requests = \ - input_batch.make_lora_inputs(num_scheduled_tokens) - return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) + prompt_lora_mapping, token_lora_mapping, lora_requests = ( + input_batch.make_lora_inputs(num_scheduled_tokens) + ) + return self._set_active_loras( + prompt_lora_mapping, token_lora_mapping, lora_requests + ) @contextmanager - def maybe_setup_dummy_loras(self, - lora_config: Optional[LoRAConfig], - remove_lora: bool = True): + def maybe_setup_dummy_loras( + self, lora_config: Optional[LoRAConfig], remove_lora: bool = True + ): if lora_config is None: yield else: @@ -96,9 +99,11 @@ class LoRAModelRunnerMixin: # Make dummy lora requests lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") + LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) for lora_id in range(1, num_loras + 1) } @@ -106,8 +111,7 @@ class LoRAModelRunnerMixin: # Add the dummy LoRAs here so _set_active_loras doesn't try to # load from disk. for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK) yield @@ -116,8 +120,9 @@ class LoRAModelRunnerMixin: self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], - num_scheduled_tokens: np.ndarray): + def maybe_select_dummy_loras( + self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray + ): if lora_config is None: yield else: @@ -129,35 +134,37 @@ class LoRAModelRunnerMixin: # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % - num_loras) + 1 + prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 # Make token lora mapping - token_lora_mapping = np.repeat(prompt_lora_mapping, - num_scheduled_tokens) + token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) # Make dummy lora requests lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") + LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) for lora_id in range(1, num_loras + 1) } - self._set_active_loras(tuple(prompt_lora_mapping), - tuple(token_lora_mapping), lora_requests) + self._set_active_loras( + tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests + ) yield @contextmanager - def maybe_dummy_run_with_lora(self, - lora_config: Optional[LoRAConfig], - num_scheduled_tokens: np.ndarray, - remove_lora: bool = True): + def maybe_dummy_run_with_lora( + self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + remove_lora: bool = True, + ): with ( - self.maybe_setup_dummy_loras(lora_config, remove_lora), - self.maybe_select_dummy_loras(lora_config, - num_scheduled_tokens), + self.maybe_setup_dummy_loras(lora_config, remove_lora), + self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens), ): yield diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 4cd0ac352d..34fed8f964 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -18,16 +18,15 @@ _SAMPLING_EPS = 1e-5 class InputBatch: - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -54,13 +53,12 @@ class InputBatch: self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), + (max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( @@ -73,91 +71,72 @@ class InputBatch: ) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) + self.temperature_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device) + self.top_k_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() - self.min_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.min_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.min_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.min_p_reqs: set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.frequency_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + self.presence_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device ) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.repetition_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # req_index -> (min_tokens, stop_token_ids) self.min_tokens: dict[int, tuple[int, set[int]]] = {} # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -174,8 +153,7 @@ class InputBatch: # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} - self.logit_bias: list[Optional[dict[int, - float]]] = [None] * max_num_reqs + self.logit_bias: list[Optional[dict[int, float]]] = [None] * max_num_reqs self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. @@ -214,15 +192,14 @@ class InputBatch: # Copy the prompt token ids and output token ids. num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) + request.prompt_token_ids, request.prompt_embeds + ) # TODO: copy prompt_embeds self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids + self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids + self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids # Number of token ids in token_ids_cpu. # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens @@ -252,23 +229,22 @@ class InputBatch: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k self.min_p_cpu[req_index] = sampling_params.min_p - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty + self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty if sampling_params.min_p > _SAMPLING_EPS: self.min_p_reqs.add(req_id) if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty + self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty + self.repetition_penalties_cpu[req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) if sampling_params.min_tokens: - self.min_tokens[req_index] = (sampling_params.min_tokens, - sampling_params.all_stop_token_ids) + self.min_tokens[req_index] = ( + sampling_params.min_tokens, + sampling_params.all_stop_token_ids, + ) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -287,23 +263,23 @@ class InputBatch: if self.allowed_token_ids_mask_cpu_tensor is None: # Lazy allocation for this tensor, which can be large. # False means we don't fill with -inf. - self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.allowed_token_ids_mask = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device="cpu") + device=self.device, + ) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu" + ) self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + sampling_params.allowed_token_ids + ] = False if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids # Add request lora ID if request.lora_request: @@ -361,35 +337,51 @@ class InputBatch: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( + self.req_output_token_ids[i2], + self.req_output_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.min_p_cpu[i1], self.min_p_cpu[i2] =\ - self.min_p_cpu[i2], self.min_p_cpu[i1] + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( + self.req_id_to_index[old_id_i2], + self.req_id_to_index[old_id_i1], + ) + self.num_tokens[i1], self.num_tokens[i2] = ( + self.num_tokens[i2], + self.num_tokens[i1], + ) + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( + self.num_tokens_no_spec[i2], + self.num_tokens_no_spec[i1], + ) + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = ( + self.num_prompt_tokens[i2], + self.num_prompt_tokens[i1], + ) + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = ( + self.num_computed_tokens_cpu[i2], + self.num_computed_tokens_cpu[i1], + ) + self.temperature_cpu[i1], self.temperature_cpu[i2] = ( + self.temperature_cpu[i2], + self.temperature_cpu[i1], + ) + self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = ( + self.frequency_penalties_cpu[i2], + self.frequency_penalties_cpu[i1], + ) + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = ( + self.presence_penalties_cpu[i2], + self.presence_penalties_cpu[i1], + ) + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = ( + self.repetition_penalties_cpu[i2], + self.repetition_penalties_cpu[i1], + ) + self.min_p_cpu[i1], self.min_p_cpu[i2] = self.min_p_cpu[i2], self.min_p_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -404,21 +396,28 @@ class InputBatch: swap_dict_values(self.min_tokens, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - self.logit_bias[i1], self.logit_bias[i2] =\ - self.logit_bias[i2], self.logit_bias[i1] + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = ( + self.request_lora_mapping[i2], + self.request_lora_mapping[i1], + ) + self.logit_bias[i1], self.logit_bias[i2] = ( + self.logit_bias[i2], + self.logit_bias[i1], + ) if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] + ( + self.allowed_token_ids_mask_cpu_tensor[i1], + self.allowed_token_ids_mask_cpu_tensor[i2], + ) = ( + self.allowed_token_ids_mask_cpu_tensor[i2], + self.allowed_token_ids_mask_cpu_tensor[i1], + ) self.block_table.swap_row(i1, i2) def condense(self, empty_req_indices: list[int]) -> None: """Move non-empty requests down into lower, empty indices. - + Args: empty_req_indices: empty batch indices, sorted descending. """ @@ -454,25 +453,29 @@ class InputBatch: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] + last_req_index + ] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ + last_req_index + ] self.block_table.move_row(last_req_index, empty_index) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[ + last_req_index + ] + self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[ + last_req_index + ] + self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[ + last_req_index + ] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: @@ -483,28 +486,28 @@ class InputBatch: self.min_tokens[empty_index] = min_token self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] + last_req_index + ] self.logit_bias[empty_index] = self.logit_bias[last_req_index] if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] + self.allowed_token_ids_mask_cpu_tensor[empty_index] = ( + self.allowed_token_ids_mask_cpu_tensor[last_req_index] + ) - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) + bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids # Decrement last_req_index since it is now empty. last_req_index -= 1 # Trim lists to the batch size. - del self._req_ids[self.num_reqs:] - del self.req_output_token_ids[self.num_reqs:] + del self._req_ids[self.num_reqs :] + del self.req_output_token_ids[self.num_reqs :] def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + max_prompt_len = self.num_prompt_tokens[: self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -512,14 +515,12 @@ class InputBatch: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[: self.num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(self.num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) + prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray @@ -535,12 +536,12 @@ class InputBatch: 3. lora_requests: Set of relevant LoRA requests. """ - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + req_lora_mapping = self.request_lora_mapping[: self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) + token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) + self.lora_id_to_lora_request.values() + ) return prompt_lora_mapping, token_lora_mapping, active_lora_requests @@ -570,9 +571,11 @@ class InputBatch: @property def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) + return ( + len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0 + ) @property def max_num_logprobs(self) -> Optional[int]: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0b1c3d7c0e..5fe23c58ff 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -9,6 +9,7 @@ from unittest.mock import patch import numpy as np import torch import torch.nn as nn + # TPU XLA related import torch_xla import torch_xla.core.xla_model as xm @@ -20,46 +21,71 @@ from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import (ParallelConfig, VllmConfig, - get_layers_from_vllm_config, update_config) -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.config import ( + ParallelConfig, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - supports_transcription) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + supports_transcription, +) from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, is_text_generation_model) + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, - prev_power_of_2) -from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, - PallasAttentionBackend, - PallasMetadata, - get_page_size_bytes) -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, - LogprobsTensors, ModelRunnerOutput) +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available, prev_power_of_2 +from vllm.v1.attention.backends.pallas import ( + TPU_STR_DTYPE_TO_TORCH_DTYPE, + PallasAttentionBackend, + PallasMetadata, + get_page_size_bytes, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + SlidingWindowSpec, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, +) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, KVConnectorOutput) + KVConnectorModelRunnerMixin, + KVConnectorOutput, +) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, - bind_kv_cache, sanity_check_mm_encoder_outputs) +from .utils import ( + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + sanity_check_mm_encoder_outputs, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -107,7 +133,6 @@ MIN_NUM_SEQS = 8 # branch predictions are included as subgraph inputs to facilitate # pre-compilation. class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -139,7 +164,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) self.enforce_eager = model_config.enforce_eager @@ -155,8 +180,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: self.kv_cache_dtype = model_dtype else: - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] self._hidden_states_dtype = self.dtype self.sliding_window = model_config.get_sliding_window() @@ -164,25 +188,28 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_model_len = model_config.max_model_len self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.num_blocks_per_most_len_req = cdiv( - self.most_model_len, - self.block_size) if self.most_model_len is not None else None + self.num_blocks_per_most_len_req = ( + cdiv(self.most_model_len, self.block_size) + if self.most_model_len is not None + else None + ) # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) self.num_tokens_paddings = _get_token_paddings( min_token_size=16, max_token_size=scheduler_config.max_num_batched_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP, + ) # In case `max_num_tokens < max(num_tokens_paddings)` use the actual # padded max value to pre-allocate data structures and pre-compile. self.max_num_tokens = self.num_tokens_paddings[-1] # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + parallel_config, LayerBlockType.attention + ) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() @@ -195,17 +222,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." - self._num_slices_per_kv_cache_update_block = \ - _get_num_slices_per_kv_cache_update_block(get_page_size_bytes( - block_size=self.block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - kv_cache_dtype=self.kv_cache_dtype, - )) + self._num_slices_per_kv_cache_update_block = ( + _get_num_slices_per_kv_cache_update_block( + get_page_size_bytes( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + kv_cache_dtype=self.kv_cache_dtype, + ) + ) + ) # Lazy initialization self.model: nn.Module # Set after load_model @@ -230,52 +261,68 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + self.input_ids_cpu = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device="cpu" + ) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + self.positions_cpu = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device="cpu" + ) self.positions_np = self.positions_cpu.numpy() self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), dtype=torch.int32, - device="cpu") + device="cpu", + ) # adjust num_reqs to avoid SMEM OOM. - self.num_reqs_most_model_len = min( - PallasAttentionBackend.get_max_num_seqs(self.most_model_len, - self.block_size), - self.max_num_reqs) if self.most_model_len is not None else None + self.num_reqs_most_model_len = ( + min( + PallasAttentionBackend.get_max_num_seqs( + self.most_model_len, self.block_size + ), + self.max_num_reqs, + ) + if self.most_model_len is not None + else None + ) self.num_reqs_max_model_len = min( - PallasAttentionBackend.get_max_num_seqs(self.max_model_len, - self.block_size), - self.max_num_reqs) - self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + PallasAttentionBackend.get_max_num_seqs( + self.max_model_len, self.block_size + ), + self.max_num_reqs, + ) + self.query_start_loc_cpu = torch.zeros( + self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.seq_lens_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.seq_lens_np = self.seq_lens_cpu.numpy() # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory) + self.is_mm_embed_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory, + ) # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens # Keep in int64 to avoid overflow with long context self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) self.num_reqs_paddings = _get_req_paddings( - min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -288,27 +335,35 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): (self.max_num_reqs, cdiv(self.vocab_size, 32)), dtype=torch.int32, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.require_structured_out_cpu = torch.zeros( (self.max_num_reqs, 1), dtype=torch.bool, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory) + 0, 32, device="cpu", pin_memory=self.pin_memory + ) - self.mm_budget = (MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None) + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) if not self.use_spmd: self.sample_from_logits_func = torch.compile( self.sample_from_logits, backend="openxla", fullgraph=True, - dynamic=False) + dynamic=False, + ) else: self.sample_from_logits_func = self.sample_from_logits @@ -322,8 +377,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if new_compiled_graphs == 0: return - logger.info("Add new %d compiled XLA graphs due to %s", - new_compiled_graphs, case_str) + logger.info( + "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str + ) self.num_xla_graphs += new_compiled_graphs def _verify_num_xla_graphs(self, case_str): @@ -335,7 +391,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert self.num_xla_graphs == curr_cached_graph, ( "Recompilation after warm up is detected during {}." " num_xla_graphs = {} curr_cached_graph = {}".format( - case_str, self.num_xla_graphs, curr_cached_graph)) + case_str, self.num_xla_graphs, curr_cached_graph + ) + ) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler @@ -388,8 +446,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: - assert new_req_data.sampling_params is not None,\ + assert new_req_data.sampling_params is not None, ( "Pooling is not supported in TPU yet" + ) req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params @@ -422,8 +481,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -440,11 +498,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -513,8 +569,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -529,7 +584,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if isinstance(attn_module, ChunkedLocalAttention): logger.warning_once( "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context.") + "will fall back to global attention for long context." + ) if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -545,20 +601,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): head_size=attn_module.head_size, dtype=self.kv_cache_dtype, ) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec - def _get_slot_mapping_metadata(self, num_reqs, - num_scheduled_tokens_per_req) -> np.ndarray: + def _get_slot_mapping_metadata( + self, num_reqs, num_scheduled_tokens_per_req + ) -> np.ndarray: """ Computes metadata for mapping slots to blocks in the key-value (KV) cache for a batch of requests. @@ -583,14 +641,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - slice_len (int): The length of the slice. """ slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] - slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ - num_scheduled_tokens_per_req + slices_end = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req + ) local_block_start_idx = slices_start // self.block_size local_block_end_idx = (slices_end - 1) // self.block_size no_repeat_req_indices = self.arange_np[:num_reqs] global_block_start_idx = ( - no_repeat_req_indices * self.max_num_blocks_per_req + - local_block_start_idx) + no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx + ) block_lens = local_block_end_idx - local_block_start_idx + 1 global_block_start_idx = np.repeat(global_block_start_idx, block_lens) slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) @@ -598,30 +658,31 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() total_block_len = np.sum(block_lens) - slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], - dtype=np.int32), - total_block_len, - axis=0) + slot_mapping_slices = np.repeat( + np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0 + ) cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) np.cumsum(block_lens, out=cu_block_lens[1:]) for req_idx in range(num_reqs): - slot_mapping_slices[cu_block_lens[req_idx]][ - 0] = slices_start[req_idx] % self.block_size - slot_mapping_slices[ - cu_block_lens[req_idx + 1] - - 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slot_mapping_slices[cu_block_lens[req_idx]][0] = ( + slices_start[req_idx] % self.block_size + ) + slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = ( + slices_end[req_idx] - 1 + ) % self.block_size + 1 slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) np.cumsum(slice_lens, out=cu_slices_lens[1:]) - kv_cache_start_indices = slot_mapping_slices[:, 0] + \ - (block_numbers * self.block_size) + kv_cache_start_indices = slot_mapping_slices[:, 0] + ( + block_numbers * self.block_size + ) new_kv_start_indices = cu_slices_lens[:-1] slot_mapping_metadata = np.stack( - [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1 + ) return slot_mapping_metadata - def _prepare_inputs(self, scheduler_output: "SchedulerOutput", - start_index: int): + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): assert scheduler_output.total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 @@ -643,22 +704,24 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens_per_req.append(num_tokens) if use_max_model_len: if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len: - num_scheduled_tokens_per_req = \ - num_scheduled_tokens_per_req[:self.num_reqs_max_model_len] + num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ + : self.num_reqs_max_model_len + ] end_index = start_index + self.num_reqs_max_model_len else: end_index = num_reqs else: - if len(num_scheduled_tokens_per_req - ) > self.num_reqs_most_model_len: - num_scheduled_tokens_per_req = \ - num_scheduled_tokens_per_req[:self.num_reqs_most_model_len] + if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: + num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ + : self.num_reqs_most_model_len + ] end_index = start_index + self.num_reqs_most_model_len else: end_index = num_reqs max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req) - num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, - dtype=np.int32) + num_scheduled_tokens_per_req = np.array( + num_scheduled_tokens_per_req, dtype=np.int32 + ) total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req) assert max_num_scheduled_tokens_all_reqs > 0 @@ -667,121 +730,130 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # For each scheduled token, what are the corresponding req index. - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_per_req) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # For each scheduled token, what is its position in corresponding req. arange = np.concatenate( - [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + [self.arange_np[:n] for n in num_scheduled_tokens_per_req] + ) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 - np.cumsum(num_scheduled_tokens_per_req, - out=self.query_start_loc_np[1:num_reqs + 1]) - self.query_start_loc_np[num_reqs + 1:] = 1 + np.cumsum( + num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1] + ) + self.query_start_loc_np[num_reqs + 1 :] = 1 self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens_per_req) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req + ) # Do the padding and copy the tensors to the TPU. padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens) + self.num_tokens_paddings, total_num_scheduled_tokens + ) # Zero out to avoid spurious values from prev iteration (last cp chunk) self.input_ids_cpu[ - total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 - self.input_ids = self.input_ids_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) - self.position_ids = self.positions_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) + total_num_scheduled_tokens:padded_total_num_scheduled_tokens + ] = 0 + self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to( + self.device + ) + self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to( + self.device + ) if use_max_model_len: - block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : - self.max_num_blocks_per_req] - block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) - query_start_loc = self.query_start_loc_cpu[:self. - num_reqs_max_model_len + - 1].to(self.device) - seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to( - self.device) + block_tables = self.block_table_cpu[ + : self.num_reqs_max_model_len, : self.max_num_blocks_per_req + ] + block_tables[:num_reqs, : self.max_num_blocks_per_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs] + ) + query_start_loc = self.query_start_loc_cpu[ + : self.num_reqs_max_model_len + 1 + ].to(self.device) + seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) else: - block_tables = self.block_table_cpu[:self. - num_reqs_most_model_len, :self. - num_blocks_per_most_len_req] - block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = ( - self.input_batch.block_table[0].get_cpu_tensor() - [:num_reqs, :self.num_blocks_per_most_len_req]) - query_start_loc = self.query_start_loc_cpu[:self. - num_reqs_most_model_len + - 1].to(self.device) - seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to( - self.device) + block_tables = self.block_table_cpu[ + : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req + ] + block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[ + :num_reqs, : self.num_blocks_per_most_len_req + ] + ) + query_start_loc = self.query_start_loc_cpu[ + : self.num_reqs_most_model_len + 1 + ].to(self.device) + seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device) block_tables = block_tables.to(self.device) # Calculate the slot mapping slot_mapping_metadata = self._get_slot_mapping_metadata( - num_reqs, num_scheduled_tokens_per_req) + num_reqs, num_scheduled_tokens_per_req + ) num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( - padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size) + padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size + ) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], - constant_values=0) + constant_values=0, + ) slot_mapping_metadata = np.transpose(slot_mapping_metadata) - slot_mapping_metadata = torch.tensor(slot_mapping_metadata, - device=self.device) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device) if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ + padded_num_scheduled_tokens_per_req[-1] += ( padded_total_num_scheduled_tokens - total_num_scheduled_tokens + ) - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, - num_seqs=torch.tensor([num_reqs], - dtype=torch.int32, - device=self.device), - num_kv_update_slices=torch.tensor([num_kv_update_slices], - dtype=torch.int32, - device=self.device), - num_slices_per_kv_cache_update_block=self. - _num_slices_per_kv_cache_update_block, + num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + num_kv_update_slices=torch.tensor( + [num_kv_update_slices], dtype=torch.int32, device=self.device + ), + num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -789,10 +861,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # token from the partial request. # TODO: Support prompt logprobs. padded_num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs, self.max_num_reqs) + num_reqs, self.max_num_reqs + ) # Indices at which we sample (positions of last token in the sequence). # Padded to avoid recompiling when `num_reqs` varies. - logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 + logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) if self.lora_config is not None: @@ -800,20 +873,23 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ + padded_num_scheduled_tokens_per_req[-1] += ( padded_total_num_scheduled_tokens - total_num_scheduled_tokens + ) - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names + layer_name: attn_metadata for layer_name in layer_names } - return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ - num_reqs, end_index + return ( + per_layer_attn_metadata, + logits_indices, + padded_num_reqs, + num_reqs, + end_index, + ) def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs @@ -843,10 +919,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Run the encoder. # `curr_group_outputs` is either of the following: @@ -856,8 +932,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. torch_xla.sync(wait=False) - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) torch_xla.sync(wait=False) sanity_check_mm_encoder_outputs( @@ -877,8 +952,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - assert pos_info.is_embed is None, "Expected all positions to be"\ - " contiguous and embeddings." + assert pos_info.is_embed is None, ( + "Expected all positions to be contiguous and embeddings." + ) self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( @@ -887,7 +963,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> tuple[list[torch.Tensor], torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens) + self.num_tokens_paddings, total_num_scheduled_tokens + ) is_mm_embed = self.is_mm_embed_cpu is_mm_embed[:padded_total_num_scheduled_tokens] = False @@ -895,8 +972,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_start_idx = 0 for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens @@ -930,23 +1006,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." - assert pos_info.is_embed is None, "Expected all positions to"\ - " be contiguous and embeddings." + assert pos_info.is_embed is None, ( + "Expected all positions to be contiguous and embeddings." + ) req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True # Only whole mm items are processed mm_embeds.append(encoder_output) req_start_idx += num_scheduled_tokens - is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \ - .to(self.device) + is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device) return mm_embeds, is_mm_embed @@ -988,8 +1062,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) if self.supports_mm_inputs: # Run the multimodal encoder if any. @@ -1011,41 +1084,48 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) while start_index < self.input_batch.num_reqs: - attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ - end_index = self._prepare_inputs(scheduler_output, start_index) + attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = ( + self._prepare_inputs(scheduler_output, start_index) + ) input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embed_inputs) + self.input_ids, mm_embed_inputs + ) torch_xla.sync(wait=False) # Run the decoder with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens): + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens, + ): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, inputs_embeds=inputs_embeds, ) - hidden_states = self.select_hidden_states(hidden_states, - logits_indices) + hidden_states = self.select_hidden_states(hidden_states, logits_indices) logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, padded_num_reqs, self.device) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, padded_num_reqs, self.device + ) if scheduler_output.grammar_bitmask is not None: - require_struct_decoding, grammar_bitmask_padded, arange = \ - self.prepare_structured_decoding_input(logits, - scheduler_output) - logits = self.structured_decode(require_struct_decoding, - grammar_bitmask_padded, logits, - arange) + require_struct_decoding, grammar_bitmask_padded, arange = ( + self.prepare_structured_decoding_input(logits, scheduler_output) + ) + logits = self.structured_decode( + require_struct_decoding, grammar_bitmask_padded, logits, arange + ) selected_token_ids = self.sample_from_logits_func( - logits, tpu_sampling_metadata) + logits, tpu_sampling_metadata + ) # NOTE (NickLucche) Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. We can't enforce it # due to recompilations outside torch.compiled code, so just make # sure `sample_from_logits` does not modify the logits in-place. - logprobs = self.gather_logprobs(logits, selected_token_ids) \ - if tpu_sampling_metadata.logprobs else None + logprobs = ( + self.gather_logprobs(logits, selected_token_ids) + if tpu_sampling_metadata.logprobs + else None + ) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] @@ -1061,8 +1141,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # should be called right after each single forward pass, # instead of the forwards of the entire input batch. self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + finished_sending, finished_recving = self.get_finished_kv_transfers( + scheduler_output + ) selected_token_ids = torch.cat(combined_selected_tokens, dim=0) if tpu_sampling_metadata.logprobs: @@ -1073,16 +1154,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): result.extend(input_list) return result - logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists( - [lp.logprob_token_ids for lp in combined_logprobs]), - logprobs=concat_lists([ - lp.logprobs - for lp in combined_logprobs - ]), - sampled_token_ranks=concat_lists([ - lp.sampled_token_ranks - for lp in combined_logprobs - ])) + logprobs_lists = LogprobsLists( + logprob_token_ids=concat_lists( + [lp.logprob_token_ids for lp in combined_logprobs] + ), + logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]), + sampled_token_ranks=concat_lists( + [lp.sampled_token_ranks for lp in combined_logprobs] + ), + ) else: logprobs_lists = None @@ -1094,8 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) + seq_len = ( + req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id] + ) if seq_len >= req_state.num_tokens: request_seq_lens.append((i, req_state, seq_len)) else: @@ -1111,8 +1193,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): discard_sampled_tokens_req_indices.append(i) assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_id is not None for req_id in self.input_batch.req_ids[:num_reqs] + ), "req_ids contains None" req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} @@ -1140,22 +1222,24 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ - seq.tolist() - for seq in selected_token_ids[valid_mask].split(gen_lens) + seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) ] self.input_batch.num_tokens[:num_reqs] += gen_lens for i, req_state, seq_len in request_seq_lens: target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] + self.input_batch.token_ids_cpu[i, target_slice] = ( + valid_sampled_token_ids[i] + ) req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - kv_connector_output = None if ( - finished_sending is None - and finished_recving is None) else KVConnectorOutput( + kv_connector_output = ( + None + if (finished_sending is None and finished_recving is None) + else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, ) + ) model_runner_output = ModelRunnerOutput( req_ids=req_ids, @@ -1178,9 +1262,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -1199,30 +1284,34 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the embedding weights. xm_tp_rank = xr.global_ordinal() with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank, + ): try: if self.use_spmd: tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config) + load_config=self.vllm_config.load_config + ) model = tpu_loader.load_model( vllm_config=self.vllm_config, model_config=self.vllm_config.model_config, - mesh=self.mesh) + mesh=self.mesh, + ) else: model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) except RuntimeError as e: raise RuntimeError( f"Unable to load model, a likely reason is the model is " "too large for the current device's HBM memory. " "Consider switching to a smaller model " "or sharding the weights on more chips. " - f"See the detailed error: {e}") from e + f"See the detailed error: {e}" + ) from e if self.lora_config is not None: model = self.load_lora_model(model, self.vllm_config, self.device) replace_set_lora(model) @@ -1236,44 +1325,43 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.sampler = TPUSampler() def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") model_loader.load_weights(self.model, model_config=self.model_config) @torch.no_grad() - def _dummy_run(self, num_tokens: int, num_reqs: int, - num_blocks: int) -> None: + def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: if self.supports_mm_inputs: input_ids = None - inputs_embeds = torch.zeros((num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + inputs_embeds = torch.zeros( + (num_tokens, self.hidden_size), dtype=self.dtype, device=self.device + ) else: - input_ids = torch.zeros((num_tokens), - dtype=torch.int32).to(self.device) + input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device) inputs_embeds = None actual_num_reqs = min(num_tokens, num_reqs) - position_ids = torch.zeros(num_tokens, - dtype=torch.int32).to(self.device) + position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size) - num_kv_update_slices = torch.tensor([padded_num_slices], - dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros((3, padded_num_slices), - dtype=torch.int32).to(self.device) - block_tables = torch.zeros((num_reqs, num_blocks), - dtype=torch.int32).to(self.device) + num_tokens, self.max_num_reqs, self.block_size + ) + num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to( + self.device + ) + slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to( + self.device + ) + block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to( + self.device + ) query_lens = [1] * num_reqs - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.int32), - dim=0, - dtype=torch.int32).to(self.device) - context_lens = torch.ones((num_reqs, ), - dtype=torch.int32).to(self.device) - num_seqs = torch.tensor([actual_num_reqs], - dtype=torch.int32).to(self.device) + query_start_loc = torch.cumsum( + torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ).to(self.device) + context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device) + num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -1281,8 +1369,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc=query_start_loc, num_seqs=num_seqs, num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block=self. - _num_slices_per_kv_cache_update_block, + num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, ) if self.supports_mm_inputs: @@ -1295,27 +1382,29 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names + layer_name: attn_metadata for layer_name in layer_names } - with self.maybe_select_dummy_loras( - self.lora_config, - np.array([num_tokens], dtype=np.int32)), set_forward_context( - per_layer_attn_metadata, self.vllm_config, 0): - out = self.model(input_ids=input_ids, - positions=position_ids, - inputs_embeds=inputs_embeds) + with ( + self.maybe_select_dummy_loras( + self.lora_config, np.array([num_tokens], dtype=np.int32) + ), + set_forward_context(per_layer_attn_metadata, self.vllm_config, 0), + ): + out = self.model( + input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds + ) self._hidden_states_dtype = out.dtype - def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, - lora_requests) -> None: + def _set_active_loras( + self, prompt_lora_mapping, token_lora_mapping, lora_requests + ) -> None: torch_xla.sync(wait=False) # Captures input updates - super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) + super()._set_active_loras( + prompt_lora_mapping, token_lora_mapping, lora_requests + ) torch_xla.sync(wait=False) # Captures metadata updates def _precompile_mm_encoder(self) -> None: @@ -1332,8 +1421,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for mode, max_items_per_seq in max_items_per_seq_by_modality.items(): logger.info( - "Compiling Multimodal %s Encoder with different input" - " shapes.", mode) + "Compiling Multimodal %s Encoder with different input shapes.", mode + ) start = time.perf_counter() # No padding for MM encoder just yet. for num_items in range(1, max_items_per_seq + 1): @@ -1345,7 +1434,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Run multimodal encoder. torch_xla.sync(wait=False) mm_embeds = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + **batched_dummy_mm_inputs + ) torch_xla.sync(wait=False) num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items @@ -1359,12 +1449,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # XLA Workaround: if torch.zeros(..device) is used, XLA # compiles a scalar+expansion op, which won't match # the graph generated at runtime. CPU->TPU must be used - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") + placeholders_ids = torch.zeros( + num_tokens, dtype=torch.int32, device="cpu" + ) # Align placeholders and actual num mm_embeddings. - placeholders_ids[:items_size] = \ - hf_config.image_token_index + placeholders_ids[:items_size] = hf_config.image_token_index placeholders_ids = placeholders_ids.to(self.device) @@ -1382,9 +1471,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Pre-compile `get_input_embeddings` when mm_embeddings are not # present. Chunk is only made of text, no mm_placeholders. for num_tokens in self.num_tokens_paddings: - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") + placeholders_ids = torch.zeros( + num_tokens, dtype=torch.int32, device="cpu" + ) placeholders_ids = placeholders_ids.to(self.device) a, b = self._get_model_inputs( placeholders_ids, @@ -1396,19 +1485,25 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): xm.wait_device_ops() end = time.perf_counter() logger.info( - "Multimodal %s Encoder compilation finished in in %.2f " - "[secs].", mode, end - start) + "Multimodal %s Encoder compilation finished in in %.2f [secs].", + mode, + end - start, + ) def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run(num_tokens, self.num_reqs_max_model_len, - self.max_num_blocks_per_req) + self._dummy_run( + num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req + ) if self.most_model_len is not None: - self._dummy_run(num_tokens, self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req) + self._dummy_run( + num_tokens, + self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req, + ) xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in %.2f [secs].", end - start) @@ -1417,23 +1512,19 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _precompile_select_hidden_states(self) -> None: # Compile hidden state selection function for bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. - logger.info( - "Compiling select_hidden_states with different input shapes.") + logger.info("Compiling select_hidden_states with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_tokens in self.num_tokens_paddings: - dummy_hidden = torch.zeros((num_tokens, hsize), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_hidden = torch.zeros( + (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype + ) torch._dynamo.mark_dynamic(dummy_hidden, 0) for num_reqs in self.num_reqs_paddings: - indices = torch.zeros(num_reqs, - dtype=torch.int32, - device=self.device) + indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) torch._dynamo.mark_dynamic(indices, 0) self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, - num_reqs) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs) # Requests can't be more than tokens. But do compile for the # next bigger value in case num_tokens uses bucketed padding. if num_reqs >= min(num_tokens, self.max_num_reqs): @@ -1448,9 +1539,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_reqs in self.num_reqs_paddings: - dummy_hidden = torch.zeros((num_reqs, hsize), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_hidden = torch.zeros( + (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype + ) torch._dynamo.mark_dynamic(dummy_hidden, 0) self.compute_logits(dummy_hidden) logger.info(" -- num_seqs: %d", num_reqs) @@ -1460,23 +1551,28 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._update_num_xla_graphs("compute_logits") def _precompile_structured_decoding(self) -> None: - logger.info( - "Compiling structured_decoding with different input shapes.") + logger.info("Compiling structured_decoding with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_require_struct_decoding = \ - self.require_structured_out_cpu[:num_reqs].to(self.device) - dummy_grammar_bitmask = \ - self.grammar_bitmask_cpu[:num_reqs].to(self.device) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) + dummy_require_struct_decoding = self.require_structured_out_cpu[ + :num_reqs + ].to(self.device) + dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device) # The first dimension of the above 3 dummy tensors cannot be # mark_dynamic because some operations in structured_decode require # them to be static. arange = self.structured_decode_arange.to(self.device) - self.structured_decode(dummy_require_struct_decoding, - dummy_grammar_bitmask, dummy_logits, arange) + self.structured_decode( + dummy_require_struct_decoding, + dummy_grammar_bitmask, + dummy_logits, + arange, + ) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1484,30 +1580,29 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._update_num_xla_graphs("structured_decoding") def _precompile_sample_from_logits(self) -> None: - logger.info( - "Compiling sample_from_logits with different input shapes.") + logger.info("Compiling sample_from_logits with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) # The first dimension of dummy_logits cannot be mark_dynamic # because some operations in the sampler require it to be static. for all_greedy in [False, True]: generate_params_if_all_greedy = not all_greedy - sampling_metadata = ( - TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, - num_reqs, - self.device, - generate_params_if_all_greedy, - )) + sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, + num_reqs, + self.device, + generate_params_if_all_greedy, + ) sampling_metadata.all_greedy = all_greedy with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], - dtype=np.int32)): - self.sample_from_logits_func(dummy_logits, - sampling_metadata) + self.lora_config, np.array([num_reqs], dtype=np.int32) + ): + self.sample_from_logits_func(dummy_logits, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1518,13 +1613,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Compiling gather_logprobs with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_tokens = torch.zeros((num_reqs, 1), - dtype=torch.int64).to(self.device) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) + dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device) with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32)): + self.lora_config, np.array([num_reqs], dtype=np.int32) + ): self.gather_logprobs(dummy_logits, dummy_tokens) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() @@ -1554,7 +1651,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -1565,8 +1663,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -1588,15 +1687,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # impact of recompilation until it's fixed. start = time.perf_counter() torch_xla.sync(wait=False) - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() logger.info( "Multimodal Encoder profiling finished in %.2f [secs].", - end - start) + end - start, + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -1604,15 +1704,18 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - self._dummy_run(num_tokens, self.num_reqs_max_model_len, - self.max_num_blocks_per_req) + self._dummy_run( + num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req + ) if self.most_model_len is not None: - self._dummy_run(num_tokens, self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req) + self._dummy_run( + num_tokens, + self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req, + ) torch_xla.sync(wait=False) xm.wait_device_ops() @@ -1637,10 +1740,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config.kv_cache_groups, ) - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: @@ -1652,11 +1753,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") + "Hybrid models with more than one KV cache type are not supported yet." + ) - if kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size != self.block_size: + if ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + != self.block_size + ): self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -1669,14 +1772,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ], ) # Verify dtype compatibility between block_table_cpu and input_batch - assert self.block_table_cpu.dtype == self.input_batch.block_table[ - 0].get_cpu_tensor().dtype + assert ( + self.block_table_cpu.dtype + == self.input_batch.block_table[0].get_cpu_tensor().dtype + ) kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "TPU.") + "KV cache tensor shared by multiple layers is not supported in TPU." + ) kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size kv_caches: dict[str, torch.Tensor] = {} @@ -1690,19 +1795,23 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.use_spmd: num_kv_heads = kv_cache_spec.num_kv_heads assert self.original_parallel_config is not None - tp_size = \ - self.original_parallel_config.tensor_parallel_size + tp_size = self.original_parallel_config.tensor_parallel_size # TODO: Handle kv cache duplication under SPMD mode. assert num_kv_heads % tp_size == 0, ( f"num_kv_heads {num_kv_heads} must be divisible by " - f"tp_size {tp_size} under SPMD mode") + f"tp_size {tp_size} under SPMD mode" + ) kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) dtype = kv_cache_spec.dtype - tpu_kv_cache = torch.zeros(kv_cache_shape, - dtype=dtype).to(self.device) + tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to( + self.device + ) kv_caches[layer_name] = tpu_kv_cache else: @@ -1714,19 +1823,19 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + self.kv_caches, + ) if self.use_spmd: # Shard KV Cache for cache in self.kv_caches: - xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) + xs.mark_sharding(cache, self.mesh, (None, "x", None, None)) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): - # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` # since the compiled model object of the language backbone of a # multimodal model needs to be extracted via `get_language_model`. @@ -1737,7 +1846,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): logger.info("Clear dynamo cache and cached dynamo bytecode.") torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object) + compiled_model.original_code_object + ) compiled_model.compiled_codes.clear() @torch.compile(backend="openxla", fullgraph=True, dynamic=False) @@ -1745,30 +1855,29 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return hidden_states[indices_do_sample] @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits(self, - sample_hidden_states: torch.Tensor) -> torch.Tensor: + def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: return self.model.compute_logits(sample_hidden_states) # TODO: Under SPMD mode, sample_from_logits has correctness issue. # Re-enable the torch.compile once the issue is fixed in torchxla. # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_logits( - self, logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: + self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata + ) -> torch.Tensor: """ - Sample with xla-friendly function. This function is to be traced + Sample with xla-friendly function. This function is to be traced separately from `forward` for lighter compilation overhead. """ if sampling_metadata.all_greedy: out_tokens = torch.argmax(logits, dim=-1, keepdim=True) else: - out_tokens = self.sampler(logits, - sampling_metadata).sampled_token_ids + out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids return out_tokens @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def gather_logprobs(self, logits: torch.Tensor, - sampled_tokens: torch.Tensor) -> LogprobsTensors: + def gather_logprobs( + self, logits: torch.Tensor, sampled_tokens: torch.Tensor + ) -> LogprobsTensors: """ Gather the top_logprobs with corresponding tokens. Use a fixed number of logprobs as an alternative to having multiple pre-compiled graphs. @@ -1778,28 +1887,37 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.sampler.gather_logprobs( logprobs, self.model_config.max_logprobs, - token_ids=sampled_tokens.squeeze(-1)) + token_ids=sampled_tokens.squeeze(-1), + ) @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode(self, require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, logits: torch.Tensor, - arange: torch.Tensor) -> torch.Tensor: + def structured_decode( + self, + require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, + logits: torch.Tensor, + arange: torch.Tensor, + ) -> torch.Tensor: return torch.where( require_struct_decoding, self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits) + logits, + ) - def apply_grammar_bitmask(self, logits: torch.Tensor, - grammar_bitmask: torch.Tensor, - arange: torch.Tensor): - assert (logits.shape[0] == grammar_bitmask.shape[0]) + def apply_grammar_bitmask( + self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor + ): + assert logits.shape[0] == grammar_bitmask.shape[0] logits_cloned = logits.clone() for i in range(logits.shape[0]): - unpacked_bitmask = (torch.bitwise_right_shift( - grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] + unpacked_bitmask = ( + torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :]) + & 1 + ) == 0 + unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size] logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf")) + unpacked_bitmask, -float("inf") + ) return logits_cloned def get_multimodal_embeddings(self, *args, **kwargs): @@ -1821,23 +1939,27 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sorted_struct_requests = sorted( scheduler_output.structured_output_request_ids.items(), - key=lambda item: item[1]) + key=lambda item: item[1], + ) cumulative_mask_idx = 0 for req_id, _ in sorted_struct_requests: if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( - grammar_bitmask[cumulative_mask_idx]) + grammar_bitmask[cumulative_mask_idx] + ) # It's not guaranteed that all requests in this batch require # structured output, so create a bool tensor to represent # the requests that need structured output. self.require_structured_out_cpu[batch_index] = True cumulative_mask_idx += 1 - return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ - self.structured_decode_arange.to(logits.device) + return ( + self.require_structured_out_cpu[:num_reqs].to(logits.device), + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), + self.structured_decode_arange.to(logits.device), + ) def _get_mm_dummy_batch( self, @@ -1860,13 +1982,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_mm_items = [dummy_mm_item] * max_items_per_batch model = cast(SupportsMultiModal, self.model) - return next(grouped_mm_kwargs - for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + grouped_mm_kwargs + for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: @@ -1887,9 +2011,10 @@ def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: return min(res, upper_limit) -def _get_token_paddings(min_token_size: int, max_token_size: int, - padding_gap: int) -> list[int]: - """Generate a list of padding size, starting from min_token_size, +def _get_token_paddings( + min_token_size: int, max_token_size: int, padding_gap: int +) -> list[int]: + """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size If padding_gap == 0 then: @@ -1927,15 +2052,15 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, def _get_padded_token_len(paddings: list[int], x: int) -> int: - """Return the first element in paddings list greater or equal to x. - """ + """Return the first element in paddings list greater or equal to x.""" index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index] -def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, - page_size: int) -> int: +def _get_padded_num_kv_cache_update_slices( + num_tokens: int, max_num_reqs: int, page_size: int +) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" # NOTE(chengjiyao): let's say R_i is the token num for i-th request, @@ -1971,7 +2096,6 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: def replace_set_lora(model): - def _tpu_set_lora( self, index: int, @@ -1995,5 +2119,4 @@ def replace_set_lora(model): module._original_set_lora = module.set_lora module._original_reset_lora = module.reset_lora module.set_lora = _tpu_set_lora.__get__(module, module.__class__) - module.reset_lora = _tpu_reset_lora.__get__( - module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index d4f0a65f2a..66515c7e57 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -11,10 +11,14 @@ import torch.nn as nn import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + has_kv_transfer_group, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -23,8 +27,7 @@ from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, - KVCacheSpec) +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.utils import bind_kv_cache @@ -44,7 +47,6 @@ if not USE_TPU_COMMONS: class TPUWorker: - def __init__( self, vllm_config: VllmConfig, @@ -82,12 +84,12 @@ class TPUWorker: if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype] + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() # Delay profiler initialization to the start of the profiling. @@ -100,14 +102,14 @@ class TPUWorker: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - self.profile_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", self.profile_dir + ) if self.model_config.seed is None: self.model_config.seed = 0 - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -118,9 +120,10 @@ class TPUWorker: # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + - " --xla_tpu_force_1d_allreduce_at_chunk_count=1" - " --xla_jf_conv_input_fusion=False") + os.environ.get("LIBTPU_INIT_ARGS", "") + + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" + " --xla_jf_conv_input_fusion=False" + ) # --xla_jf_conv_input_fusion=False is used to improve the perf of # quantized matmul. torch.set_grad_enabled(False) @@ -128,8 +131,8 @@ class TPUWorker: # Initialize the distributed environment. self._init_tpu_worker_distributed_environment( - self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank) + self.vllm_config, self.rank, self.distributed_init_method, self.local_rank + ) # Device initialization should happen after initializing # the distributed runtime. @@ -158,14 +161,15 @@ class TPUWorker: # cache during development is recommended.We can disable it by # `export VLLM_XLA_CACHE_PATH=` if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") + per_rank_path = os.path.join( + envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}" + ) xr.initialize_cache(per_rank_path, readonly=False) # Init ModelRunner here, so that we have access to self.device. - self.model_runner = \ - TPUModelRunner(self.vllm_config, self.device, - self.original_parallel_config) + self.model_runner = TPUModelRunner( + self.vllm_config, self.device, self.original_parallel_config + ) if rank == 0: # If usage stat is enabled, collect relevant info. @@ -184,13 +188,15 @@ class TPUWorker: kv_caches[layer_name] = tpu_kv_cache else: raise NotImplementedError( - f"Unsupported KV cache spec '{type(layer_spec)}'") + f"Unsupported KV cache spec '{type(layer_spec)}'" + ) runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches) + runner_kv_caches, + ) # `max_num_tokens >= max_num_batched_tokens` due to padding. with self.model_runner.maybe_setup_dummy_loras(self.lora_config): @@ -215,6 +221,7 @@ class TPUWorker: # TODO: use xm.get_memory_info for SPMD once it's supported in # PyTorch/XLA. import tpu_info + chip_type, _ = tpu_info.device.get_local_chips() device_usage = tpu_info.metrics.get_chip_usage(chip_type) total_memory_size = device_usage[0].total_memory @@ -231,20 +238,20 @@ class TPUWorker: profiled = current_mem * 1.02 # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int(total_memory_size * - self.cache_config.gpu_memory_utilization) + usable_memory_size = int( + total_memory_size * self.cache_config.gpu_memory_utilization + ) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) head_size = self.model_config.get_head_size() if head_size > 0: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) if padded_head_size != head_size: - logger.warning_once("head size is padded to %d", - padded_head_size) + logger.warning_once("head size is padded to %d", padded_head_size) # We adjust the usable memory size for the KV cache to prevent OOM # errors, even after padding the head_size. - tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size // - padded_head_size) + tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size return int(tpu_kv_cache_bytes) def execute_model( @@ -253,8 +260,7 @@ class TPUWorker: ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) # every worker's output is needed when kv_transfer_group is set up - return output if self.is_driver_worker or has_kv_transfer_group( - ) else None + return output if self.is_driver_worker or has_kv_transfer_group() else None def profile(self, is_start: bool = True): if self.rank < 1: @@ -327,8 +333,8 @@ class TPUWorker: backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py index 7767750aa6..6723239e84 100644 --- a/vllm/v1/worker/ubatch_splitting.py +++ b/vllm/v1/worker/ubatch_splitting.py @@ -10,8 +10,11 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.forward_context import DPMetadata from vllm.logger import init_logger from vllm.utils import round_up -from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices, - is_second_ubatch_empty) +from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, + UBatchSlices, + is_second_ubatch_empty, +) logger = init_logger(__name__) @@ -24,14 +27,18 @@ def should_ubatch_with_num_tokens( ) -> tuple[bool, Optional[torch.Tensor]]: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank - return DPMetadata.should_ubatch_across_dp(should_ubatch, - orig_num_tokens_per_ubatch, - padded_num_tokens_per_ubatch, - dp_size, dp_rank) + return DPMetadata.should_ubatch_across_dp( + should_ubatch, + orig_num_tokens_per_ubatch, + padded_num_tokens_per_ubatch, + dp_size, + dp_rank, + ) -def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int, - uniform_decode: bool) -> bool: +def check_ubatch_thresholds( + config: ParallelConfig, num_tokens: int, uniform_decode: bool +) -> bool: if not config.enable_dbo: return False if uniform_decode: @@ -41,9 +48,11 @@ def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int, def get_dp_padding_ubatch( - num_tokens_unpadded: int, num_tokens_padded: int, - should_attempt_ubatching: bool, - vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]: + num_tokens_unpadded: int, + num_tokens_padded: int, + should_attempt_ubatching: bool, + vllm_config: VllmConfig, +) -> tuple[bool, Optional[torch.Tensor]]: """ 1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. If this function decides @@ -71,7 +80,8 @@ def get_dp_padding_ubatch( # If this DP rank doesn't want to attempt microbatching if not should_attempt_ubatching: (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( - False, 0, 0, vllm_config) + False, 0, 0, vllm_config + ) assert should_ubatch is False assert num_tokens_across_dp is None return should_ubatch, num_tokens_across_dp @@ -85,14 +95,16 @@ def get_dp_padding_ubatch( # ubatch. Abort if so if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded): logger.debug( - "Empty second µbatch detected: unpadded tokens: %s, padded " - "tokens: %s", num_tokens_unpadded, num_tokens_padded) + "Empty second µbatch detected: unpadded tokens: %s, padded tokens: %s", + num_tokens_unpadded, + num_tokens_padded, + ) should_ubatch = False # Note that we compute the number of padded tokens per ubatch (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( - should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, - vllm_config) + should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, vllm_config + ) if not should_ubatch: assert num_tokens_across_dp is None return should_ubatch, num_tokens_across_dp @@ -100,14 +112,15 @@ def get_dp_padding_ubatch( assert num_tokens_across_dp is not None max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item()) - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor( + [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 + ) return should_ubatch, num_tokens_after_padding -def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \ - -> UBatchSlices: + +def create_ubatch_slices( + num_scheduled_tokens: np.ndarray, split_point: int +) -> UBatchSlices: # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) @@ -119,19 +132,20 @@ def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \ # Determine request slices using exclusive stop semantics # First ubatch includes requests whose tokens overlap [0, split_point) first_ubatch_req_stop = int( - np.searchsorted(cu_num_tokens, split_point, side="left")) + np.searchsorted(cu_num_tokens, split_point, side="left") + ) first_ubatch_req_slice = slice(0, first_ubatch_req_stop) # Second ubatch starts at the request that contains the split_point # or the request starting exactly at split_point (if on boundary) second_ubatch_req_start = int( - np.searchsorted(cu_num_tokens, split_point, side="right") - 1) - second_ubatch_req_slice = slice(second_ubatch_req_start, - len(cu_num_tokens) - 1) + np.searchsorted(cu_num_tokens, split_point, side="right") - 1 + ) + second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) return [ UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), - UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice) + UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), ] @@ -147,7 +161,7 @@ def ubatch_split( should be split into microbatches. Returns: tuple[ - ubatch_slices: if this is set then all DP ranks have agreed to + ubatch_slices: if this is set then all DP ranks have agreed to microbatch num_tokens_after_padding: A tensor containing the total number of tokens per-microbatch for each DP rank including padding. Will be @@ -186,7 +200,8 @@ def ubatch_split( assert num_tokens_after_padding is not None token_split_point = int(num_tokens_after_padding[0].item()) - ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request, - token_split_point) + ubatch_slices = create_ubatch_slices( + num_scheduled_tokens_per_request, token_split_point + ) return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 33d58aa948..2deba16f8a 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -11,8 +11,10 @@ class UBatchSlice: token_slice: slice def is_empty(self) -> bool: - return self.request_slice.start == self.request_slice.stop \ + return ( + self.request_slice.start == self.request_slice.stop or self.token_slice.start == self.token_slice.stop + ) @property def num_tokens(self) -> int: @@ -22,6 +24,7 @@ class UBatchSlice: UBatchSlices: TypeAlias = list[UBatchSlice] -def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int, - padded_num_tokens_per_ubatch: int) -> bool: +def is_second_ubatch_empty( + orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int +) -> bool: return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index c26cb07123..867ce2b930 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -10,7 +10,7 @@ from vllm.forward_context import ForwardContext from vllm.utils import current_stream _THREAD_ID_TO_CONTEXT: dict = {} -_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None] +_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] class UBatchContext: @@ -18,17 +18,19 @@ class UBatchContext: Context manager for micro-batching synchronization using threading events. """ - def __init__(self, - id: int, - comm_stream: torch.cuda.Stream, - compute_stream: torch.cuda.Stream, - forward_context: ForwardContext, - ready_barrier: threading.Barrier, - cpu_wait_event: threading.Event, - cpu_signal_event: threading.Event, - gpu_comm_done_event: torch.cuda.Event, - gpu_compute_done_event: torch.cuda.Event, - schedule: str = "default"): + def __init__( + self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + forward_context: ForwardContext, + ready_barrier: threading.Barrier, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default", + ): self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream @@ -151,7 +153,6 @@ def dbo_current_ubatch_id() -> int: def _register_ubatch_function(func): - def wrapper(*args, **kwargs): if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] @@ -161,20 +162,20 @@ def _register_ubatch_function(func): return wrapper -dbo_maybe_run_recv_hook = _register_ubatch_function( - UBatchContext.maybe_run_recv_hook) +dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook) dbo_yield = _register_ubatch_function(UBatchContext.yield_) dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( - UBatchContext.yield_and_switch_from_compute_to_comm) + UBatchContext.yield_and_switch_from_compute_to_comm +) dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( - UBatchContext.yield_and_switch_from_comm_to_compute) + UBatchContext.yield_and_switch_from_comm_to_compute +) dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm) -dbo_switch_to_compute = _register_ubatch_function( - UBatchContext.switch_to_compute) -dbo_switch_to_comm_sync = _register_ubatch_function( - UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute) +dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync) dbo_switch_to_compute_sync = _register_ubatch_function( - UBatchContext.switch_to_compute_sync) + UBatchContext.switch_to_compute_sync +) def dbo_register_recv_hook(recv_hook): @@ -197,28 +198,25 @@ def make_ubatch_contexts( Create a context manager for micro-batching synchronization. """ cpu_events = [threading.Event() for _ in range(num_micro_batches)] - gpu_comm_done_events = [ - torch.cuda.Event() for _ in range(num_micro_batches) - ] - gpu_compute_done_events = [ - torch.cuda.Event() for _ in range(num_micro_batches) - ] + gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] + gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] assert len(forward_contexts) == 2 ctxs = [] for i in range(num_micro_batches): - ctx = UBatchContext(id=i, - compute_stream=compute_stream, - comm_stream=comm_stream, - forward_context=forward_contexts[i], - ready_barrier=ready_barrier, - cpu_wait_event=cpu_events[i], - cpu_signal_event=cpu_events[(i + 1) % - num_micro_batches], - gpu_comm_done_event=gpu_comm_done_events[i], - gpu_compute_done_event=gpu_compute_done_events[i], - schedule=schedule) + ctx = UBatchContext( + id=i, + compute_stream=compute_stream, + comm_stream=comm_stream, + forward_context=forward_contexts[i], + ready_barrier=ready_barrier, + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule, + ) ctxs.append(ctx) return ctxs diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 3e0dbda594..c3d16827f1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -35,18 +35,18 @@ class MultiModalBudget: self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry - self.cache = cache = processor_only_cache_from_config( - model_config, mm_registry) + self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, - cache=cache) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config, - cache=cache) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality( + model_config, cache=cache + ) + ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, @@ -145,17 +145,14 @@ class AttentionGroup: vllm_config: VllmConfig, device: torch.device, num_metadata_builders: int = 1, - ) -> 'AttentionGroup': + ) -> "AttentionGroup": metadata_builders = [ - backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, - device) + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) for _ in range(num_metadata_builders) ] - return AttentionGroup(backend, metadata_builders, layer_names, - kv_cache_spec) + return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec) - def get_metadata_builder(self, - ubatch_id: int = 0) -> AttentionMetadataBuilder: + def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id return self.metadata_builders[ubatch_id] @@ -172,19 +169,22 @@ def sanity_check_mm_encoder_outputs( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " f"or a single 3D tensor, but got {type(mm_embeddings)} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert len(mm_embeddings) == expected_num_items, ( "Expected number of multimodal embeddings to match number of " f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert all(e.ndim == 2 for e in mm_embeddings), ( "Expected multimodal embeddings to be a sequence of 2D tensors, " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) def scatter_mm_placeholders( @@ -290,8 +290,7 @@ def bind_kv_cache( # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: - index2name[extract_layer_index(layer_name, - num_attn_module)].append(layer_name) + index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] @@ -319,16 +318,16 @@ def bind_kv_cache( forward_context[layer_name].kv_cache = [kv_cache] -def is_residual_scattered_for_sp(vllm_config: VllmConfig, - num_input_tokens: int) -> bool: +def is_residual_scattered_for_sp( + vllm_config: VllmConfig, num_input_tokens: int +) -> bool: """Check if the residual tensor is scattered for sequence parallelism. The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled, and the number of input tokens is one of the compilation sizes. """ - if not vllm_config.compilation_config.pass_config.\ - enable_sequence_parallelism: + if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: return False tp = vllm_config.parallel_config.tensor_parallel_size @@ -341,4 +340,4 @@ def is_residual_scattered_for_sp(vllm_config: VllmConfig, assert num_input_tokens % tp == 0 # Currently, SP is only enabled for static size fx graphs. - return (num_input_tokens in vllm_config.compilation_config.compile_sizes) + return num_input_tokens in vllm_config.compilation_config.compile_sizes diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 5b393ee6bf..5f5c6bcea0 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -13,10 +13,13 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (enable_trace_function_call_for_thread, - resolve_obj_by_qualname, run_method, - update_environment_variables, - warn_for_unimplemented_methods) +from vllm.utils import ( + enable_trace_function_call_for_thread, + resolve_obj_by_qualname, + run_method, + update_environment_variables, + warn_for_unimplemented_methods, +) from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.outputs import SamplerOutput @@ -65,6 +68,7 @@ class WorkerBase: self.compilation_config = vllm_config.compilation_config from vllm.platforms import current_platform + self.current_platform = current_platform self.parallel_config.rank = rank @@ -95,10 +99,8 @@ class WorkerBase: """ raise NotImplementedError - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. - """ + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks.""" raise NotImplementedError def get_model(self) -> nn.Module: @@ -113,8 +115,7 @@ class WorkerBase: raise NotImplementedError def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None + self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[list[SamplerOutput]]: raise NotImplementedError @@ -209,6 +210,7 @@ class WorkerWrapperBase: if trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() def shutdown(self) -> None: @@ -229,7 +231,7 @@ class WorkerWrapperBase: envs_list: list[dict[str, str]], ) -> None: envs = envs_list[self.rpc_rank] - key = 'CUDA_VISIBLE_DEVICES' + key = "CUDA_VISIBLE_DEVICES" if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior # suppress the warning in `update_environment_variables` @@ -244,22 +246,26 @@ class WorkerWrapperBase: kwargs = all_kwargs[self.rpc_rank] self.vllm_config = kwargs.get("vllm_config") assert self.vllm_config is not None, ( - "vllm_config is required to initialize the worker") + "vllm_config is required to initialize the worker" + ) enable_trace_function_call_for_thread(self.vllm_config) from vllm.plugins import load_general_plugins + load_general_plugins() if isinstance(self.vllm_config.parallel_config.worker_cls, str): worker_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_cls) + self.vllm_config.parallel_config.worker_cls + ) else: raise ValueError( "passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501 ) if self.vllm_config.parallel_config.worker_extension_cls: worker_extension_cls = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_extension_cls) + self.vllm_config.parallel_config.worker_extension_cls + ) extended_calls = [] if worker_extension_cls not in worker_class.__bases__: # check any conflicts between worker and worker_extension_cls @@ -269,15 +275,20 @@ class WorkerWrapperBase: assert not hasattr(worker_class, attr), ( f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" - f" extension class {worker_extension_cls}.") + f" extension class {worker_extension_cls}." + ) if callable(getattr(worker_extension_cls, attr)): extended_calls.append(attr) # dynamically inherit the worker extension class worker_class.__bases__ = worker_class.__bases__ + ( - worker_extension_cls, ) + worker_extension_cls, + ) logger.info( "Injected %s into %s for extended collective_rpc calls %s", - worker_extension_cls, worker_class, extended_calls) + worker_extension_cls, + worker_class, + extended_calls, + ) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) @@ -305,8 +316,10 @@ class WorkerWrapperBase: # exceptions in the rest worker may cause deadlock in rpc like ray # see https://github.com/vllm-project/vllm/issues/3455 # print the error and inform the user to solve the error - msg = (f"Error executing method {method!r}. " - "This might cause deadlock in distributed execution.") + msg = ( + f"Error executing method {method!r}. " + "This might cause deadlock in distributed execution." + ) logger.exception(msg) raise e diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 7becdd3924..4f82c18da7 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -37,9 +37,7 @@ class XPUModelRunner(GPUModelRunner): @contextmanager def _torch_cuda_wrapper(): - class _EventPlaceholder: - def __init__(self, *args, **kwargs) -> None: self.record = lambda: None self.synchronize = lambda: None diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 7355206f30..a1e54628d9 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -11,8 +11,7 @@ from vllm.distributed import get_world_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) @@ -29,8 +28,9 @@ class XPUWorker(Worker): distributed_init_method: str, is_driver_worker: bool = False, ): - super().__init__(vllm_config, local_rank, rank, - distributed_init_method, is_driver_worker) + super().__init__( + vllm_config, local_rank, rank, distributed_init_method, is_driver_worker + ) device_config = self.device_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() @@ -39,8 +39,10 @@ class XPUWorker(Worker): # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", @@ -59,7 +61,9 @@ class XPUWorker(Worker): with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) + torch_profiler_trace_dir, use_gzip=True + ), + ) else: self.profiler = None @@ -75,8 +79,7 @@ class XPUWorker(Worker): # and we don't have any API to get it. so we mark it as 128MB. used_memory = torch.xpu.memory_allocated() non_torch_allocations = 128 * 1024 * 1024 - free_gpu_memory = total_gpu_memory - (used_memory + - non_torch_allocations) + free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations) return free_gpu_memory, total_gpu_memory @torch.inference_mode() @@ -97,10 +100,12 @@ class XPUWorker(Worker): free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() current_allocated_bytes = torch.xpu.memory_allocated() - msg = ("Before memory profiling run, " - f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " - f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + msg = ( + "Before memory profiling run, " + f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " + f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." + ) logger.info(msg) # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -113,67 +118,73 @@ class XPUWorker(Worker): "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + "not properly cleaned up before initializing the vLLM instance." + ) # Get the peak memory allocation recorded by torch peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] torch.xpu.empty_cache() - torch_allocated_bytes = torch.xpu.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = self.xpu_get_mem_info( - )[1] - self.xpu_get_mem_info()[0] + torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"] + total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0] non_torch_allocations = total_allocated_bytes - torch_allocated_bytes if non_torch_allocations > 0: peak_memory += non_torch_allocations available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) + total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory + ) - msg = ("After memory profiling run, " - f"peak memory usage is {peak_memory / 1024**2:.2f} MB," - f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " - f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + msg = ( + "After memory profiling run, " + f"peak memory usage is {peak_memory / 1024**2:.2f} MB," + f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " + f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." + ) logger.info(msg) return int(available_kv_cache_memory) def init_device(self): - if self.device_config.device.type == "xpu" and current_platform.is_xpu( - ): + if self.device_config.device.type == "xpu" and current_platform.is_xpu(): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory + self.local_rank + ).total_memory else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") + raise RuntimeError(f"Not support device type: {self.device_config.device}") ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd") ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") - ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", - str(self.parallel_config.world_size)) + ENV_LOCAL_WORLD_SIZE = os.getenv( + "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size) + ) os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_RANK"] = str(self.local_rank) - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce(torch.zeros(1).xpu(), - group=get_world_group().device_group) + torch.distributed.all_reduce( + torch.zeros(1).xpu(), group=get_world_group().device_group + ) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner = XPUModelRunner( # type: ignore - self.vllm_config, self.device) + self.vllm_config, self.device + ) diff --git a/vllm/version.py b/vllm/version.py index 6c88b1b5a3..63095f8bce 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -6,9 +6,7 @@ try: except Exception as e: import warnings - warnings.warn(f"Failed to read commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) __version__ = "dev" __version_tuple__ = (0, 0, __version__)