mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix case where collective_rpc
returns None
(#22006)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -305,10 +305,10 @@ def echo_dc(
|
||||
return_list: bool = False,
|
||||
) -> Union[MyDataclass, list[MyDataclass]]:
|
||||
print(f"echo dc util function called: {msg}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
# Return dataclass to verify support for returning custom types
|
||||
# (for which there is special handling to make it work with msgspec).
|
||||
return [MyDataclass(msg) for _ in range(3)] if return_list \
|
||||
else MyDataclass(msg)
|
||||
return [val for _ in range(3)] if return_list else val
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
@ -351,6 +351,15 @@ async def test_engine_core_client_util_method_custom_return(
|
||||
assert isinstance(result, list) and all(
|
||||
isinstance(r, MyDataclass) and r.message == "testarg2"
|
||||
for r in result)
|
||||
|
||||
# Test returning None and list of Nones
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", None, False)
|
||||
assert result is None
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", None, True)
|
||||
assert isinstance(result, list) and all(r is None for r in result)
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
@ -49,7 +49,10 @@ def _log_insecure_serialization_warning():
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
||||
|
||||
|
||||
def _typestr(t: type):
|
||||
def _typestr(val: Any) -> Optional[tuple[str, str]]:
|
||||
if val is None:
|
||||
return None
|
||||
t = type(val)
|
||||
return t.__module__, t.__qualname__
|
||||
|
||||
|
||||
@ -131,14 +134,13 @@ class MsgpackEncoder:
|
||||
|
||||
if isinstance(obj, UtilityResult):
|
||||
result = obj.result
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
return None, result
|
||||
# Since utility results are not strongly typed, we also encode
|
||||
# the type (or a list of types in the case it's a list) to
|
||||
# help with correct msgspec deserialization.
|
||||
cls = result.__class__
|
||||
return _typestr(cls) if cls is not list else [
|
||||
_typestr(type(v)) for v in result
|
||||
return _typestr(result) if type(result) is not list else [
|
||||
_typestr(v) for v in result
|
||||
], result
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
@ -277,7 +279,9 @@ class MsgpackDecoder:
|
||||
]
|
||||
return UtilityResult(result)
|
||||
|
||||
def _convert_result(self, result_type: Sequence[str], result: Any):
|
||||
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
|
||||
if result_type is None:
|
||||
return result
|
||||
mod_name, name = result_type
|
||||
mod = importlib.import_module(mod_name)
|
||||
result_type = getattr(mod, name)
|
||||
|
Reference in New Issue
Block a user