!239 修复device_map patch在transformers4.47.1下异常问题
Merge pull request !239 from 幽若/master-0613
This commit is contained in:
@ -32,9 +32,12 @@ if is_transformers_available() and is_torch_available():
|
||||
|
||||
from openmind.integrations.transformers.logging import patch_transformers_logging
|
||||
from openmind.integrations.transformers.bitsandbytes import patch_bnb
|
||||
from openmind.integrations.transformers.modeling_utils import patch_modeling_utils
|
||||
|
||||
if version.check_package_version("torch>=2.1.0, <2.1.1"):
|
||||
if version.check_package_version("torch>=2.1.0, <2.1.1") and version.check_package_version(
|
||||
"transformers>=4.51.1, <=4.51.3"
|
||||
):
|
||||
from openmind.integrations.transformers.modeling_utils import patch_modeling_utils
|
||||
|
||||
patch_modeling_utils()
|
||||
|
||||
patch_transformers_logging()
|
||||
|
@ -32,6 +32,8 @@ from transformers.modeling_utils import (
|
||||
)
|
||||
from transformers.utils.quantization_config import QuantizationMethod
|
||||
|
||||
from openmind.utils.version import check_package_version
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model_patch(
|
||||
@ -58,7 +60,7 @@ def _load_state_dict_into_meta_model_patch(
|
||||
"""
|
||||
|
||||
# in npu environment, set the device_map like {"": "npu:0"}, in other case, keep the original {"": 0}
|
||||
if is_npu_available():
|
||||
if is_npu_available() and device_map is not None:
|
||||
for k, v in device_map.items():
|
||||
if "npu" not in str(device_map.get(k)):
|
||||
device_map[k] = f"npu:{v}"
|
||||
@ -175,4 +177,5 @@ def _load_state_dict_into_meta_model_patch(
|
||||
|
||||
|
||||
def patch_modeling_utils():
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = _load_state_dict_into_meta_model_patch
|
||||
if check_package_version("transformers>=4.51.1, <=4.51.3"):
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = _load_state_dict_into_meta_model_patch
|
||||
|
@ -16,9 +16,12 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
from openmind.utils.import_utils import is_torch_npu_available
|
||||
from openmind.utils.version import check_package_version
|
||||
|
||||
if check_package_version("transformers>=4.51.1, <=4.51.3"):
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
|
@ -17,8 +17,8 @@ from types import ModuleType
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.llama import modeling_llama
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||
from transformers.integrations import sdpa_attention
|
||||
|
||||
from openmind.utils.version import check_package_version
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
||||
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
||||
|
||||
@ -35,6 +35,11 @@ def _patch_sdpa_forward():
|
||||
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
|
||||
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
|
||||
"""
|
||||
if not check_package_version("transformers>=4.51.1, <=4.51.3"):
|
||||
return
|
||||
from transformers.integrations import sdpa_attention
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||
|
||||
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
|
||||
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
||||
|
Reference in New Issue
Block a user