!224 融合算子后端迁移

Merge pull request !224 from 幽若/update-fused
This commit is contained in:
2025-05-30 06:13:41 +00:00
committed by i-robot
parent a82627330b
commit 9b24e1b73a
7 changed files with 349 additions and 91 deletions

View File

@ -313,6 +313,7 @@ def get_model():
patch_config(config)
if config.architectures and config.architectures[0] in SUPPORTED_FUSED_MODELS and args.do_train:
logger.warning_rank0(f"Unsupported model architecture for npu fused options: {config.architectures[0]}")
map_fused_kernel_to_model(
config.architectures[0],
use_npu_fusion_attention=args.use_npu_fusion_attention,

View File

@ -28,14 +28,12 @@ from pathlib import Path
import hashlib
import functools
import torch
import transformers
from transformers.dynamic_module_utils import get_relative_import_files
from transformers.utils.hub import HF_MODULES_CACHE
from transformers import PretrainedConfig
from openmind.utils import logging
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
from openmind.integrations.transformers.npu_fused_ops.rope import rope
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
@ -49,7 +47,6 @@ DYNAMIC_MODELS = {}
@dataclasses.dataclass
class Pattern:
attention: str = "ATTENTION_CLASSES"
rmsnorm: str = "RMSNorm"
rope: str = "apply_rotary_pos_emb"
swiglu: str = "MLP"
@ -73,7 +70,6 @@ def register_dynamic_model(model_name: str, /, **kwargs):
register_dynamic_model(
"InternLM2ForCausalLM",
npu_fusion_attention=internlm2,
rms_norm=rms_norm.NpuRMSNorm,
rope=rope.apply_rotary_pos_emb,
swiglu=swiglu.NpuIntern2SwiGlu,
@ -122,23 +118,9 @@ def _raw_get_dynamic_module(
def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs):
if model_name not in DYNAMIC_MODELS:
return
if torch.__version__ == "2.1.0":
pattern = re.compile(Pattern.attention)
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)]
attention_classes = getattr(module, attention_classes_attr[0])
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
npu_attention_class = type(
"NPUFusionAttention",
(attention_classes["eager"],),
{"forward": DYNAMIC_MODELS[model_name].get("npu_fusion_attention").forward},
)
attention_classes.update({k: npu_attention_class for k in attention_classes})
elif torch.__version__ >= "2.6.0":
config = kwargs.get("config")
setattr(config, "_attn_implementation", "sdpa")
else:
config = kwargs.get("config")
setattr(config, "_attn_implementation", "eager")
setattr(config, "_attn_implementation", "sdpa")
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
@ -157,12 +139,6 @@ def _dynamic_patch_rope(model_name, module):
rope_attr = [attr for attr in dir(module) if pattern.search(attr)]
if DYNAMIC_MODELS[model_name].get("rope"):
setattr(module, rope_attr[0], DYNAMIC_MODELS[model_name].get("rope"))
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
setattr(
DYNAMIC_MODELS[model_name].get("npu_fusion_attention"),
rope_attr[0],
DYNAMIC_MODELS[model_name].get("rope"),
)
def _dynamic_patch_swiglu(model_name, module):

View File

@ -13,34 +13,27 @@
import dataclasses
import re
from types import ModuleType
from typing import Dict, Type
import torch
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.llama import modeling_llama
from transformers.models.mistral import modeling_mistral
from openmind.integrations.transformers.npu_fused_ops import attenions, rms_norm, rope, swiglu
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
@dataclasses.dataclass
class Pattern:
attention: str = "ATTENTION_CLASSES"
rmsnorm: str = "RMSNorm"
rope: str = "apply_rotary_pos_emb"
swiglu: str = "MLP"
def _builtin_patch_flash_attention(RAW_ATTENTION_CLASSES: Dict, NEW_ATTENTION_CLASS: Type):
def _builtin_patch_flash_attention(config):
"""
Patch the FA for transformers built-in models, call this method before the model instantiation is completed,
when the model has already been instantiated, this method is not effective.
"""
RAW_ATTENTION_CLASSES.update({k: NEW_ATTENTION_CLASS for k in RAW_ATTENTION_CLASSES})
def __builtin_patch_flash_attention_v2(config):
setattr(config, "_attn_implementation", "sdpa")
@ -64,16 +57,8 @@ def _builtin_patch_swiglu(module: ModuleType, class_name: str):
def _apply_fused_kernel_base(module: ModuleType, **kwargs):
if kwargs.get("use_npu_fusion_attention", False):
if torch.__version__ == "2.1.0":
attention = kwargs.get("attention")
pattern = re.compile(Pattern.attention)
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)][0]
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
elif torch.__version__ >= "2.6.0":
config = kwargs.get("config")
__builtin_patch_flash_attention_v2(config)
else:
pass
config = kwargs.get("config")
_builtin_patch_flash_attention(config)
else:
# if the FA fused option is not open, enforce eager mode.
config = kwargs.get("config")
@ -96,15 +81,15 @@ def _apply_fused_kernel_base(module: ModuleType, **kwargs):
def apply_fused_kernel_qwen2(**kwargs):
_apply_fused_kernel_base(modeling_qwen2, attention=attenions.qwen2.Qwen2NPUAttention, **kwargs)
_apply_fused_kernel_base(modeling_qwen2, **kwargs)
def apply_fused_kernel_llama(**kwargs):
_apply_fused_kernel_base(modeling_llama, attention=attenions.llama.LlamaNpuFusionAttention, **kwargs)
_apply_fused_kernel_base(modeling_llama, **kwargs)
def apply_fused_kernel_mistral(**kwargs):
_apply_fused_kernel_base(modeling_mistral, attention=attenions.mistral.MistralNpuFlashAttention, **kwargs)
_apply_fused_kernel_base(modeling_mistral, **kwargs)
def apply_fused_kernel_internlm2(**kwargs):
@ -116,7 +101,6 @@ def apply_fused_kernel_internlm2(**kwargs):
if "InternLM2ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
dynamic_module_utils.register_dynamic_model(
"InternLM2ForCausalLM",
npu_fusion_attention=attenions.internlm2,
rms_norm=rms_norm.rms_norm.NpuRMSNorm,
rope=rope.rope.apply_rotary_pos_emb,
swiglu=swiglu.swiglu.NpuIntern2SwiGlu,
@ -139,7 +123,6 @@ def apply_fused_kernel_internlm3(**kwargs):
if "InternLM3ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
dynamic_module_utils.register_dynamic_model(
"InternLM3ForCausalLM",
npu_fusion_attention=attenions.internlm3,
rms_norm=rms_norm.rms_norm.NpuRMSNorm,
rope=rope.rope.apply_rotary_pos_emb,
swiglu=swiglu.swiglu.NpuSwiGlu,

View File

@ -30,7 +30,6 @@ logger = logging.get_logger()
@lru_cache
def check_use_fused_kernel(inner=False) -> bool:
"""
Args:
inner: When this method is called internally within openmind-cli, the value of this parameter is True,
and the openmind related args will be verified. when the outside user try to call `apply_fused_kernel()`
@ -39,13 +38,15 @@ def check_use_fused_kernel(inner=False) -> bool:
Returns: If the environment and the parameters allowed use fused options, return True, otherwise return False
"""
if is_torch_available():
state = PartialState()
device_module = getattr(torch, state.device.type.lower(), None)
if not is_torch_available():
return False
# not in npu environment
if "npu" not in str(device_module):
return False
state = PartialState()
device_module = getattr(torch, state.device.type.lower(), None)
# not in npu environment
if "npu" not in str(device_module):
return False
# torch npu is not available
if not is_torch_npu_available():
@ -62,16 +63,15 @@ def check_use_fused_kernel(inner=False) -> bool:
return False
# installed version of transformers and torch is not compatible for npu fused options
if torch.__version__ == "2.1.0" and version.check_package_version("transformers<=4.47.1, >=4.45.0"):
return True
elif torch.__version__ >= "2.6.0" and version.check_package_version("transformers>=4.51.1"):
if version.check_package_version("transformers>=4.51.1, <=4.51.3") and (
version.check_package_version("torch>=2.1.0, <2.1.1") or version.check_package_version("torch>=2.6.0, <2.6.1")
):
return True
else:
logger.warning_rank0(
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__} "
f"and the torch v{torch.__version__}. To use npu fused options, if torch version >= 2.6.0, the version of "
f"transformers is required at least v4.51.1; if torch version == 2.1.0, the version of transformers is "
f"required >= v4.45.0, and <= 4.47.1; In other cases, the npu fused options will not be available. "
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__}. "
f"To use npu fused options, you need torch == 2.1.0 or 2.6.0, and transformers >= 4.51.1, <=4.51.3 . "
f"In other cases, the npu fused options will not available. "
)
return False
@ -118,6 +118,8 @@ def apply_fused_kernel(**kwargs):
specified fusion operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
`use_fused_rope: bool = False`, default is True, set it to `False` to disable npu apply_rotary_pos_emb.
`use_fused_swiglu: bool = False`, default is True, set it to `False` to disable npu swiglu.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
@ -198,6 +200,5 @@ SUPPORTED_FUSED_MODELS = {
def map_fused_kernel_to_model(architecture, **kwargs):
if architecture not in SUPPORTED_FUSED_MODELS:
logger.warning_rank0(f"Unsupported fused model architecture: {architecture}")
return
SUPPORTED_FUSED_MODELS.get(architecture)(inner=True, **kwargs)

View File

@ -29,7 +29,6 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
_dynamic_patch_swiglu,
patch_dynamic_fused_ops,
)
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
@ -38,7 +37,6 @@ class TestDynamicModelsRegistration(unittest.TestCase):
"""Test dynamic model registration"""
self.assertIn("InternLM2ForCausalLM", DYNAMIC_MODELS)
model_config = DYNAMIC_MODELS["InternLM2ForCausalLM"]
self.assertIsNotNone(model_config.get("npu_fusion_attention"))
self.assertIsNotNone(model_config.get("rms_norm"))
self.assertIsNotNone(model_config.get("rope"))
self.assertIsNotNone(model_config.get("swiglu"))
@ -86,8 +84,10 @@ class TestDynamicPatching(unittest.TestCase):
mock_module = ModuleType("mock_module")
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=Config())
self.assertIsInstance(mock_module.ATTENTION_CLASSES["eager"].forward, internlm2.forward.__class__)
mock_config = Config()
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config)
self.assertEqual(mock_config._attn_implementation, "sdpa")
@patch("torch.__version__", "2.6.0")
def test_torch_260_sets_sdpa(self):
@ -102,19 +102,6 @@ class TestDynamicPatching(unittest.TestCase):
self.assertEqual(mock_config._attn_implementation, "sdpa")
@patch("torch.__version__", "2.5.0")
def test_torch_other_sets_sdpa(self):
model_name = "test_model_260"
DYNAMIC_MODELS[model_name] = {}
class Config:
_attn_implementation = "sdpa"
mock_config = Config()
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
self.assertEqual(mock_config._attn_implementation, "eager")
@patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec")
def test_rms_norm_patching(self, _, __):
@ -137,11 +124,6 @@ class TestDynamicPatching(unittest.TestCase):
DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"],
)
self.assertEqual(
getattr(DYNAMIC_MODELS["InternLM2ForCausalLM"]["npu_fusion_attention"], "apply_rotary_pos_emb"),
DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"],
)
@patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec")
def test_swiglu_patching(self, _, __):
@ -155,16 +137,11 @@ class TestDynamicPatching(unittest.TestCase):
@patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec")
def test_rope_patching_without_attention(self, _, __):
original_attention = DYNAMIC_MODELS["InternLM2ForCausalLM"].pop("npu_fusion_attention")
mock_module = ModuleType("mock_module")
mock_module.apply_rotary_pos_emb = MagicMock()
_dynamic_patch_rope("InternLM2ForCausalLM", mock_module)
DYNAMIC_MODELS["InternLM2ForCausalLM"]["npu_fusion_attention"] = original_attention
self.assertEqual(mock_module.apply_rotary_pos_emb, DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"])
@patch("importlib.util.spec_from_file_location")

View File

@ -0,0 +1,117 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
#
# openMind is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import transformers
from openmind.integrations.transformers.npu_fused_ops import kernel
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu
class TestFusedKernel(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.mock_cache = Path(self.temp_dir.name)
self.patcher = patch(
"openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils.HF_MODULES_CACHE", self.mock_cache
)
self.patcher.start()
self.original_rope = transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
self.original_rmsnorm = transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
self.original_swiglu = transformers.models.qwen2.modeling_qwen2.Qwen2MLP
def tearDown(self):
self.temp_dir.cleanup()
self.patcher.stop()
transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm
transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu
def test_builtin_patch_flash_attention(self):
class Config:
_attn_implementation = "eager"
mock_config = Config()
kernel._builtin_patch_flash_attention(mock_config)
self.assertEqual(mock_config._attn_implementation, "sdpa")
def test_builtin_patch_rmsnorm(self):
kernel._builtin_patch_rmsnorm(transformers.models.qwen2.modeling_qwen2, "Qwen2RMSNorm")
self.assertEqual(rms_norm.rms_norm.NpuRMSNorm, transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm)
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm
def test_builtin_patch_rope(self):
self.assertNotEqual(
rope.rope.apply_rotary_pos_emb, transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
)
kernel._builtin_patch_rope(transformers.models.qwen2.modeling_qwen2, "apply_rotary_pos_emb")
self.assertEqual(rope.rope.apply_rotary_pos_emb, transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb)
transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope
def test_builtin_patch_swiglu(self):
self.assertNotEqual(swiglu.swiglu.NpuSwiGlu, transformers.models.qwen2.modeling_qwen2.Qwen2MLP)
kernel._builtin_patch_swiglu(transformers.models.qwen2.modeling_qwen2, "Qwen2MLP")
self.assertEqual(swiglu.swiglu.NpuSwiGlu, transformers.models.qwen2.modeling_qwen2.Qwen2MLP)
transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu
def test_apply_fused_kernel_base(self):
class Config:
_attn_implementation = "eager"
mock_config = Config()
kwargs = {
"use_npu_fusion_attention": True,
"use_fused_rms_norm": True,
"use_fused_rope": True,
"use_fused_swiglu": True,
"config": mock_config,
}
module = transformers.models.qwen2.modeling_qwen2
kernel._apply_fused_kernel_base(module, **kwargs)
self.assertEqual(mock_config._attn_implementation, "sdpa")
self.assertEqual(module.Qwen2RMSNorm, rms_norm.rms_norm.NpuRMSNorm)
self.assertEqual(module.apply_rotary_pos_emb, rope.rope.apply_rotary_pos_emb)
self.assertEqual(module.Qwen2MLP, swiglu.swiglu.NpuSwiGlu)
kwargs["use_npu_fusion_attention"] = False
kernel._apply_fused_kernel_base(module, **kwargs)
self.assertEqual(mock_config._attn_implementation, "eager")
def test_apply_fused_kernel_internlm2(self):
class Config:
_attn_implementation = "eager"
mock_config = Config()
kwargs = {
"use_npu_fusion_attention": True,
"use_fused_rms_norm": True,
"use_fused_rope": True,
"use_fused_swiglu": True,
"config": mock_config,
}
module_path = self.mock_cache / "test_module.py"
module_path.write_text(
"class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass"
"\ndef apply_rotary_pos_emb():\n pass\n"
)
original_utils = transformers.dynamic_module_utils.get_class_in_module
kernel.apply_fused_kernel_internlm2(**kwargs)
self.assertNotEqual(transformers.dynamic_module_utils.get_class_in_module, original_utils)
transformers.dynamic_module_utils.get_class_in_module(
class_name="InternLM2ForCausalLM", module_path="test_module.py"
)
self.assertEqual(mock_config._attn_implementation, "sdpa")

View File

@ -0,0 +1,203 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
#
# openMind is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import unittest
from unittest.mock import patch, MagicMock
from openmind.integrations.transformers.npu_fused_ops.sdk import check_use_fused_kernel
from openmind.integrations.transformers.npu_fused_ops import sdk
class TestCheckUseFusedKernel:
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=False)
def test_torch_unavailable(self, mock_torch_avail):
check_use_fused_kernel.cache_clear()
ret1 = check_use_fused_kernel(inner=False)
check_use_fused_kernel.cache_clear()
ret2 = check_use_fused_kernel(inner=True)
assert not ret1 and not ret2
@patch("transformers.__version__", "4.51.1")
@patch("torch.__version__", "2.1.0")
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
@patch("torch.npu", create=True)
def test_not_npu_environment(self, mock_torch_npu, mock_torch_avail, mock_partial_state):
mock_state = MagicMock()
mock_state.device.type = "cuda"
mock_partial_state.return_value = mock_state
check_use_fused_kernel.cache_clear()
ret1 = check_use_fused_kernel(inner=False)
check_use_fused_kernel.cache_clear()
ret2 = check_use_fused_kernel(inner=True)
assert not ret1 and not ret2
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.arguments")
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_npu_available", return_value=False)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
@patch("torch.npu", create=True)
def test_torch_npu_not_available(
self, mock_torch_npu, mock_torch_avail, mock_npu_avail, mock_partial_state, mock_arguments
):
mock_state = MagicMock()
mock_state.device.type = "npu"
mock_partial_state.return_value = mock_state
mock_args = MagicMock()
mock_args.get_args().disable_fused_options = False
mock_arguments.return_value = mock_args
check_use_fused_kernel.cache_clear()
ret1 = check_use_fused_kernel(inner=False)
check_use_fused_kernel.cache_clear()
ret2 = check_use_fused_kernel(inner=True)
assert not ret1 and not ret2
@patch("transformers.__version__", "4.51.1")
@patch("torch.__version__", "2.1.0")
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.arguments")
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_npu_available", return_value=True)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
@patch("torch.npu", create=True)
def test_inner_call_disable_fused(
self, mock_torch_npu, mock_torch_avail, mock_npu_avail, mock_partial_state, mock_arguments
):
mock_state = MagicMock()
mock_state.device.type = "npu"
mock_partial_state.return_value = mock_state
mock_args = MagicMock()
mock_args.get_args().disable_fused_options = True
mock_arguments.return_value = mock_args
check_use_fused_kernel.cache_clear()
ret1 = check_use_fused_kernel(inner=False)
check_use_fused_kernel.cache_clear()
ret2 = check_use_fused_kernel(inner=True)
assert ret1 and not ret2
@patch("transformers.__version__", "4.51.1")
@patch("torch.__version__", "2.1.0")
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.arguments")
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_npu_available", return_value=True)
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
@patch("openmind.utils.version.check_package_version", return_value=True)
@patch("torch.npu", create=True)
def test_compatible_versions(
self, mock_torch_npu, mock_check_version, mock_torch_avail, mock_npu_avail, mock_partial_state, mock_arguments
):
mock_state = MagicMock()
mock_state.device.type = "npu"
mock_partial_state.return_value = mock_state
mock_args = MagicMock()
mock_args.get_args().disable_fused_options = False
mock_arguments.return_value = mock_args
check_use_fused_kernel.cache_clear()
ret1 = check_use_fused_kernel(inner=False)
check_use_fused_kernel.cache_clear()
ret2 = check_use_fused_kernel(inner=False)
assert ret1 and ret2
class TestParseParams:
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=True)
def test_parse_params_(self, _):
kwargs = {
"use_npu_fusion_attention": True,
"use_fused_rms_norm": True,
"use_fused_rope": True,
"use_fused_swiglu": True,
}
ret = sdk._parse_params(**kwargs)
assert all(ret.values())
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=False)
def test_parse_params_check_failed(self, _):
kwargs = {
"use_npu_fusion_attention": True,
"use_fused_rms_norm": True,
"use_fused_rope": True,
"use_fused_swiglu": True,
}
sdk._parse_params.cache_clear()
ret = sdk._parse_params(**kwargs)
assert not all(ret.values())
class TestApplyLog:
def test_apply_log(self):
kwargs = {
"use_npu_fusion_attention": True,
"use_fused_rms_norm": True,
"use_fused_rope": True,
"use_fused_swiglu": True,
}
model_type = None
sdk._apply_log(model_type, **kwargs)
class TestApplyFusedKernelGeneric(unittest.TestCase):
def test_apply_fused_kernel_generic(self):
apply_func = MagicMock()
kwargs = {"config": None}
sdk._apply_fused_kernel_generic(apply_func, **kwargs)
apply_func.assert_called()
class TestMapFusedKernel:
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=True)
def test_map_fused_kernel_to_model_builtin(self, _):
class Config:
_attn_implementation = "eager"
config = Config()
sdk.map_fused_kernel_to_model(
architecture="Qwen2ForCausalLM",
use_npu_fusion_attention=True,
use_fused_rms_norm=True,
use_fused_rope=True,
use_fused_swiglu=True,
config=config,
)
assert config._attn_implementation == "sdpa"
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=True)
@patch("openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils._raw_get_dynamic_module")
def test_map_fused_kernel_to_model_dynamic(self, mock_raw_get_dynamic_module, mock_check_use_fused_kernel):
import transformers
class Config:
_attn_implementation = "eager"
class Model:
def __init__(self, config):
self.config = config
config = Config()
mock_raw_get_dynamic_module = MagicMock()
mock_raw_get_dynamic_module.model_name = "InternLM2ForCausalLM"
mock_raw_get_dynamic_module.return_value = Model(config)
sdk.map_fused_kernel_to_model(
architecture="InternLM2ForCausalLM",
use_npu_fusion_attention=True,
use_fused_rms_norm=False,
use_fused_rope=False,
use_fused_swiglu=False,
config=config,
)
transformers.dynamic_module_utils.get_class_in_module(class_name="InternLM2ForCausalLM", module_path="/")
assert config._attn_implementation == "sdpa"