[BugFix] Fix async scheduling CPU tensor race take 2 (#25279)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-09-19 16:34:07 -07:00
committed by GitHub
parent ee7a66dd9a
commit 14c1432789

View File

@ -1903,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
elif (self.enable_prompt_embeds and get_pp_group().is_first_rank):
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
# TODO(qthequartermasterman): Since even when prompt embeds are
@ -2125,6 +2125,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
invalid_req_indices,
)
@contextmanager
def synchronize_input_prep(self):
if self.prepare_inputs_event is None:
yield
return
# Ensure prior step has finished with reused CPU tensors.
# This is required in the async scheduling case because
# the CPU->GPU transfer happens async.
self.prepare_inputs_event.synchronize()
try:
yield
finally:
self.prepare_inputs_event.record()
@torch.inference_mode()
def execute_model(
self,
@ -2132,33 +2147,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
with record_function_or_nullcontext("Preprocess"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests"
" need prompt logprobs")
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(
scheduler_output, self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs")
if self.prepare_inputs_event is not None:
# Ensure prior step has finished with reused CPU tensors.
self.prepare_inputs_event.synchronize()
try:
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len, ubatch_slices, num_tokens_after_padding
) = self._prepare_inputs(scheduler_output)
finally:
if self.prepare_inputs_event is not None:
self.prepare_inputs_event.record()
(
num_scheduled_tokens,
num_input_tokens,