mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.9.1rc2
...
pd_schedul
Author | SHA1 | Date | |
---|---|---|---|
161010c384 |
@ -196,7 +196,9 @@ class KVConnectorBase_V1(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
self, scheduler_output: SchedulerOutput,
|
||||||
|
sending_KV_req_ids: set[str],
|
||||||
|
waiting_KV_req_ids: set[str]) -> KVConnectorMetadata:
|
||||||
"""
|
"""
|
||||||
Build the connector metadata for this step.
|
Build the connector metadata for this step.
|
||||||
|
|
||||||
@ -205,5 +207,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
|
sending_KV_req_ids (set[str]): Request IDs to send
|
||||||
|
waiting_KV_req_ids (set[str]): Request IDs to receive
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -271,9 +271,9 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
self._requests_need_load[request.request_id] = request
|
self._requests_need_load[request.request_id] = request
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self,
|
self, scheduler_output: SchedulerOutput,
|
||||||
scheduler_output: SchedulerOutput,
|
sending_KV_req_ids: set[str],
|
||||||
) -> KVConnectorMetadata:
|
waiting_KV_req_ids: set[str]) -> KVConnectorMetadata:
|
||||||
"""Build the connector metadata for this step.
|
"""Build the connector metadata for this step.
|
||||||
|
|
||||||
This function should NOT modify any fields in the scheduler_output.
|
This function should NOT modify any fields in the scheduler_output.
|
||||||
@ -281,6 +281,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
|
sending_KV_req_ids (set[str]): Request IDs to send
|
||||||
|
waiting_KV_req_ids (set[str]): Request IDs to receive
|
||||||
"""
|
"""
|
||||||
meta = SharedStorageConnectorMetadata()
|
meta = SharedStorageConnectorMetadata()
|
||||||
|
|
||||||
|
@ -98,6 +98,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
# This is flushed at the end of each scheduling step.
|
# This is flushed at the end of each scheduling step.
|
||||||
self.finished_req_ids: set[str] = set()
|
self.finished_req_ids: set[str] = set()
|
||||||
|
|
||||||
|
# Requests in states for tracking KV transfers for P/D disagg
|
||||||
|
self.sending_KV_req_ids: set[str] = set()
|
||||||
|
self.waiting_KV_req_ids: set[str] = set()
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
# them at each scheduling step.
|
# them at each scheduling step.
|
||||||
# Request id -> CachedRequestData
|
# Request id -> CachedRequestData
|
||||||
@ -167,6 +171,21 @@ class Scheduler(SchedulerInterface):
|
|||||||
# For logging.
|
# For logging.
|
||||||
scheduled_timestamp = time.monotonic()
|
scheduled_timestamp = time.monotonic()
|
||||||
|
|
||||||
|
# Check for new remote decode requests for P/D
|
||||||
|
if self.connector is not None:
|
||||||
|
self.waiting_KV_req_ids.update(
|
||||||
|
self.connector.receive_remote_decode_requests())
|
||||||
|
|
||||||
|
# Check if any P/D requests have finished sending or receiving
|
||||||
|
for req_id in list(self.sending_KV_req_ids):
|
||||||
|
if self.connector.done_sending_remote_decode_request(req_id):
|
||||||
|
self.sending_KV_req_ids.remove(req_id)
|
||||||
|
self.finished_req_ids.add(req_id)
|
||||||
|
for req_id in list(self.waiting_KV_req_ids):
|
||||||
|
if self.connector.done_waiting_remote_decode_request(req_id):
|
||||||
|
self.waiting_KV_req_ids.remove(req_id)
|
||||||
|
self.waiting.append(self.requests[req_id])
|
||||||
|
|
||||||
# First, schedule the RUNNING requests.
|
# First, schedule the RUNNING requests.
|
||||||
req_index = 0
|
req_index = 0
|
||||||
while req_index < len(self.running) and token_budget > 0:
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
@ -479,7 +498,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||||
# 3. Clear the internal states of the connector
|
# 3. Clear the internal states of the connector
|
||||||
if self.connector is not None:
|
if self.connector is not None:
|
||||||
meta = self.connector.build_connector_meta(scheduler_output)
|
meta = self.connector.build_connector_meta(scheduler_output,
|
||||||
|
self.sending_KV_req_ids,
|
||||||
|
self.waiting_KV_req_ids)
|
||||||
scheduler_output.kv_connector_metadata = meta
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
|
||||||
# Advance the number of computed tokens for the request AFTER
|
# Advance the number of computed tokens for the request AFTER
|
||||||
@ -682,6 +703,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Check for stop and update request state.
|
# Check for stop and update request state.
|
||||||
# This must be called before we make the EngineCoreOutput.
|
# This must be called before we make the EngineCoreOutput.
|
||||||
|
# TODO: What if we detect we're done here when doing P/D disagg?
|
||||||
stopped = check_stop(request, self.max_model_len)
|
stopped = check_stop(request, self.max_model_len)
|
||||||
if stopped:
|
if stopped:
|
||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
@ -718,6 +740,13 @@ class Scheduler(SchedulerInterface):
|
|||||||
# Invariant: EngineCore returns no partial prefill outputs.
|
# Invariant: EngineCore returns no partial prefill outputs.
|
||||||
assert not prompt_logprobs_tensors
|
assert not prompt_logprobs_tensors
|
||||||
|
|
||||||
|
if self.connector is not None and request.do_remote_decode:
|
||||||
|
stopped = True
|
||||||
|
|
||||||
|
self.sending_KV_req_ids.add(req_id)
|
||||||
|
self.connector.send_remote_decode_request(
|
||||||
|
self.kv_cache_manager.req_to_blocks[req_id])
|
||||||
|
|
||||||
self.scheduled_req_ids.remove(req_id)
|
self.scheduled_req_ids.remove(req_id)
|
||||||
if not stopped:
|
if not stopped:
|
||||||
new_running.append(request)
|
new_running.append(request)
|
||||||
|
@ -61,6 +61,9 @@ class Request:
|
|||||||
self.num_encoder_inputs = len(self.mm_inputs)
|
self.num_encoder_inputs = len(self.mm_inputs)
|
||||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||||
|
|
||||||
|
# P/D disagg related
|
||||||
|
self.do_remote_decode = False
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||||
if self.mm_hashes:
|
if self.mm_hashes:
|
||||||
|
@ -991,6 +991,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||||
# Update KVConnector with the KVConnector metadata forward().
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
|
# Background KV cache transfers can happen here,
|
||||||
|
# since kv_connector_metadata has the req_ids to send/receive.
|
||||||
|
# Not sure I like doing it here since this does not have to do
|
||||||
|
# with model execution but this way we don't do a separate rpc.
|
||||||
get_kv_transfer_group().bind_connector_metadata(
|
get_kv_transfer_group().bind_connector_metadata(
|
||||||
scheduler_output.kv_connector_metadata)
|
scheduler_output.kv_connector_metadata)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user