!239 修复device_map patch在transformers4.47.1下异常问题

Merge pull request !239 from 幽若/master-0613
This commit is contained in:
2025-06-13 07:03:41 +00:00
committed by i-robot
parent f0a48b8676
commit c996d59825
4 changed files with 21 additions and 7 deletions

View File

@ -32,9 +32,12 @@ if is_transformers_available() and is_torch_available():
from openmind.integrations.transformers.logging import patch_transformers_logging
from openmind.integrations.transformers.bitsandbytes import patch_bnb
from openmind.integrations.transformers.modeling_utils import patch_modeling_utils
if version.check_package_version("torch>=2.1.0, <2.1.1"):
if version.check_package_version("torch>=2.1.0, <2.1.1") and version.check_package_version(
"transformers>=4.51.1, <=4.51.3"
):
from openmind.integrations.transformers.modeling_utils import patch_modeling_utils
patch_modeling_utils()
patch_transformers_logging()

View File

@ -32,6 +32,8 @@ from transformers.modeling_utils import (
)
from transformers.utils.quantization_config import QuantizationMethod
from openmind.utils.version import check_package_version
@torch.no_grad()
def _load_state_dict_into_meta_model_patch(
@ -58,7 +60,7 @@ def _load_state_dict_into_meta_model_patch(
"""
# in npu environment, set the device_map like {"": "npu:0"}, in other case, keep the original {"": 0}
if is_npu_available():
if is_npu_available() and device_map is not None:
for k, v in device_map.items():
if "npu" not in str(device_map.get(k)):
device_map[k] = f"npu:{v}"
@ -175,4 +177,5 @@ def _load_state_dict_into_meta_model_patch(
def patch_modeling_utils():
transformers.modeling_utils._load_state_dict_into_meta_model = _load_state_dict_into_meta_model_patch
if check_package_version("transformers>=4.51.1, <=4.51.3"):
transformers.modeling_utils._load_state_dict_into_meta_model = _load_state_dict_into_meta_model_patch

View File

@ -16,9 +16,12 @@
from typing import Optional, Tuple
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from openmind.utils.import_utils import is_torch_npu_available
from openmind.utils.version import check_package_version
if check_package_version("transformers>=4.51.1, <=4.51.3"):
from transformers.integrations.sdpa_attention import repeat_kv
def sdpa_attention_forward(

View File

@ -17,8 +17,8 @@ from types import ModuleType
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.llama import modeling_llama
from transformers.models.mistral import modeling_mistral
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
from transformers.integrations import sdpa_attention
from openmind.utils.version import check_package_version
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
@ -35,6 +35,11 @@ def _patch_sdpa_forward():
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
"""
if not check_package_version("transformers>=4.51.1, <=4.51.3"):
return
from transformers.integrations import sdpa_attention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward