[UX] Support nested dicts in hf_overrides (#25727)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-10-06 23:19:16 -04:00
committed by GitHub
parent 2111b4643c
commit c6873c4e6d
2 changed files with 88 additions and 1 deletions

View File

@ -292,6 +292,37 @@ def test_rope_customization():
assert longchat_model_config.max_model_len == 4096
def test_nested_hf_overrides():
"""Test that nested hf_overrides work correctly."""
# Test with a model that has text_config
model_config = ModelConfig(
"Qwen/Qwen2-VL-2B-Instruct",
hf_overrides={
"text_config": {
"hidden_size": 1024,
},
},
)
assert model_config.hf_config.text_config.hidden_size == 1024
# Test with deeply nested overrides
model_config = ModelConfig(
"Qwen/Qwen2-VL-2B-Instruct",
hf_overrides={
"text_config": {
"hidden_size": 2048,
"num_attention_heads": 16,
},
"vision_config": {
"hidden_size": 512,
},
},
)
assert model_config.hf_config.text_config.hidden_size == 2048
assert model_config.hf_config.text_config.num_attention_heads == 16
assert model_config.hf_config.vision_config.hidden_size == 512
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
)

View File

@ -367,6 +367,51 @@ class ModelConfig:
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
def _update_nested(
self,
target: Union["PretrainedConfig", dict[str, Any]],
updates: dict[str, Any],
) -> None:
"""Recursively updates a config or dict with nested updates."""
for key, value in updates.items():
if isinstance(value, dict):
# Get the nested target
if isinstance(target, dict):
nested_target = target.get(key)
else:
nested_target = getattr(target, key, None)
# If nested target exists and can be updated recursively
if nested_target is not None and (
isinstance(nested_target, dict)
or hasattr(nested_target, "__dict__")
):
self._update_nested(nested_target, value)
continue
# Set the value (base case)
if isinstance(target, dict):
target[key] = value
else:
setattr(target, key, value)
def _apply_dict_overrides(
self,
config: "PretrainedConfig",
overrides: dict[str, Any],
) -> None:
"""Apply dict overrides, handling both nested configs and dict values."""
from transformers import PretrainedConfig
for key, value in overrides.items():
attr = getattr(config, key, None)
if attr is not None and isinstance(attr, PretrainedConfig):
# It's a nested config - recursively update it
self._update_nested(attr, value)
else:
# It's a dict-valued parameter - set it directly
setattr(config, key, value)
def __post_init__(
self,
# Multimodal config init vars
@ -419,8 +464,17 @@ class ModelConfig:
if callable(self.hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = self.hf_overrides
dict_overrides: dict[str, Any] = {}
else:
hf_overrides_kw = self.hf_overrides
# Separate dict overrides from flat ones
# We'll determine how to apply dict overrides after loading the config
hf_overrides_kw = {}
dict_overrides = {}
for key, value in self.hf_overrides.items():
if isinstance(value, dict):
dict_overrides[key] = value
else:
hf_overrides_kw[key] = value
hf_overrides_fn = None
if self.rope_scaling:
@ -478,6 +532,8 @@ class ModelConfig:
)
self.hf_config = hf_config
if dict_overrides:
self._apply_dict_overrides(hf_config, dict_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None