From 1b15df2546e97c409668da92954d8802c48d13af Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 19 May 2025 09:03:25 -0700 Subject: [PATCH] [BugFix] Fix handling of num_computed_tokens with connector (#18232) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nick Hill Co-authored-by: Nicolò Lucchesi --- .../kv_connector/v1/nixl_connector.py | 16 ++++++---- vllm/v1/core/sched/scheduler.py | 29 ++++++++++++------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e6c83a0fc5..9c2e82b29c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -209,7 +209,17 @@ class NixlConnectorScheduler: rounded_num_prompt_tokens = round_down( len(request.prompt_token_ids), self.block_size) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) - return count, count > 0 + if count > 0: + return count, True + + # NOTE: if count is 0 here, we have less than block_size + # tokens to pull after subtracting the local prefix cache hit. + # The remote only sends fully computed blocks, so there is + # nothing to transfer but we still need to notify the + # prefill worker so that the remote blocks are freed. + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + self._reqs_need_recv[request.request_id] = (request, []) # No remote prefill for this request. return 0, False @@ -225,10 +235,6 @@ class NixlConnectorScheduler: num_external_tokens, params) if params is not None and params.get("do_remote_prefill"): - # NOTE(rob): if prompt < block_size, no remote blocks - # since the remote only sends fully computed blocks, so - # skip recving for this request. num_external_tokens - # should be 0 if there are no remote blocks. if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5ad05485e8..d8fd67e232 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -345,32 +345,38 @@ class Scheduler(SchedulerInterface): skipped_waiting_requests.appendleft(request) continue + num_external_computed_tokens = 0 + load_kv_async = False + # Get already-cached tokens. if num_prealloc_computed_tokens == 0: new_computed_blocks, num_native_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_native_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_native_computed_tokens + + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() num_native_computed_tokens = 0 - # Get externally-cached tokens if using a KVConnector. - num_external_computed_tokens, load_kv_async = ( - (0, False) if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + - num_external_computed_tokens + - num_prealloc_computed_tokens) + # Total computed tokens (allocated in prior step). + num_computed_tokens = num_prealloc_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget # P/D: loading remote KV, do not allocate for new work. if load_kv_async: + assert num_external_computed_tokens > 0 num_new_tokens = 0 # Number of tokens to be scheduled. else: @@ -411,7 +417,8 @@ class Scheduler(SchedulerInterface): # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if self.connector is not None: + if num_external_computed_tokens: + assert self.connector is not None self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks,