[Misc] Avoid accessing req_ids inside a loop (#23159)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-19 02:39:38 -07:00
committed by GitHub
parent 5bfe0dea7a
commit 21bcc8263f

View File

@ -1748,6 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
@ -1763,7 +1764,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
req_id = req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
@ -1843,6 +1844,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
req_ids = self.input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
@ -1851,7 +1853,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_id = req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
@ -1914,6 +1916,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
req_ids = self.input_batch.req_ids
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
@ -1924,7 +1927,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = self.input_batch.req_ids[i]
req_id = req_ids[i]
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue