[BugFix] Fix handling of num_computed_tokens with connector (#18232)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Nick Hill
2025-05-19 09:03:25 -07:00
committed by GitHub
parent 43b5f61dce
commit 1b15df2546
2 changed files with 29 additions and 16 deletions

View File

@ -209,7 +209,17 @@ class NixlConnectorScheduler:
rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0
if count > 0:
return count, True
# NOTE: if count is 0 here, we have less than block_size
# tokens to pull after subtracting the local prefix cache hit.
# The remote only sends fully computed blocks, so there is
# nothing to transfer but we still need to notify the
# prefill worker so that the remote blocks are freed.
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
self._reqs_need_recv[request.request_id] = (request, [])
# No remote prefill for this request.
return 0, False
@ -225,10 +235,6 @@ class NixlConnectorScheduler:
num_external_tokens, params)
if params is not None and params.get("do_remote_prefill"):
# NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks.
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):

View File

@ -345,32 +345,38 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.appendleft(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if num_prealloc_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
(0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens +
num_external_computed_tokens +
num_prealloc_computed_tokens)
num_external_computed_tokens)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Total computed tokens (allocated in prior step).
num_computed_tokens = num_prealloc_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
@ -411,7 +417,8 @@ class Scheduler(SchedulerInterface):
# KVConnector: update internal state after allocation.
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
if num_external_computed_tokens:
assert self.connector is not None
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,