mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix multi async save in MultiConnector (#18246)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -21,9 +22,10 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
|
||||
KVConnectorMetadata):
|
||||
pass
|
||||
@dataclass
|
||||
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
metadata: tuple[KVConnectorMetadata, ...]
|
||||
extra_async_saves: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
@ -54,6 +56,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
@ -66,7 +69,10 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
for c, cm in zip(self._connectors, connector_metadata):
|
||||
if connector_metadata.extra_async_saves:
|
||||
self._extra_async_saves.update(
|
||||
connector_metadata.extra_async_saves)
|
||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
@ -152,8 +158,13 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
||||
return MultiKVConnectorMetadata(
|
||||
c.build_connector_meta(scheduler_output) for c in self._connectors)
|
||||
metadata = MultiKVConnectorMetadata(metadata=tuple(
|
||||
c.build_connector_meta(scheduler_output)
|
||||
for c in self._connectors))
|
||||
if self._extra_async_saves:
|
||||
metadata.extra_async_saves = self._extra_async_saves
|
||||
self._extra_async_saves = {}
|
||||
return metadata
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user