mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
break execute_model in gpu_model_runner into sub-functions for custom scopes (#24265)
Co-authored-by: Bangsheng Tang <bangsheng@meta.com>
This commit is contained in:
@ -168,6 +168,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||||
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1200,6 +1201,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_TUNED_CONFIG_FOLDER":
|
"VLLM_TUNED_CONFIG_FOLDER":
|
||||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||||
|
|
||||||
|
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||||
|
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
@ -1,17 +1,21 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
from multiprocessing import connection
|
from multiprocessing import connection
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||||
Union, overload)
|
Union, overload)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.autograd.profiler import record_function
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
usage_message)
|
usage_message)
|
||||||
@ -368,3 +372,10 @@ def report_usage_stats(
|
|||||||
"disable_custom_all_reduce":
|
"disable_custom_all_reduce":
|
||||||
vllm_config.parallel_config.disable_custom_all_reduce,
|
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
||||||
|
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
|
||||||
|
return record_function(name)
|
||||||
|
else:
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
@ -69,7 +69,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
|||||||
KVCacheGroupSpec, KVCacheSpec,
|
KVCacheGroupSpec, KVCacheSpec,
|
||||||
MambaSpec, SlidingWindowSpec)
|
MambaSpec, SlidingWindowSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
|
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||||
|
ModelRunnerOutput, SamplerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -79,7 +80,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
|||||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||||
@ -1587,31 +1588,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_connector_output=kv_connector_output,
|
kv_connector_output=kv_connector_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
def _preprocess(
|
||||||
def execute_model(
|
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
self._update_states(scheduler_output)
|
Optional[torch.Tensor], torch.Tensor,
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
Optional[IntermediateTensors], dict[str, Any]]:
|
||||||
if not has_kv_transfer_group():
|
|
||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
||||||
|
|
||||||
return self.kv_connector_no_forward(scheduler_output,
|
|
||||||
self.vllm_config)
|
|
||||||
|
|
||||||
if self.cache_config.kv_sharing_fast_prefill:
|
|
||||||
assert not self.input_batch.num_prompt_logprobs, (
|
|
||||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
|
||||||
"prompt tokens, tokens, please disable it when the requests "
|
|
||||||
"need prompt logprobs")
|
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
|
||||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
|
||||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
|
||||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
|
||||||
|
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
@ -1683,75 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
num_input_tokens, intermediate_tensors, True)
|
num_input_tokens, intermediate_tensors, True)
|
||||||
|
|
||||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
return (
|
||||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
|
num_scheduled_tokens,
|
||||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
num_input_tokens,
|
||||||
uniform_decode=uniform_decode)
|
num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode, batch_descriptor = \
|
input_ids,
|
||||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
inputs_embeds,
|
||||||
|
positions,
|
||||||
# Run the model.
|
intermediate_tensors,
|
||||||
# Use persistent buffers for CUDA graphs.
|
model_kwargs,
|
||||||
with set_forward_context(
|
)
|
||||||
attn_metadata,
|
|
||||||
self.vllm_config,
|
|
||||||
num_tokens=num_input_tokens,
|
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
||||||
batch_descriptor=batch_descriptor,
|
|
||||||
), self.maybe_get_kv_connector_output(
|
|
||||||
scheduler_output) as kv_connector_output:
|
|
||||||
|
|
||||||
model_output = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
hidden_states, aux_hidden_states = model_output
|
|
||||||
else:
|
|
||||||
hidden_states = model_output
|
|
||||||
aux_hidden_states = None
|
|
||||||
|
|
||||||
# Broadcast PP output for external_launcher (torchrun)
|
|
||||||
# to make sure we are synced across pp ranks
|
|
||||||
# TODO: Support overlapping mirco-batches
|
|
||||||
# https://github.com/vllm-project/vllm/issues/18019
|
|
||||||
broadcast_pp_output = \
|
|
||||||
self.parallel_config.distributed_executor_backend \
|
|
||||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
|
||||||
if not get_pp_group().is_last_rank:
|
|
||||||
# For mid-pipeline stages, return the hidden states.
|
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
|
||||||
if not broadcast_pp_output:
|
|
||||||
hidden_states.kv_connector_output = kv_connector_output
|
|
||||||
return hidden_states
|
|
||||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
|
||||||
all_gather_group=get_tp_group())
|
|
||||||
logits = None
|
|
||||||
else:
|
|
||||||
if self.is_pooling_model:
|
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
|
||||||
num_scheduled_tokens_np, kv_connector_output)
|
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
|
||||||
if broadcast_pp_output:
|
|
||||||
model_output_broadcast_data = {
|
|
||||||
"logits": logits.contiguous(),
|
|
||||||
} if logits is not None else {}
|
|
||||||
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
|
||||||
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
|
|
||||||
assert model_output_broadcast_data is not None
|
|
||||||
logits = model_output_broadcast_data["logits"]
|
|
||||||
|
|
||||||
# Apply structured output bitmasks if present
|
|
||||||
if scheduler_output.grammar_bitmask is not None:
|
|
||||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
|
||||||
|
|
||||||
|
def _sample(
|
||||||
|
self, logits: Optional[torch.Tensor],
|
||||||
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
||||||
|
) -> SamplerOutput:
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
@ -1785,6 +1714,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
sampler_output.sampled_token_ids = output_token_ids
|
||||||
|
|
||||||
|
return sampler_output
|
||||||
|
|
||||||
|
def _bookkeeping_sync(
|
||||||
|
self, scheduler_output: "SchedulerOutput",
|
||||||
|
sampler_output: SamplerOutput, logits: Optional[torch.Tensor],
|
||||||
|
hidden_states: torch.Tensor, num_scheduled_tokens: int
|
||||||
|
) -> tuple[
|
||||||
|
dict[str, int],
|
||||||
|
Optional[LogprobsLists],
|
||||||
|
list[list[int]],
|
||||||
|
dict[str, Optional[LogprobsTensors]],
|
||||||
|
list[str],
|
||||||
|
dict[str, int],
|
||||||
|
list[int],
|
||||||
|
]:
|
||||||
num_nans_in_logits = {}
|
num_nans_in_logits = {}
|
||||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||||
@ -1827,6 +1771,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
|
invalid_req_indices = []
|
||||||
if not self.use_async_scheduling:
|
if not self.use_async_scheduling:
|
||||||
# Get the valid generated tokens.
|
# Get the valid generated tokens.
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
@ -1892,20 +1837,159 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
req_state.output_token_ids.extend(sampled_ids)
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
if self.speculative_config:
|
return (
|
||||||
assert spec_decode_common_attn_metadata is not None
|
num_nans_in_logits,
|
||||||
self._draft_token_ids = self.propose_draft_token_ids(
|
logprobs_lists,
|
||||||
scheduler_output,
|
valid_sampled_token_ids,
|
||||||
valid_sampled_token_ids,
|
prompt_logprobs_dict,
|
||||||
sampling_metadata,
|
req_ids_output_copy,
|
||||||
hidden_states,
|
req_id_to_index_output_copy,
|
||||||
sample_hidden_states,
|
invalid_req_indices,
|
||||||
aux_hidden_states,
|
)
|
||||||
spec_decode_metadata,
|
|
||||||
spec_decode_common_attn_metadata,
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||||||
|
with record_function_or_nullcontext("Preprocess"):
|
||||||
|
self._update_states(scheduler_output)
|
||||||
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
|
if not has_kv_transfer_group():
|
||||||
|
# Return empty ModelRunnerOutput if there's no work to do.
|
||||||
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
return self.kv_connector_no_forward(scheduler_output,
|
||||||
|
self.vllm_config)
|
||||||
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
|
assert not self.input_batch.num_prompt_logprobs, (
|
||||||
|
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||||
|
"prompt tokens, tokens, please disable it when the requests"
|
||||||
|
" need prompt logprobs")
|
||||||
|
|
||||||
|
# Prepare the decoder inputs.
|
||||||
|
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
|
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||||
|
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||||
|
|
||||||
|
(
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_input_tokens,
|
||||||
|
num_tokens_across_dp,
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
model_kwargs,
|
||||||
|
) = self._preprocess(scheduler_output, intermediate_tensors)
|
||||||
|
|
||||||
|
uniform_decode = (max_query_len
|
||||||
|
== self.uniform_decode_query_len) and (
|
||||||
|
num_scheduled_tokens
|
||||||
|
== self.input_batch.num_reqs * max_query_len)
|
||||||
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||||
|
uniform_decode=uniform_decode)
|
||||||
|
cudagraph_runtime_mode, batch_descriptor = \
|
||||||
|
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||||
|
|
||||||
|
# Run the model.
|
||||||
|
# Use persistent buffers for CUDA graphs.
|
||||||
|
with (set_forward_context(
|
||||||
|
attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
|
), record_function_or_nullcontext("Forward"),
|
||||||
|
self.maybe_get_kv_connector_output(scheduler_output) as
|
||||||
|
kv_connector_output):
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.eplb_step()
|
with record_function_or_nullcontext("Postprocess"):
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
hidden_states, aux_hidden_states = model_output
|
||||||
|
else:
|
||||||
|
hidden_states = model_output
|
||||||
|
aux_hidden_states = None
|
||||||
|
|
||||||
|
# Broadcast PP output for external_launcher (torchrun)
|
||||||
|
# to make sure we are synced across pp ranks
|
||||||
|
# TODO: Support overlapping mirco-batches
|
||||||
|
# https://github.com/vllm-project/vllm/issues/18019
|
||||||
|
broadcast_pp_output = \
|
||||||
|
self.parallel_config.distributed_executor_backend \
|
||||||
|
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
# For mid-pipeline stages, return the hidden states.
|
||||||
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
|
if not broadcast_pp_output:
|
||||||
|
hidden_states.kv_connector_output = kv_connector_output
|
||||||
|
return hidden_states
|
||||||
|
get_pp_group().send_tensor_dict(
|
||||||
|
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||||
|
logits = None
|
||||||
|
else:
|
||||||
|
if self.is_pooling_model:
|
||||||
|
return self._pool(hidden_states, num_scheduled_tokens,
|
||||||
|
num_scheduled_tokens_np,
|
||||||
|
kv_connector_output)
|
||||||
|
|
||||||
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
if broadcast_pp_output:
|
||||||
|
model_output_broadcast_data = {
|
||||||
|
"logits": logits.contiguous(),
|
||||||
|
} if logits is not None else {}
|
||||||
|
model_output_broadcast_data = get_pp_group(
|
||||||
|
).broadcast_tensor_dict(model_output_broadcast_data,
|
||||||
|
src=len(get_pp_group().ranks) - 1)
|
||||||
|
assert model_output_broadcast_data is not None
|
||||||
|
logits = model_output_broadcast_data["logits"]
|
||||||
|
|
||||||
|
# Apply structured output bitmasks if present
|
||||||
|
if scheduler_output.grammar_bitmask is not None:
|
||||||
|
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||||
|
|
||||||
|
with record_function_or_nullcontext("Sample"):
|
||||||
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
|
|
||||||
|
with record_function_or_nullcontext("Bookkeep"):
|
||||||
|
assert isinstance(hidden_states, torch.Tensor)
|
||||||
|
(
|
||||||
|
num_nans_in_logits,
|
||||||
|
logprobs_lists,
|
||||||
|
valid_sampled_token_ids,
|
||||||
|
prompt_logprobs_dict,
|
||||||
|
req_ids_output_copy,
|
||||||
|
req_id_to_index_output_copy,
|
||||||
|
invalid_req_indices,
|
||||||
|
) = self._bookkeeping_sync(scheduler_output, sampler_output,
|
||||||
|
logits, hidden_states,
|
||||||
|
num_scheduled_tokens)
|
||||||
|
|
||||||
|
if self.speculative_config:
|
||||||
|
assert spec_decode_common_attn_metadata is not None
|
||||||
|
with record_function_or_nullcontext("Draft"):
|
||||||
|
self._draft_token_ids = self.propose_draft_token_ids(
|
||||||
|
scheduler_output,
|
||||||
|
valid_sampled_token_ids,
|
||||||
|
self.input_batch.sampling_metadata,
|
||||||
|
hidden_states,
|
||||||
|
sample_hidden_states,
|
||||||
|
aux_hidden_states,
|
||||||
|
spec_decode_metadata,
|
||||||
|
spec_decode_common_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
with record_function_or_nullcontext("EPLB"):
|
||||||
|
self.eplb_step()
|
||||||
|
|
||||||
output = ModelRunnerOutput(
|
output = ModelRunnerOutput(
|
||||||
req_ids=req_ids_output_copy,
|
req_ids=req_ids_output_copy,
|
||||||
@ -1923,7 +2007,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
return AsyncGPUModelRunnerOutput(
|
return AsyncGPUModelRunnerOutput(
|
||||||
model_runner_output=output,
|
model_runner_output=output,
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||||
invalid_req_indices=invalid_req_indices,
|
invalid_req_indices=invalid_req_indices,
|
||||||
async_output_copy_stream=self.async_output_copy_stream,
|
async_output_copy_stream=self.async_output_copy_stream,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user