[BugFix] Fix multi async save in MultiConnector (#18246)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-05-16 08:13:47 -07:00
committed by GitHub
parent d3d91b6f71
commit 1db4f47f81

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -21,9 +22,10 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
KVConnectorMetadata):
pass
@dataclass
class MultiKVConnectorMetadata(KVConnectorMetadata):
metadata: tuple[KVConnectorMetadata, ...]
extra_async_saves: Optional[dict[str, int]] = None
class MultiConnector(KVConnectorBase_V1):
@ -54,6 +56,7 @@ class MultiConnector(KVConnectorBase_V1):
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
# a single connector to load.
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@ -66,7 +69,10 @@ class MultiConnector(KVConnectorBase_V1):
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
for c, cm in zip(self._connectors, connector_metadata):
if connector_metadata.extra_async_saves:
self._extra_async_saves.update(
connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm)
def clear_connector_metadata(self) -> None:
@ -152,8 +158,13 @@ class MultiConnector(KVConnectorBase_V1):
def build_connector_meta(
self,
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
return MultiKVConnectorMetadata(
c.build_connector_meta(scheduler_output) for c in self._connectors)
metadata = MultiKVConnectorMetadata(metadata=tuple(
c.build_connector_meta(scheduler_output)
for c in self._connectors))
if self._extra_async_saves:
metadata.extra_async_saves = self._extra_async_saves
self._extra_async_saves = {}
return metadata
def request_finished(
self,