mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Avoid accessing req_ids inside a loop (#23159)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -1748,6 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
continue
|
||||
@ -1763,7 +1764,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
req_id = req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
@ -1843,6 +1844,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
req_ids = self.input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
@ -1851,7 +1853,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_id = req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
@ -1914,6 +1916,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampled_token_ids: list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
req_ids = self.input_batch.req_ids
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
@ -1924,7 +1927,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_id = req_ids[i]
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
Reference in New Issue
Block a user