Compare commits

...

1 Commits

Author SHA1 Message Date
161010c384 Initial stubs for P/D scheduling changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-04-18 16:42:49 -04:00
5 changed files with 47 additions and 5 deletions

View File

@ -196,7 +196,9 @@ class KVConnectorBase_V1(ABC):
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
self, scheduler_output: SchedulerOutput,
sending_KV_req_ids: set[str],
waiting_KV_req_ids: set[str]) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
@ -205,5 +207,7 @@ class KVConnectorBase_V1(ABC):
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
sending_KV_req_ids (set[str]): Request IDs to send
waiting_KV_req_ids (set[str]): Request IDs to receive
"""
pass

View File

@ -271,9 +271,9 @@ class SharedStorageConnector(KVConnectorBase_V1):
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
self, scheduler_output: SchedulerOutput,
sending_KV_req_ids: set[str],
waiting_KV_req_ids: set[str]) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
@ -281,6 +281,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
sending_KV_req_ids (set[str]): Request IDs to send
waiting_KV_req_ids (set[str]): Request IDs to receive
"""
meta = SharedStorageConnectorMetadata()

View File

@ -98,6 +98,10 @@ class Scheduler(SchedulerInterface):
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# Requests in states for tracking KV transfers for P/D disagg
self.sending_KV_req_ids: set[str] = set()
self.waiting_KV_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> CachedRequestData
@ -167,6 +171,21 @@ class Scheduler(SchedulerInterface):
# For logging.
scheduled_timestamp = time.monotonic()
# Check for new remote decode requests for P/D
if self.connector is not None:
self.waiting_KV_req_ids.update(
self.connector.receive_remote_decode_requests())
# Check if any P/D requests have finished sending or receiving
for req_id in list(self.sending_KV_req_ids):
if self.connector.done_sending_remote_decode_request(req_id):
self.sending_KV_req_ids.remove(req_id)
self.finished_req_ids.add(req_id)
for req_id in list(self.waiting_KV_req_ids):
if self.connector.done_waiting_remote_decode_request(req_id):
self.waiting_KV_req_ids.remove(req_id)
self.waiting.append(self.requests[req_id])
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
@ -479,7 +498,9 @@ class Scheduler(SchedulerInterface):
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
meta = self.connector.build_connector_meta(scheduler_output,
self.sending_KV_req_ids,
self.waiting_KV_req_ids)
scheduler_output.kv_connector_metadata = meta
# Advance the number of computed tokens for the request AFTER
@ -682,6 +703,7 @@ class Scheduler(SchedulerInterface):
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
# TODO: What if we detect we're done here when doing P/D disagg?
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
@ -718,6 +740,13 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
if self.connector is not None and request.do_remote_decode:
stopped = True
self.sending_KV_req_ids.add(req_id)
self.connector.send_remote_decode_request(
self.kv_cache_manager.req_to_blocks[req_id])
self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)

View File

@ -61,6 +61,9 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D disagg related
self.do_remote_decode = False
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:

View File

@ -991,6 +991,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) -> Union[ModelRunnerOutput, torch.Tensor]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
# Background KV cache transfers can happen here,
# since kv_connector_metadata has the req_ids to send/receive.
# Not sure I like doing it here since this does not have to do
# with model execution but this way we don't do a separate rpc.
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)