Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
274 lines
11 KiB
Python
274 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import time
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
|
from vllm.logger import init_logger
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
|
last_logging_time: float = 0
|
|
forward_start_time: float = 0
|
|
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
|
batchsize_forward_time: defaultdict = defaultdict(list)
|
|
|
|
|
|
class BatchDescriptor(NamedTuple):
|
|
"""
|
|
Batch descriptor for cudagraph dispatching. We should keep the num of
|
|
items as minimal as possible to properly and uniquely describe the padded
|
|
batch for cudagraph.
|
|
"""
|
|
num_tokens: int
|
|
uniform_decode: bool = False
|
|
"""
|
|
False can also be used for an uniform decode batch to dispatch to the
|
|
cudagraph supporting non-uniform batches.
|
|
"""
|
|
|
|
@property
|
|
def non_uniform(self) -> "BatchDescriptor":
|
|
"""
|
|
Return a non-uniform version of current batch descriptor.
|
|
"""
|
|
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
|
|
|
|
|
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
|
max_num_tokens: int,
|
|
chunk_idx: int) -> list[int]:
|
|
dp_size = len(num_tokens_across_dp_cpu)
|
|
|
|
local_size = [-1] * dp_size
|
|
for i in range(dp_size):
|
|
dp_tokens = num_tokens_across_dp_cpu[i]
|
|
local_size[i] = min(max_num_tokens,
|
|
dp_tokens - (max_num_tokens * chunk_idx))
|
|
if local_size[i] <= 0:
|
|
local_size[i] = 1 # ensure lockstep even if done
|
|
return local_size
|
|
|
|
|
|
@dataclass
|
|
class DPMetadata:
|
|
max_tokens_across_dp_cpu: torch.Tensor
|
|
cu_tokens_across_dp_cpu: torch.Tensor
|
|
local_sizes: Optional[list[int]] = None
|
|
|
|
@staticmethod
|
|
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
|
dp_rank: int) -> torch.Tensor:
|
|
"""
|
|
Gather the num_tokens across all DP ranks and return results in a
|
|
CPU tensor of size dp_size.
|
|
"""
|
|
num_tokens_across_dp = [0] * dp_size
|
|
num_tokens_across_dp[dp_rank] = num_tokens
|
|
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
|
device="cpu",
|
|
dtype=torch.int32)
|
|
from vllm.distributed.parallel_state import get_dp_group
|
|
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
|
return num_tokens_tensor
|
|
|
|
@staticmethod
|
|
def make(
|
|
parallel_config: ParallelConfig,
|
|
attn_metadata: Any,
|
|
num_tokens: int,
|
|
num_tokens_across_dp: Optional[torch.Tensor] = None
|
|
) -> "DPMetadata":
|
|
|
|
assert parallel_config.data_parallel_size > 1
|
|
dp_size = parallel_config.data_parallel_size
|
|
dp_rank = parallel_config.data_parallel_rank
|
|
if attn_metadata is not None and hasattr(attn_metadata,
|
|
"num_prefill_tokens"):
|
|
# for v0 attention backends
|
|
batchsize = attn_metadata.num_prefill_tokens + \
|
|
attn_metadata.num_decode_tokens
|
|
else:
|
|
# for v1 attention backends or no attn_metadata
|
|
batchsize = num_tokens
|
|
|
|
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
|
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
|
assert (num_tokens_across_dp is None
|
|
or num_tokens_across_dp[dp_rank] == batchsize)
|
|
if num_tokens_across_dp is None:
|
|
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
|
batchsize, dp_size, dp_rank)
|
|
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
|
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
|
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
|
|
|
@contextmanager
|
|
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
|
"""
|
|
Context manager to compute and temporarily set the per-rank local token
|
|
sizes for a specific chunk during chunked forward execution.
|
|
|
|
This is necessary to ensure each DP (data parallel) rank processes its
|
|
designated portion of tokens in lockstep with others, even when the
|
|
token counts are uneven or some ranks have completed their input early.
|
|
|
|
For chunked execution, we break up the total tokens on each rank into
|
|
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
|
|
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
|
of tokens to process in that chunk on each rank.
|
|
|
|
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
|
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
|
to determine the chunk-wise split.
|
|
|
|
`self.local_sizes` is only valid inside the context.
|
|
|
|
Args:
|
|
max_chunk_size_per_rank: The max number of tokens each rank is
|
|
allowed to process in this chunk.
|
|
chunk_idx: The index of the chunk to compute sizes for.
|
|
"""
|
|
cu_sizes = self.cu_tokens_across_dp_cpu
|
|
num_tokens_across_dp_cpu = [
|
|
(cu_sizes[i] -
|
|
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
|
for i in range(len(cu_sizes))
|
|
]
|
|
self.local_sizes = _compute_chunked_local_num_tokens(
|
|
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
|
try:
|
|
yield self.local_sizes
|
|
finally:
|
|
self.local_sizes = None
|
|
|
|
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
|
return self.local_sizes
|
|
|
|
|
|
@dataclass
|
|
class ForwardContext:
|
|
# copy from vllm_config.compilation_config.static_forward_context
|
|
no_compile_layers: dict[str, Any]
|
|
"""
|
|
Type AttentionMetadata for v0,
|
|
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
|
attention layer to its attention metadata
|
|
set dynamically for each forward pass
|
|
"""
|
|
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
|
|
# TODO: remove after making all virtual_engines share the same kv cache
|
|
virtual_engine: int # set dynamically for each forward pass
|
|
# set dynamically for each forward pass
|
|
dp_metadata: Optional[DPMetadata] = None
|
|
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
|
# by default NONE, no cudagraph is used.
|
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
|
batch_descriptor: Optional[BatchDescriptor] = None
|
|
|
|
def __post_init__(self):
|
|
assert self.cudagraph_runtime_mode in [
|
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
|
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
|
|
|
|
|
_forward_context: Optional[ForwardContext] = None
|
|
|
|
|
|
def get_forward_context() -> ForwardContext:
|
|
"""Get the current forward context."""
|
|
assert _forward_context is not None, (
|
|
"Forward context is not set. "
|
|
"Please use `set_forward_context` to set the forward context.")
|
|
return _forward_context
|
|
|
|
|
|
@contextmanager
|
|
def set_forward_context(
|
|
attn_metadata: Any,
|
|
vllm_config: VllmConfig,
|
|
virtual_engine: int = 0,
|
|
num_tokens: Optional[int] = None,
|
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
|
batch_descriptor: Optional[BatchDescriptor] = None):
|
|
"""A context manager that stores the current forward context,
|
|
can be attention metadata, etc.
|
|
Here we can inject common logic for every model forward pass.
|
|
"""
|
|
global forward_start_time
|
|
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
|
if need_to_track_batchsize:
|
|
forward_start_time = time.perf_counter()
|
|
dp_metadata: Optional[DPMetadata] = None
|
|
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
|
attn_metadata is not None or num_tokens is not None):
|
|
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
|
|
attn_metadata, num_tokens or 0,
|
|
num_tokens_across_dp)
|
|
|
|
global _forward_context
|
|
prev_context = _forward_context
|
|
_forward_context = ForwardContext(
|
|
no_compile_layers=vllm_config.compilation_config.
|
|
static_forward_context,
|
|
virtual_engine=virtual_engine,
|
|
attn_metadata=attn_metadata,
|
|
dp_metadata=dp_metadata,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
global last_logging_time, batchsize_logging_interval
|
|
if need_to_track_batchsize:
|
|
if hasattr(attn_metadata, "num_prefill_tokens"):
|
|
# for v0 attention backends
|
|
batchsize = attn_metadata.num_prefill_tokens + \
|
|
attn_metadata.num_decode_tokens
|
|
else:
|
|
# for v1 attention backends
|
|
batchsize = num_tokens
|
|
# we use synchronous scheduling right now,
|
|
# adding a sync point here should not affect
|
|
# scheduling of the next batch
|
|
from vllm.platforms import current_platform
|
|
synchronize = current_platform.synchronize
|
|
if synchronize is not None:
|
|
synchronize()
|
|
now = time.perf_counter()
|
|
# time measurement is in milliseconds
|
|
batchsize_forward_time[batchsize].append(
|
|
(now - forward_start_time) * 1000)
|
|
if now - last_logging_time > batchsize_logging_interval:
|
|
last_logging_time = now
|
|
forward_stats = []
|
|
for bs, times in batchsize_forward_time.items():
|
|
if len(times) <= 1:
|
|
# can be cudagraph / profiling run
|
|
continue
|
|
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
|
medium = round(medium, 2)
|
|
forward_stats.append((bs, len(times), medium))
|
|
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
|
if forward_stats:
|
|
logger.info(("Batchsize forward time stats "
|
|
"(batchsize, count, median_time(ms)): %s"),
|
|
forward_stats)
|
|
|
|
_forward_context = prev_context
|