Adds json_count_leaves utility function (#23899)

Signed-off-by: aditchawdhary <aditxy@hotmail.com>
This commit is contained in:
Adit Chawdhary
2025-08-29 17:58:13 +05:30
committed by GitHub
parent 67c14906aa
commit 4f7cde7272
3 changed files with 72 additions and 10 deletions

View File

@ -379,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser):
def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported):
assert supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported
@ -948,6 +948,36 @@ def test_join_host_port():
assert join_host_port("::1", 5555) == "[::1]:5555"
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
from vllm.utils.jsontree import json_count_leaves
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4
def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")

View File

@ -10,7 +10,8 @@ from typing_extensions import TypeAlias, override
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
json_reduce_leaves)
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
MultiModalKwargs, MultiModalKwargsItem,
@ -127,11 +128,32 @@ class MultiModalCache:
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
size / GiB_bytes)
leaf_count = json_count_leaves(value)
logger.debug(
"Calculated size of %s to be %.2f GiB (%d leaves)",
type(value),
size / GiB_bytes,
leaf_count,
)
return size
@classmethod
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
"""
Get the number of leaf elements in a multi-modal cache value.
This provides a measure of structural complexity that can be useful
for debugging cache performance and understanding data patterns.
Args:
value: The multi-modal cache value to analyze.
Returns:
The number of leaf elements in the nested structure.
"""
return json_count_leaves(value)
@classmethod
def get_lru_cache(
cls,

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helper functions to work with nested JSON structures."""
from collections.abc import Iterable
from functools import reduce
from typing import Callable, TypeVar, Union, overload
@ -8,8 +9,12 @@ from typing import Callable, TypeVar, Union, overload
_T = TypeVar("_T")
_U = TypeVar("_U")
JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...], _T]
JSONTree = Union[
dict[str, "JSONTree[_T]"],
list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...],
_T,
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@ -78,3 +83,8 @@ def json_reduce_leaves(
json_iter_leaves(value),
initial,
)
def json_count_leaves(value: JSONTree[_T]) -> int:
"""Count the number of leaves in a nested JSON structure."""
return sum(1 for _ in json_iter_leaves(value))