diff --git a/examples/omniinfer/omni/adaptors/vllm/patches/pangu_patch.py b/examples/omniinfer/omni/adaptors/vllm/patches/pangu_patch.py index 23e09d0d9..95b306952 100644 --- a/examples/omniinfer/omni/adaptors/vllm/patches/pangu_patch.py +++ b/examples/omniinfer/omni/adaptors/vllm/patches/pangu_patch.py @@ -11,13 +11,13 @@ def patch_pangu(): if not hasattr(self.hf_text_config, "model_type"): return False elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'pangu_ultra_moe'): + ('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp', 'pangu_ultra_moe'): return kv_lora_dim is not None elif self.hf_text_config.model_type == 'eagle': # if the model is an EAGLE module, check for the # underlying architecture return self.hf_text_config.model.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'pangu_ultra_moe') \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'pangu_ultra_moe') \ and kv_lora_dim is not None return False @@ -60,7 +60,7 @@ def patch_pangu(): @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: - if hf_config.model_type == "deepseek_v3": + if hf_config.model_type in ["deepseek_v3", "deepseek_v32"]: hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) @@ -105,7 +105,7 @@ def patch_pangu(): # mtp acceleration for more models besides deepseek_v3 if self.target_model_config and \ (self.target_model_config.hf_text_config.model_type \ - == "deepseek_v3" or + in ["deepseek_v3", "deepseek_v32"] or self.target_model_config.hf_text_config.model_type \ == "mimo" or self.target_model_config.hf_text_config.model_type \ diff --git a/examples/omniinfer/omni/models/__init__.py b/examples/omniinfer/omni/models/__init__.py index 2ecbc77f1..3e6ed54df 100644 --- a/examples/omniinfer/omni/models/__init__.py +++ b/examples/omniinfer/omni/models/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +from transformers import AutoConfig from omni.adaptors.vllm.patches import model_patch from vllm import ModelRegistry import os @@ -10,6 +11,10 @@ if os.getenv("PROFILING_NAMELIST", None): print("<<