[BugFix] Fix case where collective_rpc returns None (#22006)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-07-31 13:51:37 +01:00
committed by GitHub
parent 2836dd73f1
commit 5daffe7cf6
2 changed files with 21 additions and 8 deletions

View File

@ -305,10 +305,10 @@ def echo_dc(
return_list: bool = False, return_list: bool = False,
) -> Union[MyDataclass, list[MyDataclass]]: ) -> Union[MyDataclass, list[MyDataclass]]:
print(f"echo dc util function called: {msg}") print(f"echo dc util function called: {msg}")
val = None if msg is None else MyDataclass(msg)
# Return dataclass to verify support for returning custom types # Return dataclass to verify support for returning custom types
# (for which there is special handling to make it work with msgspec). # (for which there is special handling to make it work with msgspec).
return [MyDataclass(msg) for _ in range(3)] if return_list \ return [val for _ in range(3)] if return_list else val
else MyDataclass(msg)
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio(loop_scope="function")
@ -351,6 +351,15 @@ async def test_engine_core_client_util_method_custom_return(
assert isinstance(result, list) and all( assert isinstance(result, list) and all(
isinstance(r, MyDataclass) and r.message == "testarg2" isinstance(r, MyDataclass) and r.message == "testarg2"
for r in result) for r in result)
# Test returning None and list of Nones
result = await core_client.call_utility_async(
"echo_dc", None, False)
assert result is None
result = await core_client.call_utility_async(
"echo_dc", None, True)
assert isinstance(result, list) and all(r is None for r in result)
finally: finally:
client.shutdown() client.shutdown()

View File

@ -49,7 +49,10 @@ def _log_insecure_serialization_warning():
"VLLM_ALLOW_INSECURE_SERIALIZATION=1") "VLLM_ALLOW_INSECURE_SERIALIZATION=1")
def _typestr(t: type): def _typestr(val: Any) -> Optional[tuple[str, str]]:
if val is None:
return None
t = type(val)
return t.__module__, t.__qualname__ return t.__module__, t.__qualname__
@ -131,14 +134,13 @@ class MsgpackEncoder:
if isinstance(obj, UtilityResult): if isinstance(obj, UtilityResult):
result = obj.result result = obj.result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None: if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
return None, result return None, result
# Since utility results are not strongly typed, we also encode # Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to # the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization. # help with correct msgspec deserialization.
cls = result.__class__ return _typestr(result) if type(result) is not list else [
return _typestr(cls) if cls is not list else [ _typestr(v) for v in result
_typestr(type(v)) for v in result
], result ], result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
@ -277,7 +279,9 @@ class MsgpackDecoder:
] ]
return UtilityResult(result) return UtilityResult(result)
def _convert_result(self, result_type: Sequence[str], result: Any): def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
if result_type is None:
return result
mod_name, name = result_type mod_name, name = result_type
mod = importlib.import_module(mod_name) mod = importlib.import_module(mod_name)
result_type = getattr(mod, name) result_type = getattr(mod, name)