@ -313,6 +313,7 @@ def get_model():
|
||||
|
||||
patch_config(config)
|
||||
if config.architectures and config.architectures[0] in SUPPORTED_FUSED_MODELS and args.do_train:
|
||||
logger.warning_rank0(f"Unsupported model architecture for npu fused options: {config.architectures[0]}")
|
||||
map_fused_kernel_to_model(
|
||||
config.architectures[0],
|
||||
use_npu_fusion_attention=args.use_npu_fusion_attention,
|
||||
|
@ -28,14 +28,12 @@ from pathlib import Path
|
||||
import hashlib
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers.dynamic_module_utils import get_relative_import_files
|
||||
from transformers.utils.hub import HF_MODULES_CACHE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from openmind.utils import logging
|
||||
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
|
||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
from openmind.integrations.transformers.npu_fused_ops.rope import rope
|
||||
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
|
||||
@ -49,7 +47,6 @@ DYNAMIC_MODELS = {}
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Pattern:
|
||||
attention: str = "ATTENTION_CLASSES"
|
||||
rmsnorm: str = "RMSNorm"
|
||||
rope: str = "apply_rotary_pos_emb"
|
||||
swiglu: str = "MLP"
|
||||
@ -73,7 +70,6 @@ def register_dynamic_model(model_name: str, /, **kwargs):
|
||||
|
||||
register_dynamic_model(
|
||||
"InternLM2ForCausalLM",
|
||||
npu_fusion_attention=internlm2,
|
||||
rms_norm=rms_norm.NpuRMSNorm,
|
||||
rope=rope.apply_rotary_pos_emb,
|
||||
swiglu=swiglu.NpuIntern2SwiGlu,
|
||||
@ -122,23 +118,9 @@ def _raw_get_dynamic_module(
|
||||
def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs):
|
||||
if model_name not in DYNAMIC_MODELS:
|
||||
return
|
||||
if torch.__version__ == "2.1.0":
|
||||
pattern = re.compile(Pattern.attention)
|
||||
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)]
|
||||
attention_classes = getattr(module, attention_classes_attr[0])
|
||||
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
|
||||
npu_attention_class = type(
|
||||
"NPUFusionAttention",
|
||||
(attention_classes["eager"],),
|
||||
{"forward": DYNAMIC_MODELS[model_name].get("npu_fusion_attention").forward},
|
||||
)
|
||||
attention_classes.update({k: npu_attention_class for k in attention_classes})
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
config = kwargs.get("config")
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
else:
|
||||
config = kwargs.get("config")
|
||||
setattr(config, "_attn_implementation", "eager")
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
|
||||
|
||||
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
|
||||
@ -157,12 +139,6 @@ def _dynamic_patch_rope(model_name, module):
|
||||
rope_attr = [attr for attr in dir(module) if pattern.search(attr)]
|
||||
if DYNAMIC_MODELS[model_name].get("rope"):
|
||||
setattr(module, rope_attr[0], DYNAMIC_MODELS[model_name].get("rope"))
|
||||
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
|
||||
setattr(
|
||||
DYNAMIC_MODELS[model_name].get("npu_fusion_attention"),
|
||||
rope_attr[0],
|
||||
DYNAMIC_MODELS[model_name].get("rope"),
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_patch_swiglu(model_name, module):
|
||||
|
@ -13,34 +13,27 @@
|
||||
import dataclasses
|
||||
import re
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.llama import modeling_llama
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops import attenions, rms_norm, rope, swiglu
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu
|
||||
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Pattern:
|
||||
attention: str = "ATTENTION_CLASSES"
|
||||
rmsnorm: str = "RMSNorm"
|
||||
rope: str = "apply_rotary_pos_emb"
|
||||
swiglu: str = "MLP"
|
||||
|
||||
|
||||
def _builtin_patch_flash_attention(RAW_ATTENTION_CLASSES: Dict, NEW_ATTENTION_CLASS: Type):
|
||||
def _builtin_patch_flash_attention(config):
|
||||
"""
|
||||
Patch the FA for transformers built-in models, call this method before the model instantiation is completed,
|
||||
when the model has already been instantiated, this method is not effective.
|
||||
"""
|
||||
RAW_ATTENTION_CLASSES.update({k: NEW_ATTENTION_CLASS for k in RAW_ATTENTION_CLASSES})
|
||||
|
||||
|
||||
def __builtin_patch_flash_attention_v2(config):
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
|
||||
|
||||
@ -64,16 +57,8 @@ def _builtin_patch_swiglu(module: ModuleType, class_name: str):
|
||||
|
||||
def _apply_fused_kernel_base(module: ModuleType, **kwargs):
|
||||
if kwargs.get("use_npu_fusion_attention", False):
|
||||
if torch.__version__ == "2.1.0":
|
||||
attention = kwargs.get("attention")
|
||||
pattern = re.compile(Pattern.attention)
|
||||
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)][0]
|
||||
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
config = kwargs.get("config")
|
||||
__builtin_patch_flash_attention_v2(config)
|
||||
else:
|
||||
pass
|
||||
config = kwargs.get("config")
|
||||
_builtin_patch_flash_attention(config)
|
||||
else:
|
||||
# if the FA fused option is not open, enforce eager mode.
|
||||
config = kwargs.get("config")
|
||||
@ -96,15 +81,15 @@ def _apply_fused_kernel_base(module: ModuleType, **kwargs):
|
||||
|
||||
|
||||
def apply_fused_kernel_qwen2(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_qwen2, attention=attenions.qwen2.Qwen2NPUAttention, **kwargs)
|
||||
_apply_fused_kernel_base(modeling_qwen2, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_llama(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_llama, attention=attenions.llama.LlamaNpuFusionAttention, **kwargs)
|
||||
_apply_fused_kernel_base(modeling_llama, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_mistral(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_mistral, attention=attenions.mistral.MistralNpuFlashAttention, **kwargs)
|
||||
_apply_fused_kernel_base(modeling_mistral, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_internlm2(**kwargs):
|
||||
@ -116,7 +101,6 @@ def apply_fused_kernel_internlm2(**kwargs):
|
||||
if "InternLM2ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
|
||||
dynamic_module_utils.register_dynamic_model(
|
||||
"InternLM2ForCausalLM",
|
||||
npu_fusion_attention=attenions.internlm2,
|
||||
rms_norm=rms_norm.rms_norm.NpuRMSNorm,
|
||||
rope=rope.rope.apply_rotary_pos_emb,
|
||||
swiglu=swiglu.swiglu.NpuIntern2SwiGlu,
|
||||
@ -139,7 +123,6 @@ def apply_fused_kernel_internlm3(**kwargs):
|
||||
if "InternLM3ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
|
||||
dynamic_module_utils.register_dynamic_model(
|
||||
"InternLM3ForCausalLM",
|
||||
npu_fusion_attention=attenions.internlm3,
|
||||
rms_norm=rms_norm.rms_norm.NpuRMSNorm,
|
||||
rope=rope.rope.apply_rotary_pos_emb,
|
||||
swiglu=swiglu.swiglu.NpuSwiGlu,
|
||||
|
@ -30,7 +30,6 @@ logger = logging.get_logger()
|
||||
@lru_cache
|
||||
def check_use_fused_kernel(inner=False) -> bool:
|
||||
"""
|
||||
|
||||
Args:
|
||||
inner: When this method is called internally within openmind-cli, the value of this parameter is True,
|
||||
and the openmind related args will be verified. when the outside user try to call `apply_fused_kernel()`
|
||||
@ -39,13 +38,15 @@ def check_use_fused_kernel(inner=False) -> bool:
|
||||
Returns: If the environment and the parameters allowed use fused options, return True, otherwise return False
|
||||
|
||||
"""
|
||||
if is_torch_available():
|
||||
state = PartialState()
|
||||
device_module = getattr(torch, state.device.type.lower(), None)
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
# not in npu environment
|
||||
if "npu" not in str(device_module):
|
||||
return False
|
||||
state = PartialState()
|
||||
device_module = getattr(torch, state.device.type.lower(), None)
|
||||
|
||||
# not in npu environment
|
||||
if "npu" not in str(device_module):
|
||||
return False
|
||||
|
||||
# torch npu is not available
|
||||
if not is_torch_npu_available():
|
||||
@ -62,16 +63,15 @@ def check_use_fused_kernel(inner=False) -> bool:
|
||||
return False
|
||||
|
||||
# installed version of transformers and torch is not compatible for npu fused options
|
||||
if torch.__version__ == "2.1.0" and version.check_package_version("transformers<=4.47.1, >=4.45.0"):
|
||||
return True
|
||||
elif torch.__version__ >= "2.6.0" and version.check_package_version("transformers>=4.51.1"):
|
||||
if version.check_package_version("transformers>=4.51.1, <=4.51.3") and (
|
||||
version.check_package_version("torch>=2.1.0, <2.1.1") or version.check_package_version("torch>=2.6.0, <2.6.1")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
logger.warning_rank0(
|
||||
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__} "
|
||||
f"and the torch v{torch.__version__}. To use npu fused options, if torch version >= 2.6.0, the version of "
|
||||
f"transformers is required at least v4.51.1; if torch version == 2.1.0, the version of transformers is "
|
||||
f"required >= v4.45.0, and <= 4.47.1; In other cases, the npu fused options will not be available. "
|
||||
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__}. "
|
||||
f"To use npu fused options, you need torch == 2.1.0 or 2.6.0, and transformers >= 4.51.1, <=4.51.3 . "
|
||||
f"In other cases, the npu fused options will not available. "
|
||||
)
|
||||
return False
|
||||
|
||||
@ -118,6 +118,8 @@ def apply_fused_kernel(**kwargs):
|
||||
specified fusion operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
`use_fused_rope: bool = False`, default is True, set it to `False` to disable npu apply_rotary_pos_emb.
|
||||
`use_fused_swiglu: bool = False`, default is True, set it to `False` to disable npu swiglu.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
|
||||
@ -198,6 +200,5 @@ SUPPORTED_FUSED_MODELS = {
|
||||
|
||||
def map_fused_kernel_to_model(architecture, **kwargs):
|
||||
if architecture not in SUPPORTED_FUSED_MODELS:
|
||||
logger.warning_rank0(f"Unsupported fused model architecture: {architecture}")
|
||||
return
|
||||
SUPPORTED_FUSED_MODELS.get(architecture)(inner=True, **kwargs)
|
||||
|
@ -29,7 +29,6 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
|
||||
_dynamic_patch_swiglu,
|
||||
patch_dynamic_fused_ops,
|
||||
)
|
||||
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
|
||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
|
||||
|
||||
@ -38,7 +37,6 @@ class TestDynamicModelsRegistration(unittest.TestCase):
|
||||
"""Test dynamic model registration"""
|
||||
self.assertIn("InternLM2ForCausalLM", DYNAMIC_MODELS)
|
||||
model_config = DYNAMIC_MODELS["InternLM2ForCausalLM"]
|
||||
self.assertIsNotNone(model_config.get("npu_fusion_attention"))
|
||||
self.assertIsNotNone(model_config.get("rms_norm"))
|
||||
self.assertIsNotNone(model_config.get("rope"))
|
||||
self.assertIsNotNone(model_config.get("swiglu"))
|
||||
@ -86,8 +84,10 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
mock_module = ModuleType("mock_module")
|
||||
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
|
||||
|
||||
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=Config())
|
||||
self.assertIsInstance(mock_module.ATTENTION_CLASSES["eager"].forward, internlm2.forward.__class__)
|
||||
mock_config = Config()
|
||||
|
||||
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config)
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
|
||||
@patch("torch.__version__", "2.6.0")
|
||||
def test_torch_260_sets_sdpa(self):
|
||||
@ -102,19 +102,6 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
|
||||
@patch("torch.__version__", "2.5.0")
|
||||
def test_torch_other_sets_sdpa(self):
|
||||
model_name = "test_model_260"
|
||||
DYNAMIC_MODELS[model_name] = {}
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "sdpa"
|
||||
|
||||
mock_config = Config()
|
||||
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
|
||||
|
||||
self.assertEqual(mock_config._attn_implementation, "eager")
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
def test_rms_norm_patching(self, _, __):
|
||||
@ -137,11 +124,6 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
getattr(DYNAMIC_MODELS["InternLM2ForCausalLM"]["npu_fusion_attention"], "apply_rotary_pos_emb"),
|
||||
DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"],
|
||||
)
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
def test_swiglu_patching(self, _, __):
|
||||
@ -155,16 +137,11 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
def test_rope_patching_without_attention(self, _, __):
|
||||
|
||||
original_attention = DYNAMIC_MODELS["InternLM2ForCausalLM"].pop("npu_fusion_attention")
|
||||
|
||||
mock_module = ModuleType("mock_module")
|
||||
mock_module.apply_rotary_pos_emb = MagicMock()
|
||||
|
||||
_dynamic_patch_rope("InternLM2ForCausalLM", mock_module)
|
||||
|
||||
DYNAMIC_MODELS["InternLM2ForCausalLM"]["npu_fusion_attention"] = original_attention
|
||||
|
||||
self.assertEqual(mock_module.apply_rotary_pos_emb, DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"])
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
|
@ -0,0 +1,117 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu
|
||||
|
||||
|
||||
class TestFusedKernel(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.mock_cache = Path(self.temp_dir.name)
|
||||
self.patcher = patch(
|
||||
"openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils.HF_MODULES_CACHE", self.mock_cache
|
||||
)
|
||||
self.patcher.start()
|
||||
self.original_rope = transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
|
||||
self.original_rmsnorm = transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
|
||||
self.original_swiglu = transformers.models.qwen2.modeling_qwen2.Qwen2MLP
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
self.patcher.stop()
|
||||
transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope
|
||||
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm
|
||||
transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu
|
||||
|
||||
def test_builtin_patch_flash_attention(self):
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
mock_config = Config()
|
||||
kernel._builtin_patch_flash_attention(mock_config)
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
|
||||
def test_builtin_patch_rmsnorm(self):
|
||||
kernel._builtin_patch_rmsnorm(transformers.models.qwen2.modeling_qwen2, "Qwen2RMSNorm")
|
||||
self.assertEqual(rms_norm.rms_norm.NpuRMSNorm, transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm)
|
||||
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm
|
||||
|
||||
def test_builtin_patch_rope(self):
|
||||
self.assertNotEqual(
|
||||
rope.rope.apply_rotary_pos_emb, transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
|
||||
)
|
||||
kernel._builtin_patch_rope(transformers.models.qwen2.modeling_qwen2, "apply_rotary_pos_emb")
|
||||
self.assertEqual(rope.rope.apply_rotary_pos_emb, transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb)
|
||||
transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope
|
||||
|
||||
def test_builtin_patch_swiglu(self):
|
||||
self.assertNotEqual(swiglu.swiglu.NpuSwiGlu, transformers.models.qwen2.modeling_qwen2.Qwen2MLP)
|
||||
kernel._builtin_patch_swiglu(transformers.models.qwen2.modeling_qwen2, "Qwen2MLP")
|
||||
self.assertEqual(swiglu.swiglu.NpuSwiGlu, transformers.models.qwen2.modeling_qwen2.Qwen2MLP)
|
||||
transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu
|
||||
|
||||
def test_apply_fused_kernel_base(self):
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
mock_config = Config()
|
||||
kwargs = {
|
||||
"use_npu_fusion_attention": True,
|
||||
"use_fused_rms_norm": True,
|
||||
"use_fused_rope": True,
|
||||
"use_fused_swiglu": True,
|
||||
"config": mock_config,
|
||||
}
|
||||
module = transformers.models.qwen2.modeling_qwen2
|
||||
kernel._apply_fused_kernel_base(module, **kwargs)
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
self.assertEqual(module.Qwen2RMSNorm, rms_norm.rms_norm.NpuRMSNorm)
|
||||
self.assertEqual(module.apply_rotary_pos_emb, rope.rope.apply_rotary_pos_emb)
|
||||
self.assertEqual(module.Qwen2MLP, swiglu.swiglu.NpuSwiGlu)
|
||||
|
||||
kwargs["use_npu_fusion_attention"] = False
|
||||
kernel._apply_fused_kernel_base(module, **kwargs)
|
||||
self.assertEqual(mock_config._attn_implementation, "eager")
|
||||
|
||||
def test_apply_fused_kernel_internlm2(self):
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
mock_config = Config()
|
||||
kwargs = {
|
||||
"use_npu_fusion_attention": True,
|
||||
"use_fused_rms_norm": True,
|
||||
"use_fused_rope": True,
|
||||
"use_fused_swiglu": True,
|
||||
"config": mock_config,
|
||||
}
|
||||
module_path = self.mock_cache / "test_module.py"
|
||||
module_path.write_text(
|
||||
"class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass"
|
||||
"\ndef apply_rotary_pos_emb():\n pass\n"
|
||||
)
|
||||
original_utils = transformers.dynamic_module_utils.get_class_in_module
|
||||
kernel.apply_fused_kernel_internlm2(**kwargs)
|
||||
self.assertNotEqual(transformers.dynamic_module_utils.get_class_in_module, original_utils)
|
||||
transformers.dynamic_module_utils.get_class_in_module(
|
||||
class_name="InternLM2ForCausalLM", module_path="test_module.py"
|
||||
)
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
203
tests/unit/integrations/transformers/npu_fused_ops/test_sdk.py
Normal file
203
tests/unit/integrations/transformers/npu_fused_ops/test_sdk.py
Normal file
@ -0,0 +1,203 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops.sdk import check_use_fused_kernel
|
||||
from openmind.integrations.transformers.npu_fused_ops import sdk
|
||||
|
||||
|
||||
class TestCheckUseFusedKernel:
|
||||
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=False)
|
||||
def test_torch_unavailable(self, mock_torch_avail):
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret1 = check_use_fused_kernel(inner=False)
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret2 = check_use_fused_kernel(inner=True)
|
||||
assert not ret1 and not ret2
|
||||
|
||||
@patch("transformers.__version__", "4.51.1")
|
||||
@patch("torch.__version__", "2.1.0")
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
|
||||
@patch("torch.npu", create=True)
|
||||
def test_not_npu_environment(self, mock_torch_npu, mock_torch_avail, mock_partial_state):
|
||||
mock_state = MagicMock()
|
||||
mock_state.device.type = "cuda"
|
||||
mock_partial_state.return_value = mock_state
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret1 = check_use_fused_kernel(inner=False)
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret2 = check_use_fused_kernel(inner=True)
|
||||
assert not ret1 and not ret2
|
||||
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.arguments")
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_npu_available", return_value=False)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
|
||||
@patch("torch.npu", create=True)
|
||||
def test_torch_npu_not_available(
|
||||
self, mock_torch_npu, mock_torch_avail, mock_npu_avail, mock_partial_state, mock_arguments
|
||||
):
|
||||
mock_state = MagicMock()
|
||||
mock_state.device.type = "npu"
|
||||
mock_partial_state.return_value = mock_state
|
||||
mock_args = MagicMock()
|
||||
mock_args.get_args().disable_fused_options = False
|
||||
mock_arguments.return_value = mock_args
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret1 = check_use_fused_kernel(inner=False)
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret2 = check_use_fused_kernel(inner=True)
|
||||
assert not ret1 and not ret2
|
||||
|
||||
@patch("transformers.__version__", "4.51.1")
|
||||
@patch("torch.__version__", "2.1.0")
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.arguments")
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_npu_available", return_value=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
|
||||
@patch("torch.npu", create=True)
|
||||
def test_inner_call_disable_fused(
|
||||
self, mock_torch_npu, mock_torch_avail, mock_npu_avail, mock_partial_state, mock_arguments
|
||||
):
|
||||
mock_state = MagicMock()
|
||||
mock_state.device.type = "npu"
|
||||
mock_partial_state.return_value = mock_state
|
||||
mock_args = MagicMock()
|
||||
mock_args.get_args().disable_fused_options = True
|
||||
mock_arguments.return_value = mock_args
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret1 = check_use_fused_kernel(inner=False)
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret2 = check_use_fused_kernel(inner=True)
|
||||
assert ret1 and not ret2
|
||||
|
||||
@patch("transformers.__version__", "4.51.1")
|
||||
@patch("torch.__version__", "2.1.0")
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.arguments")
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.PartialState", autospec=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_npu_available", return_value=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.is_torch_available", return_value=True)
|
||||
@patch("openmind.utils.version.check_package_version", return_value=True)
|
||||
@patch("torch.npu", create=True)
|
||||
def test_compatible_versions(
|
||||
self, mock_torch_npu, mock_check_version, mock_torch_avail, mock_npu_avail, mock_partial_state, mock_arguments
|
||||
):
|
||||
mock_state = MagicMock()
|
||||
mock_state.device.type = "npu"
|
||||
mock_partial_state.return_value = mock_state
|
||||
|
||||
mock_args = MagicMock()
|
||||
mock_args.get_args().disable_fused_options = False
|
||||
mock_arguments.return_value = mock_args
|
||||
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret1 = check_use_fused_kernel(inner=False)
|
||||
check_use_fused_kernel.cache_clear()
|
||||
ret2 = check_use_fused_kernel(inner=False)
|
||||
assert ret1 and ret2
|
||||
|
||||
|
||||
class TestParseParams:
|
||||
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=True)
|
||||
def test_parse_params_(self, _):
|
||||
kwargs = {
|
||||
"use_npu_fusion_attention": True,
|
||||
"use_fused_rms_norm": True,
|
||||
"use_fused_rope": True,
|
||||
"use_fused_swiglu": True,
|
||||
}
|
||||
ret = sdk._parse_params(**kwargs)
|
||||
assert all(ret.values())
|
||||
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=False)
|
||||
def test_parse_params_check_failed(self, _):
|
||||
kwargs = {
|
||||
"use_npu_fusion_attention": True,
|
||||
"use_fused_rms_norm": True,
|
||||
"use_fused_rope": True,
|
||||
"use_fused_swiglu": True,
|
||||
}
|
||||
sdk._parse_params.cache_clear()
|
||||
ret = sdk._parse_params(**kwargs)
|
||||
assert not all(ret.values())
|
||||
|
||||
|
||||
class TestApplyLog:
|
||||
def test_apply_log(self):
|
||||
kwargs = {
|
||||
"use_npu_fusion_attention": True,
|
||||
"use_fused_rms_norm": True,
|
||||
"use_fused_rope": True,
|
||||
"use_fused_swiglu": True,
|
||||
}
|
||||
model_type = None
|
||||
sdk._apply_log(model_type, **kwargs)
|
||||
|
||||
|
||||
class TestApplyFusedKernelGeneric(unittest.TestCase):
|
||||
def test_apply_fused_kernel_generic(self):
|
||||
apply_func = MagicMock()
|
||||
kwargs = {"config": None}
|
||||
sdk._apply_fused_kernel_generic(apply_func, **kwargs)
|
||||
apply_func.assert_called()
|
||||
|
||||
|
||||
class TestMapFusedKernel:
|
||||
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=True)
|
||||
def test_map_fused_kernel_to_model_builtin(self, _):
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
config = Config()
|
||||
sdk.map_fused_kernel_to_model(
|
||||
architecture="Qwen2ForCausalLM",
|
||||
use_npu_fusion_attention=True,
|
||||
use_fused_rms_norm=True,
|
||||
use_fused_rope=True,
|
||||
use_fused_swiglu=True,
|
||||
config=config,
|
||||
)
|
||||
assert config._attn_implementation == "sdpa"
|
||||
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.sdk.check_use_fused_kernel", return_value=True)
|
||||
@patch("openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils._raw_get_dynamic_module")
|
||||
def test_map_fused_kernel_to_model_dynamic(self, mock_raw_get_dynamic_module, mock_check_use_fused_kernel):
|
||||
import transformers
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
class Model:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
config = Config()
|
||||
mock_raw_get_dynamic_module = MagicMock()
|
||||
mock_raw_get_dynamic_module.model_name = "InternLM2ForCausalLM"
|
||||
mock_raw_get_dynamic_module.return_value = Model(config)
|
||||
sdk.map_fused_kernel_to_model(
|
||||
architecture="InternLM2ForCausalLM",
|
||||
use_npu_fusion_attention=True,
|
||||
use_fused_rms_norm=False,
|
||||
use_fused_rope=False,
|
||||
use_fused_swiglu=False,
|
||||
config=config,
|
||||
)
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module(class_name="InternLM2ForCausalLM", module_path="/")
|
||||
assert config._attn_implementation == "sdpa"
|
Reference in New Issue
Block a user