[BugFix]: Properly set engine_id when using multi connector (#19487)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: leiyiming <leiyiming@kingsoft.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Yiming
2025-07-10 04:32:44 +08:00
committed by GitHub
parent 332d4cb17b
commit cd587c93ef
4 changed files with 48 additions and 31 deletions

View File

@ -76,6 +76,9 @@ class TestSharedStorageConnector(SharedStorageConnector):
return attr
# This relies on "fork" multiprocessing method being used.
# It's the default but vLLM may fall back to spawn if for example CUDA
# is already initialized.
KVConnectorFactory.register_connector("TestSharedStorageConnector",
TestSharedStorageConnector.__module__,
TestSharedStorageConnector.__name__)

View File

@ -166,8 +166,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
expected_engine_id: str) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
@ -177,6 +177,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks
assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,

View File

@ -47,7 +47,10 @@ class MultiConnector(KVConnectorBase_V1):
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
engine_id = ktc.get("engine_id",
vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id)
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))
@ -187,7 +190,7 @@ class MultiConnector(KVConnectorBase_V1):
async_saves += 1
if txfer_params is not None:
if kv_txfer_params is not None:
#TODO we can probably change this to merge the dicts here,
# TODO we can probably change this to merge the dicts here,
# checking for key clashes.
raise RuntimeError(
"Only one connector can produce KV transfer params")

View File

@ -488,8 +488,13 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
def _nixl_handshake(
self,
host: str,
port: int,
remote_tp_size: int,
expected_engine_id: str,
) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
@ -498,26 +503,6 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> str:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, rank, remote_tp_size)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return remote_agent_name
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
@ -525,8 +510,32 @@ class NixlConnectorWorker:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", path,
p_remote_rank)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}.")
# Register Remote agent.
remote_agent_name = self.add_remote_agent(metadata, p_remote_rank,
remote_tp_size)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
# Remote rank -> agent name.
return {p_remote_rank: handshake(path, p_remote_rank)}
return {p_remote_rank: remote_agent_name}
def _background_nixl_handshake(self, req_id: str,
remote_engine_id: EngineId, meta: ReqMeta):
@ -535,7 +544,7 @@ class NixlConnectorWorker:
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host, meta.remote_port,
meta.tp_size)
meta.tp_size, remote_engine_id)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
@ -738,10 +747,10 @@ class NixlConnectorWorker:
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == remote_tp_size
else:
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
else:
assert self._tp_size[engine_id] == remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name