break execute_model in gpu_model_runner into sub-functions for custom scopes (#24265)

Co-authored-by: Bangsheng Tang <bangsheng@meta.com>
This commit is contained in:
Bangsheng Tang
2025-09-06 14:02:47 -07:00
committed by GitHub
parent e68dc2f014
commit 848562bd49
3 changed files with 208 additions and 109 deletions

View File

@ -168,6 +168,7 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
def get_default_cache_root():
@ -1200,6 +1201,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
}
# --8<-- [end:env-vars-definition]

View File

@ -1,17 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import multiprocessing
import time
import weakref
from collections.abc import Sequence
from contextlib import AbstractContextManager
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import torch
from torch.autograd.profiler import record_function
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
@ -368,3 +372,10 @@ def report_usage_stats(
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
return record_function(name)
else:
return contextlib.nullcontext()

View File

@ -69,7 +69,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
@ -79,7 +80,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
@ -1587,31 +1588,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_connector_output=kv_connector_output,
)
@torch.inference_mode()
def execute_model(
def _preprocess(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests "
"need prompt logprobs")
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = self._prepare_inputs(scheduler_output)
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor,
Optional[IntermediateTensors], dict[str, Any]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
@ -1683,75 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
return (
num_scheduled_tokens,
num_input_tokens,
num_tokens_across_dp,
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
)
def _sample(
self, logits: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata]
) -> SamplerOutput:
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
@ -1785,6 +1714,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
sampler_output.sampled_token_ids = output_token_ids
return sampler_output
def _bookkeeping_sync(
self, scheduler_output: "SchedulerOutput",
sampler_output: SamplerOutput, logits: Optional[torch.Tensor],
hidden_states: torch.Tensor, num_scheduled_tokens: int
) -> tuple[
dict[str, int],
Optional[LogprobsLists],
list[list[int]],
dict[str, Optional[LogprobsTensors]],
list[str],
dict[str, int],
list[int],
]:
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
@ -1827,6 +1771,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = []
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@ -1892,20 +1837,159 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
return (
num_nans_in_logits,
logprobs_lists,
valid_sampled_token_ids,
prompt_logprobs_dict,
req_ids_output_copy,
req_id_to_index_output_copy,
invalid_req_indices,
)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
with record_function_or_nullcontext("Preprocess"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests"
" need prompt logprobs")
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = self._prepare_inputs(scheduler_output)
(
num_scheduled_tokens,
num_input_tokens,
num_tokens_across_dp,
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
) = self._preprocess(scheduler_output, intermediate_tensors)
uniform_decode = (max_query_len
== self.uniform_decode_query_len) and (
num_scheduled_tokens
== self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model.
# Use persistent buffers for CUDA graphs.
with (set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), record_function_or_nullcontext("Forward"),
self.maybe_get_kv_connector_output(scheduler_output) as
kv_connector_output):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
self.eplb_step()
with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np,
kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group(
).broadcast_tensor_dict(model_output_broadcast_data,
src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
with record_function_or_nullcontext("Bookkeep"):
assert isinstance(hidden_states, torch.Tensor)
(
num_nans_in_logits,
logprobs_lists,
valid_sampled_token_ids,
prompt_logprobs_dict,
req_ids_output_copy,
req_id_to_index_output_copy,
invalid_req_indices,
) = self._bookkeeping_sync(scheduler_output, sampler_output,
logits, hidden_states,
num_scheduled_tokens)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
self.input_batch.sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
)
with record_function_or_nullcontext("EPLB"):
self.eplb_step()
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
@ -1923,7 +2007,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=sampler_output.sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)