mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
break execute_model in gpu_model_runner into sub-functions for custom scopes (#24265)
Co-authored-by: Bangsheng Tang <bangsheng@meta.com>
This commit is contained in:
@ -168,6 +168,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1200,6 +1201,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
@ -1,17 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
@ -368,3 +372,10 @@ def report_usage_stats(
|
||||
"disable_custom_all_reduce":
|
||||
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
|
||||
|
||||
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
||||
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
|
||||
return record_function(name)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
@ -69,7 +69,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
|
||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -79,7 +80,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
@ -1587,31 +1588,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
def _preprocess(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, tokens, please disable it when the requests "
|
||||
"need prompt logprobs")
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor,
|
||||
Optional[IntermediateTensors], dict[str, Any]]:
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
@ -1683,75 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
), self.maybe_get_kv_connector_output(
|
||||
scheduler_output) as kv_connector_output:
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
return (
|
||||
num_scheduled_tokens,
|
||||
num_input_tokens,
|
||||
num_tokens_across_dp,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
)
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
if not broadcast_pp_output:
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
return hidden_states
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.is_pooling_model:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, kv_connector_output)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
||||
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
def _sample(
|
||||
self, logits: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
||||
) -> SamplerOutput:
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
@ -1785,6 +1714,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
return sampler_output
|
||||
|
||||
def _bookkeeping_sync(
|
||||
self, scheduler_output: "SchedulerOutput",
|
||||
sampler_output: SamplerOutput, logits: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor, num_scheduled_tokens: int
|
||||
) -> tuple[
|
||||
dict[str, int],
|
||||
Optional[LogprobsLists],
|
||||
list[list[int]],
|
||||
dict[str, Optional[LogprobsTensors]],
|
||||
list[str],
|
||||
dict[str, int],
|
||||
list[int],
|
||||
]:
|
||||
num_nans_in_logits = {}
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
@ -1827,6 +1771,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
invalid_req_indices = []
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
@ -1892,12 +1837,150 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
return (
|
||||
num_nans_in_logits,
|
||||
logprobs_lists,
|
||||
valid_sampled_token_ids,
|
||||
prompt_logprobs_dict,
|
||||
req_ids_output_copy,
|
||||
req_id_to_index_output_copy,
|
||||
invalid_req_indices,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||||
with record_function_or_nullcontext("Preprocess"):
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, tokens, please disable it when the requests"
|
||||
" need prompt logprobs")
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
(
|
||||
num_scheduled_tokens,
|
||||
num_input_tokens,
|
||||
num_tokens_across_dp,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
) = self._preprocess(scheduler_output, intermediate_tensors)
|
||||
|
||||
uniform_decode = (max_query_len
|
||||
== self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens
|
||||
== self.input_batch.num_reqs * max_query_len)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with (set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
), record_function_or_nullcontext("Forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as
|
||||
kv_connector_output):
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("Postprocess"):
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
if not broadcast_pp_output:
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
return hidden_states
|
||||
get_pp_group().send_tensor_dict(
|
||||
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.is_pooling_model:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np,
|
||||
kv_connector_output)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
model_output_broadcast_data = get_pp_group(
|
||||
).broadcast_tensor_dict(model_output_broadcast_data,
|
||||
src=len(get_pp_group().ranks) - 1)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
with record_function_or_nullcontext("Sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
with record_function_or_nullcontext("Bookkeep"):
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
(
|
||||
num_nans_in_logits,
|
||||
logprobs_lists,
|
||||
valid_sampled_token_ids,
|
||||
prompt_logprobs_dict,
|
||||
req_ids_output_copy,
|
||||
req_id_to_index_output_copy,
|
||||
invalid_req_indices,
|
||||
) = self._bookkeeping_sync(scheduler_output, sampler_output,
|
||||
logits, hidden_states,
|
||||
num_scheduled_tokens)
|
||||
|
||||
if self.speculative_config:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
with record_function_or_nullcontext("Draft"):
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
self.input_batch.sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
@ -1905,6 +1988,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_common_attn_metadata,
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("EPLB"):
|
||||
self.eplb_step()
|
||||
|
||||
output = ModelRunnerOutput(
|
||||
@ -1923,7 +2007,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
|
Reference in New Issue
Block a user