mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[UX] Support nested dicts in hf_overrides (#25727)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -292,6 +292,37 @@ def test_rope_customization():
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
||||
|
||||
def test_nested_hf_overrides():
|
||||
"""Test that nested hf_overrides work correctly."""
|
||||
# Test with a model that has text_config
|
||||
model_config = ModelConfig(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
hf_overrides={
|
||||
"text_config": {
|
||||
"hidden_size": 1024,
|
||||
},
|
||||
},
|
||||
)
|
||||
assert model_config.hf_config.text_config.hidden_size == 1024
|
||||
|
||||
# Test with deeply nested overrides
|
||||
model_config = ModelConfig(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
hf_overrides={
|
||||
"text_config": {
|
||||
"hidden_size": 2048,
|
||||
"num_attention_heads": 16,
|
||||
},
|
||||
"vision_config": {
|
||||
"hidden_size": 512,
|
||||
},
|
||||
},
|
||||
)
|
||||
assert model_config.hf_config.text_config.hidden_size == 2048
|
||||
assert model_config.hf_config.text_config.num_attention_heads == 16
|
||||
assert model_config.hf_config.vision_config.hidden_size == 512
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
|
||||
)
|
||||
|
@ -367,6 +367,51 @@ class ModelConfig:
|
||||
assert_hashable(str_factors)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def _update_nested(
|
||||
self,
|
||||
target: Union["PretrainedConfig", dict[str, Any]],
|
||||
updates: dict[str, Any],
|
||||
) -> None:
|
||||
"""Recursively updates a config or dict with nested updates."""
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict):
|
||||
# Get the nested target
|
||||
if isinstance(target, dict):
|
||||
nested_target = target.get(key)
|
||||
else:
|
||||
nested_target = getattr(target, key, None)
|
||||
|
||||
# If nested target exists and can be updated recursively
|
||||
if nested_target is not None and (
|
||||
isinstance(nested_target, dict)
|
||||
or hasattr(nested_target, "__dict__")
|
||||
):
|
||||
self._update_nested(nested_target, value)
|
||||
continue
|
||||
|
||||
# Set the value (base case)
|
||||
if isinstance(target, dict):
|
||||
target[key] = value
|
||||
else:
|
||||
setattr(target, key, value)
|
||||
|
||||
def _apply_dict_overrides(
|
||||
self,
|
||||
config: "PretrainedConfig",
|
||||
overrides: dict[str, Any],
|
||||
) -> None:
|
||||
"""Apply dict overrides, handling both nested configs and dict values."""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
for key, value in overrides.items():
|
||||
attr = getattr(config, key, None)
|
||||
if attr is not None and isinstance(attr, PretrainedConfig):
|
||||
# It's a nested config - recursively update it
|
||||
self._update_nested(attr, value)
|
||||
else:
|
||||
# It's a dict-valued parameter - set it directly
|
||||
setattr(config, key, value)
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
# Multimodal config init vars
|
||||
@ -419,8 +464,17 @@ class ModelConfig:
|
||||
if callable(self.hf_overrides):
|
||||
hf_overrides_kw = {}
|
||||
hf_overrides_fn = self.hf_overrides
|
||||
dict_overrides: dict[str, Any] = {}
|
||||
else:
|
||||
hf_overrides_kw = self.hf_overrides
|
||||
# Separate dict overrides from flat ones
|
||||
# We'll determine how to apply dict overrides after loading the config
|
||||
hf_overrides_kw = {}
|
||||
dict_overrides = {}
|
||||
for key, value in self.hf_overrides.items():
|
||||
if isinstance(value, dict):
|
||||
dict_overrides[key] = value
|
||||
else:
|
||||
hf_overrides_kw[key] = value
|
||||
hf_overrides_fn = None
|
||||
|
||||
if self.rope_scaling:
|
||||
@ -478,6 +532,8 @@ class ModelConfig:
|
||||
)
|
||||
|
||||
self.hf_config = hf_config
|
||||
if dict_overrides:
|
||||
self._apply_dict_overrides(hf_config, dict_overrides)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
|
Reference in New Issue
Block a user