[Bugfix] Fix: add patch_rope_scaling after hf override (#20857)

Signed-off-by: Wang Siyuan <wsy0227@sjtu.edu.cn>
Signed-off-by: Wang Siyuan <sywang0227@gmail.com>
This commit is contained in:
Wang Siyuan
2025-07-13 15:13:25 +08:00
committed by GitHub
parent bd4c1e6fdb
commit 247102f07f
2 changed files with 17 additions and 11 deletions

View File

@ -532,16 +532,12 @@ class ModelConfig:
self.config_format = ConfigFormat(self.config_format)
hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)
if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.debug("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)
self.trust_remote_code,
self.revision,
self.code_revision,
self.config_format,
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
@ -5052,4 +5048,4 @@ class SpeechToTextConfig:
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None
return self.min_energy_split_window_size is not None

View File

@ -305,6 +305,9 @@ def get_config(
revision: Optional[str] = None,
code_revision: Optional[str] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides_kw: Optional[dict[str, Any]] = None,
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
PretrainedConfig]] = None,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
@ -423,6 +426,13 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)
config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.debug("Overriding HF config with %s", hf_overrides_fn)
config = hf_overrides_fn(config)
patch_rope_scaling(config)
if trust_remote_code: