mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Add VLLM_ALLOW_INSECURE_SERIALIZATION env var (#17490)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -105,8 +105,9 @@ def test_structured_output(
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
(f"Give an example JSON for an employee profile that fits this "
|
||||
f"schema. Make the response as short as possible. Schema: "
|
||||
f"{sample_json_schema}")
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
@ -136,7 +137,8 @@ def test_structured_output(
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old."),
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
@ -165,19 +167,20 @@ def test_structured_output(
|
||||
with pytest.raises(ValueError,
|
||||
match="The provided JSON schema contains features "
|
||||
"not supported by xgrammar."):
|
||||
llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {unsupported_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
llm.generate(
|
||||
prompts=[(f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {unsupported_json_schema}. "
|
||||
f"Make the response as short as possible.")] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
else:
|
||||
outputs = llm.generate(
|
||||
prompts=("Give an example JSON object for a grade "
|
||||
"that fits this schema: "
|
||||
f"{unsupported_json_schema}"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts=(
|
||||
"Give an example JSON object for a grade "
|
||||
"that fits this schema: "
|
||||
f"{unsupported_json_schema}. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
@ -199,8 +202,10 @@ def test_structured_output(
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1"),
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
@ -231,8 +236,10 @@ def test_structured_output(
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1"),
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
@ -269,8 +276,10 @@ def test_structured_output(
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
prompts=("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1"),
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
@ -284,7 +293,8 @@ def test_structured_output(
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
(f"Give an example IPv4 address with this regex: {sample_regex}. "
|
||||
f"Make the response as short as possible.")
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
@ -309,7 +319,8 @@ def test_structured_output(
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
prompts=("The best language for type-safe systems programming is "
|
||||
"(Make the response as short as possible.) "),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
@ -331,11 +342,12 @@ def test_structured_output(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts=(
|
||||
"Generate a JSON with the brand, model and car_type of the most "
|
||||
"iconic car from the 90's. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
@ -373,7 +385,8 @@ def test_structured_output(
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a description of a frog using 50 characters.",
|
||||
prompts=("Generate a description of a frog using 50 characters. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
@ -452,7 +465,8 @@ Reminder:
|
||||
|
||||
You are a helpful assistant.
|
||||
|
||||
Given the previous instructions, what is the weather in New York City?
|
||||
Given the previous instructions, what is the weather in New York City? \
|
||||
Make the response as short as possible.
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
@ -509,9 +523,10 @@ def test_structured_output_auto_mode(
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||
|
||||
prompts = ("Give an example JSON object for a grade "
|
||||
"that fits this schema: "
|
||||
f"{unsupported_json_schema}")
|
||||
prompts = (
|
||||
"Give an example JSON object for a grade "
|
||||
"that fits this schema: "
|
||||
f"{unsupported_json_schema}. Make the response as short as possible.")
|
||||
# This would fail with the default of "xgrammar", but in "auto"
|
||||
# we will handle fallback automatically.
|
||||
outputs = llm.generate(prompts=prompts,
|
||||
@ -566,7 +581,8 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
|
||||
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. "
|
||||
"Make the response as short as possible."
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend):
|
||||
|
@ -9,8 +9,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalFieldElem, MultiModalFlatField,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
@ -36,59 +36,62 @@ class MyType:
|
||||
empty_tensor: torch.Tensor
|
||||
|
||||
|
||||
def test_encode_decode():
|
||||
def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test encode/decode loop with zero-copy tensors."""
|
||||
|
||||
obj = MyType(
|
||||
tensor1=torch.randint(low=0,
|
||||
high=100,
|
||||
size=(1024, ),
|
||||
dtype=torch.int32),
|
||||
a_string="hello",
|
||||
list_of_tensors=[
|
||||
torch.rand((1, 10), dtype=torch.float32),
|
||||
torch.rand((3, 5, 4000), dtype=torch.float64),
|
||||
torch.tensor(1984), # test scalar too
|
||||
# Make sure to test bf16 which numpy doesn't support.
|
||||
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
|
||||
torch.tensor([float("-inf"), float("inf")] * 1024,
|
||||
dtype=torch.bfloat16),
|
||||
],
|
||||
numpy_array=np.arange(512),
|
||||
unrecognized=UnrecognizedType(33),
|
||||
small_f_contig_tensor=torch.rand(5, 4).t(),
|
||||
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
||||
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||
empty_tensor=torch.empty(0),
|
||||
)
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
encoder = MsgpackEncoder(size_threshold=256)
|
||||
decoder = MsgpackDecoder(MyType)
|
||||
obj = MyType(
|
||||
tensor1=torch.randint(low=0,
|
||||
high=100,
|
||||
size=(1024, ),
|
||||
dtype=torch.int32),
|
||||
a_string="hello",
|
||||
list_of_tensors=[
|
||||
torch.rand((1, 10), dtype=torch.float32),
|
||||
torch.rand((3, 5, 4000), dtype=torch.float64),
|
||||
torch.tensor(1984), # test scalar too
|
||||
# Make sure to test bf16 which numpy doesn't support.
|
||||
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
|
||||
torch.tensor([float("-inf"), float("inf")] * 1024,
|
||||
dtype=torch.bfloat16),
|
||||
],
|
||||
numpy_array=np.arange(512),
|
||||
unrecognized=UnrecognizedType(33),
|
||||
small_f_contig_tensor=torch.rand(5, 4).t(),
|
||||
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
||||
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||
empty_tensor=torch.empty(0),
|
||||
)
|
||||
|
||||
encoded = encoder.encode(obj)
|
||||
encoder = MsgpackEncoder(size_threshold=256)
|
||||
decoder = MsgpackDecoder(MyType)
|
||||
|
||||
# There should be the main buffer + 4 large tensor buffers
|
||||
# + 1 large numpy array. "large" is <= 512 bytes.
|
||||
# The two small tensors are encoded inline.
|
||||
assert len(encoded) == 8
|
||||
encoded = encoder.encode(obj)
|
||||
|
||||
decoded: MyType = decoder.decode(encoded)
|
||||
# There should be the main buffer + 4 large tensor buffers
|
||||
# + 1 large numpy array. "large" is <= 512 bytes.
|
||||
# The two small tensors are encoded inline.
|
||||
assert len(encoded) == 8
|
||||
|
||||
assert_equal(decoded, obj)
|
||||
decoded: MyType = decoder.decode(encoded)
|
||||
|
||||
# Test encode_into case
|
||||
assert_equal(decoded, obj)
|
||||
|
||||
preallocated = bytearray()
|
||||
# Test encode_into case
|
||||
|
||||
encoded2 = encoder.encode_into(obj, preallocated)
|
||||
preallocated = bytearray()
|
||||
|
||||
assert len(encoded2) == 8
|
||||
assert encoded2[0] is preallocated
|
||||
encoded2 = encoder.encode_into(obj, preallocated)
|
||||
|
||||
decoded2: MyType = decoder.decode(encoded2)
|
||||
assert len(encoded2) == 8
|
||||
assert encoded2[0] is preallocated
|
||||
|
||||
assert_equal(decoded2, obj)
|
||||
decoded2: MyType = decoder.decode(encoded2)
|
||||
|
||||
assert_equal(decoded2, obj)
|
||||
|
||||
|
||||
class MyRequest(msgspec.Struct):
|
||||
@ -122,7 +125,7 @@ def test_multimodal_kwargs():
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 44559, +-20 for minor changes
|
||||
assert total_len >= 44539 and total_len <= 44579
|
||||
assert 44539 <= total_len <= 44579
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||
|
||||
@ -135,14 +138,15 @@ def test_multimodal_items_by_modality():
|
||||
"video",
|
||||
"v0",
|
||||
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
||||
MultiModalBatchedField(),
|
||||
MultiModalFlatField(
|
||||
[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
|
||||
)
|
||||
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalSharedField(4))
|
||||
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalBatchedField())
|
||||
e4 = MultiModalFieldElem(
|
||||
"image", "i1", torch.zeros(1000, dtype=torch.int32),
|
||||
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2))
|
||||
audio = MultiModalKwargsItem.from_elems([e1])
|
||||
video = MultiModalKwargsItem.from_elems([e2])
|
||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||
@ -161,7 +165,7 @@ def test_multimodal_items_by_modality():
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 14255, +-20 for minor changes
|
||||
assert total_len >= 14235 and total_len <= 14275
|
||||
assert 14250 <= total_len <= 14300
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
|
||||
# check all modalities were recovered and do some basic sanity checks
|
||||
@ -178,8 +182,7 @@ def test_multimodal_items_by_modality():
|
||||
def nested_equal(a: NestedTensors, b: NestedTensors):
|
||||
if isinstance(a, torch.Tensor):
|
||||
return torch.equal(a, b)
|
||||
else:
|
||||
return all(nested_equal(x, y) for x, y in zip(a, b))
|
||||
return all(nested_equal(x, y) for x, y in zip(a, b))
|
||||
|
||||
|
||||
def assert_equal(obj1: MyType, obj2: MyType):
|
||||
@ -199,11 +202,10 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
||||
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||
def test_dict_serialization(allow_pickle: bool):
|
||||
def test_dict_serialization():
|
||||
"""Test encoding and decoding of a generic Python object using pickle."""
|
||||
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||
decoder = MsgpackDecoder(allow_pickle=allow_pickle)
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder()
|
||||
|
||||
# Create a sample Python object
|
||||
obj = {"key": "value", "number": 42}
|
||||
@ -218,11 +220,10 @@ def test_dict_serialization(allow_pickle: bool):
|
||||
assert obj == decoded, "Decoded object does not match the original object."
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||
def test_tensor_serialization(allow_pickle: bool):
|
||||
def test_tensor_serialization():
|
||||
"""Test encoding and decoding of a torch.Tensor."""
|
||||
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||
decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle)
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(torch.Tensor)
|
||||
|
||||
# Create a sample tensor
|
||||
tensor = torch.rand(10, 10)
|
||||
@ -238,11 +239,10 @@ def test_tensor_serialization(allow_pickle: bool):
|
||||
tensor, decoded), "Decoded tensor does not match the original tensor."
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||
def test_numpy_array_serialization(allow_pickle: bool):
|
||||
def test_numpy_array_serialization():
|
||||
"""Test encoding and decoding of a numpy array."""
|
||||
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||
decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle)
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(np.ndarray)
|
||||
|
||||
# Create a sample numpy array
|
||||
array = np.random.rand(10, 10)
|
||||
@ -268,26 +268,31 @@ class CustomClass:
|
||||
return isinstance(other, CustomClass) and self.value == other.value
|
||||
|
||||
|
||||
def test_custom_class_serialization_allowed_with_pickle():
|
||||
def test_custom_class_serialization_allowed_with_pickle(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that serializing a custom class succeeds when allow_pickle=True."""
|
||||
encoder = MsgpackEncoder(allow_pickle=True)
|
||||
decoder = MsgpackDecoder(CustomClass, allow_pickle=True)
|
||||
|
||||
obj = CustomClass("test_value")
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(CustomClass)
|
||||
|
||||
# Encode the custom class
|
||||
encoded = encoder.encode(obj)
|
||||
obj = CustomClass("test_value")
|
||||
|
||||
# Decode the custom class
|
||||
decoded = decoder.decode(encoded)
|
||||
# Encode the custom class
|
||||
encoded = encoder.encode(obj)
|
||||
|
||||
# Verify the decoded object matches the original
|
||||
assert obj == decoded, "Decoded object does not match the original object."
|
||||
# Decode the custom class
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded object matches the original
|
||||
assert obj == decoded, (
|
||||
"Decoded object does not match the original object.")
|
||||
|
||||
|
||||
def test_custom_class_serialization_disallowed_without_pickle():
|
||||
"""Test that serializing a custom class fails when allow_pickle=False."""
|
||||
encoder = MsgpackEncoder(allow_pickle=False)
|
||||
encoder = MsgpackEncoder()
|
||||
|
||||
obj = CustomClass("test_value")
|
||||
|
||||
|
@ -111,6 +111,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -736,6 +737,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# limit will actually be zero-copy decoded.
|
||||
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
|
||||
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
|
||||
|
||||
# If set, allow insecure serialization using pickle.
|
||||
# This is useful for environments where it is deemed safe to use the
|
||||
# insecure method and it is needed for some reason.
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -14,6 +14,7 @@ import zmq
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||
MultiModalBatchedField,
|
||||
MultiModalFieldConfig, MultiModalFieldElem,
|
||||
@ -21,6 +22,8 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUSTOM_TYPE_PICKLE = 1
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||
CUSTOM_TYPE_RAW_VIEW = 3
|
||||
@ -47,9 +50,7 @@ class MsgpackEncoder:
|
||||
via dedicated messages. Note that this is a per-tensor limit.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
size_threshold: Optional[int] = None,
|
||||
allow_pickle: bool = True):
|
||||
def __init__(self, size_threshold: Optional[int] = None):
|
||||
if size_threshold is None:
|
||||
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||
@ -58,7 +59,10 @@ class MsgpackEncoder:
|
||||
# pass custom data to the hook otherwise.
|
||||
self.aux_buffers: Optional[list[bytestr]] = None
|
||||
self.size_threshold = size_threshold
|
||||
self.allow_pickle = allow_pickle
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
logger.warning(
|
||||
"Allowing insecure serialization using pickle due to "
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
||||
|
||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||
try:
|
||||
@ -89,6 +93,12 @@ class MsgpackEncoder:
|
||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||
return self._encode_ndarray(obj)
|
||||
|
||||
if isinstance(obj, slice):
|
||||
# We are assuming only int-based values will be used here.
|
||||
return tuple(
|
||||
int(v) if v is not None else None
|
||||
for v in (obj.start, obj.stop, obj.step))
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
mm: MultiModalKwargs = obj
|
||||
if not mm.modalities:
|
||||
@ -108,7 +118,7 @@ class MsgpackEncoder:
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for item in itemlist]
|
||||
|
||||
if not self.allow_pickle:
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError(f"Object of type {type(obj)} is not serializable")
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
@ -185,13 +195,16 @@ class MsgpackDecoder:
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True):
|
||||
def __init__(self, t: Optional[Any] = None):
|
||||
args = () if t is None else (t, )
|
||||
self.decoder = msgpack.Decoder(*args,
|
||||
ext_hook=self.ext_hook,
|
||||
dec_hook=self.dec_hook)
|
||||
self.aux_buffers: Sequence[bytestr] = ()
|
||||
self.allow_pickle = allow_pickle
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
logger.warning(
|
||||
"Allowing insecure deserialization using pickle due to "
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
||||
|
||||
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
|
||||
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
|
||||
@ -212,6 +225,8 @@ class MsgpackDecoder:
|
||||
return self._decode_ndarray(obj)
|
||||
if issubclass(t, torch.Tensor):
|
||||
return self._decode_tensor(obj)
|
||||
if t is slice:
|
||||
return slice(*obj)
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
if isinstance(obj, list):
|
||||
return MultiModalKwargs.from_items(
|
||||
@ -253,6 +268,12 @@ class MsgpackDecoder:
|
||||
factory_meth_name, *field_args = v["field"]
|
||||
factory_meth = getattr(MultiModalFieldConfig,
|
||||
factory_meth_name)
|
||||
|
||||
# Special case: decode the union "slices" field of
|
||||
# MultiModalFlatField
|
||||
if factory_meth_name == "flat":
|
||||
field_args[0] = self._decode_nested_slices(field_args[0])
|
||||
|
||||
v["field"] = factory_meth(None, *field_args).field
|
||||
elems.append(MultiModalFieldElem(**v))
|
||||
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
@ -269,11 +290,17 @@ class MsgpackDecoder:
|
||||
return self._decode_tensor(obj)
|
||||
return [self._decode_nested_tensors(x) for x in obj]
|
||||
|
||||
def _decode_nested_slices(self, obj: Any) -> Any:
|
||||
assert isinstance(obj, (list, tuple))
|
||||
if obj and not isinstance(obj[0], (list, tuple)):
|
||||
return slice(*obj)
|
||||
return [self._decode_nested_slices(x) for x in obj]
|
||||
|
||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||
return data
|
||||
|
||||
if self.allow_pickle:
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||
|
Reference in New Issue
Block a user