[V1] Add VLLM_ALLOW_INSECURE_SERIALIZATION env var (#17490)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Russell Bryant
2025-05-08 01:34:02 -04:00
committed by GitHub
parent 998eea4a0e
commit 6930a41116
4 changed files with 170 additions and 115 deletions

View File

@ -105,8 +105,9 @@ def test_structured_output(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
(f"Give an example JSON for an employee profile that fits this "
f"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
@ -136,7 +137,8 @@ def test_structured_output(
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
"name and age fields for John Smith who is 31 years old. "
"Make the response as short as possible."),
sampling_params=sampling_params,
use_tqdm=True)
@ -165,19 +167,20 @@ def test_structured_output(
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
llm.generate(
prompts=[(f"Give an example JSON for an employee profile that "
f"fits this schema: {unsupported_json_schema}. "
f"Make the response as short as possible.")] * 2,
sampling_params=sampling_params,
use_tqdm=True)
else:
outputs = llm.generate(
prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
outputs = llm.generate(prompts=(
"Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
@ -199,8 +202,10 @@ def test_structured_output(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
@ -231,8 +236,10 @@ def test_structured_output(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
@ -269,8 +276,10 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short "
"as possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
@ -284,7 +293,8 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
(f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
@ -309,7 +319,8 @@ def test_structured_output(
top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
prompts=("The best language for type-safe systems programming is "
"(Make the response as short as possible.) "),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
@ -331,11 +342,12 @@ def test_structured_output(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
outputs = llm.generate(prompts=(
"Generate a JSON with the brand, model and car_type of the most "
"iconic car from the 90's. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
@ -373,7 +385,8 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(
prompts="Generate a description of a frog using 50 characters.",
prompts=("Generate a description of a frog using 50 characters. "
"Make the response as short as possible."),
sampling_params=sampling_params,
use_tqdm=True)
@ -452,7 +465,8 @@ Reminder:
You are a helpful assistant.
Given the previous instructions, what is the weather in New York City?
Given the previous instructions, what is the weather in New York City? \
Make the response as short as possible.
"""
# Change this once other backends support structural_tag
@ -509,9 +523,10 @@ def test_structured_output_auto_mode(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
prompts = ("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}")
prompts = (
"Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}. Make the response as short as possible.")
# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
outputs = llm.generate(prompts=prompts,
@ -566,7 +581,8 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. "
"Make the response as short as possible."
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):

View File

@ -9,8 +9,8 @@ import pytest
import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalFieldElem, MultiModalFlatField,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@ -36,59 +36,62 @@ class MyType:
empty_tensor: torch.Tensor
def test_encode_decode():
def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
"""Test encode/decode loop with zero-copy tensors."""
obj = MyType(
tensor1=torch.randint(low=0,
high=100,
size=(1024, ),
dtype=torch.int32),
a_string="hello",
list_of_tensors=[
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
small_f_contig_tensor=torch.rand(5, 4).t(),
large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
)
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType)
obj = MyType(
tensor1=torch.randint(low=0,
high=100,
size=(1024, ),
dtype=torch.int32),
a_string="hello",
list_of_tensors=[
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
small_f_contig_tensor=torch.rand(5, 4).t(),
large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
)
encoded = encoder.encode(obj)
encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType)
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 8
encoded = encoder.encode(obj)
decoded: MyType = decoder.decode(encoded)
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 8
assert_equal(decoded, obj)
decoded: MyType = decoder.decode(encoded)
# Test encode_into case
assert_equal(decoded, obj)
preallocated = bytearray()
# Test encode_into case
encoded2 = encoder.encode_into(obj, preallocated)
preallocated = bytearray()
assert len(encoded2) == 8
assert encoded2[0] is preallocated
encoded2 = encoder.encode_into(obj, preallocated)
decoded2: MyType = decoder.decode(encoded2)
assert len(encoded2) == 8
assert encoded2[0] is preallocated
assert_equal(decoded2, obj)
decoded2: MyType = decoder.decode(encoded2)
assert_equal(decoded2, obj)
class MyRequest(msgspec.Struct):
@ -122,7 +125,7 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
assert 44539 <= total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
@ -135,14 +138,15 @@ def test_multimodal_items_by_modality():
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
MultiModalFlatField(
[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalSharedField(4))
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
e4 = MultiModalFieldElem(
"image", "i1", torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2))
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
@ -161,7 +165,7 @@ def test_multimodal_items_by_modality():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert total_len >= 14235 and total_len <= 14275
assert 14250 <= total_len <= 14300
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# check all modalities were recovered and do some basic sanity checks
@ -178,8 +182,7 @@ def test_multimodal_items_by_modality():
def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all(nested_equal(x, y) for x, y in zip(a, b))
return all(nested_equal(x, y) for x, y in zip(a, b))
def assert_equal(obj1: MyType, obj2: MyType):
@ -199,11 +202,10 @@ def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
@pytest.mark.parametrize("allow_pickle", [True, False])
def test_dict_serialization(allow_pickle: bool):
def test_dict_serialization():
"""Test encoding and decoding of a generic Python object using pickle."""
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
decoder = MsgpackDecoder(allow_pickle=allow_pickle)
encoder = MsgpackEncoder()
decoder = MsgpackDecoder()
# Create a sample Python object
obj = {"key": "value", "number": 42}
@ -218,11 +220,10 @@ def test_dict_serialization(allow_pickle: bool):
assert obj == decoded, "Decoded object does not match the original object."
@pytest.mark.parametrize("allow_pickle", [True, False])
def test_tensor_serialization(allow_pickle: bool):
def test_tensor_serialization():
"""Test encoding and decoding of a torch.Tensor."""
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle)
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(torch.Tensor)
# Create a sample tensor
tensor = torch.rand(10, 10)
@ -238,11 +239,10 @@ def test_tensor_serialization(allow_pickle: bool):
tensor, decoded), "Decoded tensor does not match the original tensor."
@pytest.mark.parametrize("allow_pickle", [True, False])
def test_numpy_array_serialization(allow_pickle: bool):
def test_numpy_array_serialization():
"""Test encoding and decoding of a numpy array."""
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle)
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(np.ndarray)
# Create a sample numpy array
array = np.random.rand(10, 10)
@ -268,26 +268,31 @@ class CustomClass:
return isinstance(other, CustomClass) and self.value == other.value
def test_custom_class_serialization_allowed_with_pickle():
def test_custom_class_serialization_allowed_with_pickle(
monkeypatch: pytest.MonkeyPatch):
"""Test that serializing a custom class succeeds when allow_pickle=True."""
encoder = MsgpackEncoder(allow_pickle=True)
decoder = MsgpackDecoder(CustomClass, allow_pickle=True)
obj = CustomClass("test_value")
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(CustomClass)
# Encode the custom class
encoded = encoder.encode(obj)
obj = CustomClass("test_value")
# Decode the custom class
decoded = decoder.decode(encoded)
# Encode the custom class
encoded = encoder.encode(obj)
# Verify the decoded object matches the original
assert obj == decoded, "Decoded object does not match the original object."
# Decode the custom class
decoded = decoder.decode(encoded)
# Verify the decoded object matches the original
assert obj == decoded, (
"Decoded object does not match the original object.")
def test_custom_class_serialization_disallowed_without_pickle():
"""Test that serializing a custom class fails when allow_pickle=False."""
encoder = MsgpackEncoder(allow_pickle=False)
encoder = MsgpackEncoder()
obj = CustomClass("test_value")

View File

@ -111,6 +111,7 @@ if TYPE_CHECKING:
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
def get_default_cache_root():
@ -736,6 +737,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
# If set, allow insecure serialization using pickle.
# This is useful for environments where it is deemed safe to use the
# insecure method and it is needed for some reason.
"VLLM_ALLOW_INSECURE_SERIALIZATION":
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
}
# end-env-vars-definition

View File

@ -14,6 +14,7 @@ import zmq
from msgspec import msgpack
from vllm import envs
from vllm.logger import init_logger
from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem,
@ -21,6 +22,8 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
logger = init_logger(__name__)
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
@ -47,9 +50,7 @@ class MsgpackEncoder:
via dedicated messages. Note that this is a per-tensor limit.
"""
def __init__(self,
size_threshold: Optional[int] = None,
allow_pickle: bool = True):
def __init__(self, size_threshold: Optional[int] = None):
if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
@ -58,7 +59,10 @@ class MsgpackEncoder:
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
self.size_threshold = size_threshold
self.allow_pickle = allow_pickle
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
logger.warning(
"Allowing insecure serialization using pickle due to "
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
@ -89,6 +93,12 @@ class MsgpackEncoder:
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, slice):
# We are assuming only int-based values will be used here.
return tuple(
int(v) if v is not None else None
for v in (obj.start, obj.stop, obj.step))
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
@ -108,7 +118,7 @@ class MsgpackEncoder:
for itemlist in mm._items_by_modality.values()
for item in itemlist]
if not self.allow_pickle:
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError(f"Object of type {type(obj)} is not serializable")
if isinstance(obj, FunctionType):
@ -185,13 +195,16 @@ class MsgpackDecoder:
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True):
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()
self.allow_pickle = allow_pickle
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
logger.warning(
"Allowing insecure deserialization using pickle due to "
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
@ -212,6 +225,8 @@ class MsgpackDecoder:
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return self._decode_tensor(obj)
if t is slice:
return slice(*obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
@ -253,6 +268,12 @@ class MsgpackDecoder:
factory_meth_name, *field_args = v["field"]
factory_meth = getattr(MultiModalFieldConfig,
factory_meth_name)
# Special case: decode the union "slices" field of
# MultiModalFlatField
if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0])
v["field"] = factory_meth(None, *field_args).field
elems.append(MultiModalFieldElem(**v))
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
@ -269,11 +290,17 @@ class MsgpackDecoder:
return self._decode_tensor(obj)
return [self._decode_nested_tensors(x) for x in obj]
def _decode_nested_slices(self, obj: Any) -> Any:
assert isinstance(obj, (list, tuple))
if obj and not isinstance(obj[0], (list, tuple)):
return slice(*obj)
return [self._decode_nested_slices(x) for x in obj]
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if self.allow_pickle:
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE: