mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix]: Properly set engine_id when using multi connector (#19487)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: leiyiming <leiyiming@kingsoft.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -76,6 +76,9 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
||||
return attr
|
||||
|
||||
|
||||
# This relies on "fork" multiprocessing method being used.
|
||||
# It's the default but vLLM may fall back to spawn if for example CUDA
|
||||
# is already initialized.
|
||||
KVConnectorFactory.register_connector("TestSharedStorageConnector",
|
||||
TestSharedStorageConnector.__module__,
|
||||
TestSharedStorageConnector.__name__)
|
||||
|
@ -166,8 +166,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int,
|
||||
remote_tp_size: int) -> dict[int, str]:
|
||||
def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
|
||||
expected_engine_id: str) -> dict[int, str]:
|
||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
@ -177,6 +177,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
assert expected_engine_id == self.REMOTE_ENGINE_ID
|
||||
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
|
@ -47,7 +47,10 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
assert ktcs is not None
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
|
||||
engine_id = ktc.get("engine_id",
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||
|
||||
@ -187,7 +190,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
#TODO we can probably change this to merge the dicts here,
|
||||
# TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params")
|
||||
|
@ -488,8 +488,13 @@ class NixlConnectorWorker:
|
||||
"Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data))
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int,
|
||||
remote_tp_size: int) -> dict[int, str]:
|
||||
def _nixl_handshake(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
remote_tp_size: int,
|
||||
expected_engine_id: str,
|
||||
) -> dict[int, str]:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@ -498,26 +503,6 @@ class NixlConnectorWorker:
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
|
||||
def handshake(path: str, rank: int) -> str:
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
sock.send(GET_META_MSG)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
got_metadata_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
metadata, rank, remote_tp_size)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
return remote_agent_name
|
||||
|
||||
# Handshake only with the remote TP rank that current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
||||
@ -525,8 +510,32 @@ class NixlConnectorWorker:
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
logger.debug("Querying metadata on path: %s at remote rank %s", path,
|
||||
p_remote_rank)
|
||||
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
sock.send(GET_META_MSG)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
got_metadata_time = time.perf_counter()
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
|
||||
# Ensure engine id matches.
|
||||
if metadata.engine_id != expected_engine_id:
|
||||
raise RuntimeError(f"Remote NIXL agent engine ID mismatch. "
|
||||
f"Expected {expected_engine_id},"
|
||||
f"received {metadata.engine_id}.")
|
||||
|
||||
# Register Remote agent.
|
||||
remote_agent_name = self.add_remote_agent(metadata, p_remote_rank,
|
||||
remote_tp_size)
|
||||
setup_agent_time = time.perf_counter()
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
|
||||
# Remote rank -> agent name.
|
||||
return {p_remote_rank: handshake(path, p_remote_rank)}
|
||||
return {p_remote_rank: remote_agent_name}
|
||||
|
||||
def _background_nixl_handshake(self, req_id: str,
|
||||
remote_engine_id: EngineId, meta: ReqMeta):
|
||||
@ -535,7 +544,7 @@ class NixlConnectorWorker:
|
||||
if fut is None:
|
||||
fut = self._handshake_initiation_executor.submit(
|
||||
self._nixl_handshake, meta.remote_host, meta.remote_port,
|
||||
meta.tp_size)
|
||||
meta.tp_size, remote_engine_id)
|
||||
self._handshake_futures[remote_engine_id] = fut
|
||||
|
||||
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
|
||||
@ -738,10 +747,10 @@ class NixlConnectorWorker:
|
||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||
return self._remote_agents[engine_id][remote_tp_rank]
|
||||
|
||||
if engine_id in self._tp_size:
|
||||
assert self._tp_size[engine_id] == remote_tp_size
|
||||
else:
|
||||
if engine_id not in self._tp_size:
|
||||
self._tp_size[engine_id] = remote_tp_size
|
||||
else:
|
||||
assert self._tp_size[engine_id] == remote_tp_size
|
||||
# We may eventually enable this after asserting equality in cache
|
||||
# layout and close outputs.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
Reference in New Issue
Block a user