From 5daffe7cf6db9765bd667d1a2cf5f18843d58fc7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 31 Jul 2025 13:51:37 +0100 Subject: [PATCH] [BugFix] Fix case where `collective_rpc` returns `None` (#22006) Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 13 +++++++++++-- vllm/v1/serial_utils.py | 16 ++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index f648c38a63..1329ce5f69 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -305,10 +305,10 @@ def echo_dc( return_list: bool = False, ) -> Union[MyDataclass, list[MyDataclass]]: print(f"echo dc util function called: {msg}") + val = None if msg is None else MyDataclass(msg) # Return dataclass to verify support for returning custom types # (for which there is special handling to make it work with msgspec). - return [MyDataclass(msg) for _ in range(3)] if return_list \ - else MyDataclass(msg) + return [val for _ in range(3)] if return_list else val @pytest.mark.asyncio(loop_scope="function") @@ -351,6 +351,15 @@ async def test_engine_core_client_util_method_custom_return( assert isinstance(result, list) and all( isinstance(r, MyDataclass) and r.message == "testarg2" for r in result) + + # Test returning None and list of Nones + result = await core_client.call_utility_async( + "echo_dc", None, False) + assert result is None + result = await core_client.call_utility_async( + "echo_dc", None, True) + assert isinstance(result, list) and all(r is None for r in result) + finally: client.shutdown() diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 4b6a983252..809a60c196 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -49,7 +49,10 @@ def _log_insecure_serialization_warning(): "VLLM_ALLOW_INSECURE_SERIALIZATION=1") -def _typestr(t: type): +def _typestr(val: Any) -> Optional[tuple[str, str]]: + if val is None: + return None + t = type(val) return t.__module__, t.__qualname__ @@ -131,14 +134,13 @@ class MsgpackEncoder: if isinstance(obj, UtilityResult): result = obj.result - if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None: + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: return None, result # Since utility results are not strongly typed, we also encode # the type (or a list of types in the case it's a list) to # help with correct msgspec deserialization. - cls = result.__class__ - return _typestr(cls) if cls is not list else [ - _typestr(type(v)) for v in result + return _typestr(result) if type(result) is not list else [ + _typestr(v) for v in result ], result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: @@ -277,7 +279,9 @@ class MsgpackDecoder: ] return UtilityResult(result) - def _convert_result(self, result_type: Sequence[str], result: Any): + def _convert_result(self, result_type: Sequence[str], result: Any) -> Any: + if result_type is None: + return result mod_name, name = result_type mod = importlib.import_module(mod_name) result_type = getattr(mod, name)