mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
1 Commits
v0.9.1
...
pd_schedul
Author | SHA1 | Date | |
---|---|---|---|
161010c384 |
@ -196,7 +196,9 @@ class KVConnectorBase_V1(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
self, scheduler_output: SchedulerOutput,
|
||||
sending_KV_req_ids: set[str],
|
||||
waiting_KV_req_ids: set[str]) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
@ -205,5 +207,7 @@ class KVConnectorBase_V1(ABC):
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
sending_KV_req_ids (set[str]): Request IDs to send
|
||||
waiting_KV_req_ids (set[str]): Request IDs to receive
|
||||
"""
|
||||
pass
|
||||
|
@ -271,9 +271,9 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
self, scheduler_output: SchedulerOutput,
|
||||
sending_KV_req_ids: set[str],
|
||||
waiting_KV_req_ids: set[str]) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
@ -281,6 +281,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
sending_KV_req_ids (set[str]): Request IDs to send
|
||||
waiting_KV_req_ids (set[str]): Request IDs to receive
|
||||
"""
|
||||
meta = SharedStorageConnectorMetadata()
|
||||
|
||||
|
@ -98,6 +98,10 @@ class Scheduler(SchedulerInterface):
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: set[str] = set()
|
||||
|
||||
# Requests in states for tracking KV transfers for P/D disagg
|
||||
self.sending_KV_req_ids: set[str] = set()
|
||||
self.waiting_KV_req_ids: set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> CachedRequestData
|
||||
@ -167,6 +171,21 @@ class Scheduler(SchedulerInterface):
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# Check for new remote decode requests for P/D
|
||||
if self.connector is not None:
|
||||
self.waiting_KV_req_ids.update(
|
||||
self.connector.receive_remote_decode_requests())
|
||||
|
||||
# Check if any P/D requests have finished sending or receiving
|
||||
for req_id in list(self.sending_KV_req_ids):
|
||||
if self.connector.done_sending_remote_decode_request(req_id):
|
||||
self.sending_KV_req_ids.remove(req_id)
|
||||
self.finished_req_ids.add(req_id)
|
||||
for req_id in list(self.waiting_KV_req_ids):
|
||||
if self.connector.done_waiting_remote_decode_request(req_id):
|
||||
self.waiting_KV_req_ids.remove(req_id)
|
||||
self.waiting.append(self.requests[req_id])
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
@ -479,7 +498,9 @@ class Scheduler(SchedulerInterface):
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
meta = self.connector.build_connector_meta(scheduler_output,
|
||||
self.sending_KV_req_ids,
|
||||
self.waiting_KV_req_ids)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# Advance the number of computed tokens for the request AFTER
|
||||
@ -682,6 +703,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
# TODO: What if we detect we're done here when doing P/D disagg?
|
||||
stopped = check_stop(request, self.max_model_len)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
@ -718,6 +740,13 @@ class Scheduler(SchedulerInterface):
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
if self.connector is not None and request.do_remote_decode:
|
||||
stopped = True
|
||||
|
||||
self.sending_KV_req_ids.add(req_id)
|
||||
self.connector.send_remote_decode_request(
|
||||
self.kv_cache_manager.req_to_blocks[req_id])
|
||||
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
@ -61,6 +61,9 @@ class Request:
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# P/D disagg related
|
||||
self.do_remote_decode = False
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
|
@ -991,6 +991,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
# Background KV cache transfers can happen here,
|
||||
# since kv_connector_metadata has the req_ids to send/receive.
|
||||
# Not sure I like doing it here since this does not have to do
|
||||
# with model execution but this way we don't do a separate rpc.
|
||||
get_kv_transfer_group().bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
|
Reference in New Issue
Block a user