!248 修复internlm系列模型sdpa接口未正常适配,新增qwen3模型支持融合算子(暂不包括多模态系列)
Merge pull request !248 from 幽若/master-0625
This commit is contained in:
@ -11,4 +11,4 @@
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
from . import sdpa_attention
|
||||
from . import sdpa_attention, internlm2, internlm3
|
@ -16,7 +16,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch InternLM2 model."""
|
||||
|
||||
import math
|
||||
# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -25,11 +26,6 @@ from transformers.cache_utils import Cache
|
||||
|
||||
from openmind.utils import logging, is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
@ -60,6 +56,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
if is_torch_npu_available():
|
||||
from openmind.integrations.transformers.npu_fused_ops.rope.rope import (
|
||||
apply_rotary_pos_emb as fused_rotary_pos_emb,
|
||||
)
|
||||
|
||||
return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
@ -89,9 +91,7 @@ def forward(
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
self.scale_value = 1.0 / math.sqrt(self.head_dim)
|
||||
if output_attentions:
|
||||
# Improve this warning with e.g. `model.config.attn_implementation = "manual"`
|
||||
# once this is implemented.
|
||||
logger.warning_once(
|
||||
"InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
|
||||
@ -144,22 +144,29 @@ def forward(
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
attn_output = torch_npu.npu_fusion_attention(
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
|
||||
# an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
|
||||
# options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = bool(causal_mask is None and q_len > 1)
|
||||
|
||||
if is_torch_npu_available():
|
||||
is_causal = True
|
||||
causal_mask = None
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
atten_mask=causal_mask.bool(),
|
||||
scale=self.scale_value,
|
||||
# pre_tockens and next_tockens are used for sparse computation,
|
||||
# `2147483647` is the default value for the npu_fusion_attention interface.
|
||||
pre_tockens=2147483647,
|
||||
next_tockens=2147483647,
|
||||
keep_prob=1,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
|
||||
# 2025.01.16 - Adapt to openmind.
|
||||
# Huawei Technologies Co., Ltd.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -16,7 +17,9 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch InternLM3 model."""
|
||||
|
||||
import math
|
||||
# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -24,11 +27,6 @@ from transformers.cache_utils import Cache
|
||||
|
||||
from openmind.utils import logging, is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
@ -60,6 +58,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
if is_torch_npu_available():
|
||||
from openmind.integrations.transformers.npu_fused_ops.rope.rope import (
|
||||
apply_rotary_pos_emb as fused_rotary_pos_emb,
|
||||
)
|
||||
|
||||
return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
@ -91,7 +95,22 @@ def forward(
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
self.scale_value = 1.0 / math.sqrt(self.head_dim)
|
||||
if output_attentions:
|
||||
logger.warning_once(
|
||||
"InternLM3Model is using InternLM3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
@ -127,30 +146,28 @@ def forward(
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
attn_output = torch_npu.npu_fusion_attention(
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
atten_mask=causal_mask.bool(),
|
||||
scale=self.scale_value,
|
||||
# pre_tockens and next_tockens are used for sparse computation,
|
||||
# `2147483647` is the default value for the npu_fusion_attention interface.
|
||||
pre_tockens=2147483647,
|
||||
next_tockens=2147483647,
|
||||
keep_prob=1,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, None, past_key_value
|
@ -38,6 +38,7 @@ from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
from openmind.integrations.transformers.npu_fused_ops.rope import rope
|
||||
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
|
||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||
from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2, internlm3
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
@ -59,7 +60,8 @@ def register_dynamic_model(model_name: str, /, **kwargs):
|
||||
model_name: Autoclass name, such as InternLM2ForCausalLM.
|
||||
**kwargs: supported npu fused options, all kwargs can be None.
|
||||
kwargs include the follow params:
|
||||
npu_fusion_attention: the adapter module of npu fused attention.
|
||||
sdpa_cls_name: the sdpa attention class name of the module
|
||||
sdpa_forward: the forward function of npu fused attention.
|
||||
rms_norm: the class of npu fused rms norm.
|
||||
rope: the function of npu fused rotary position embedding
|
||||
swiglu: the class of npu fused SwiGLU.
|
||||
@ -74,6 +76,17 @@ register_dynamic_model(
|
||||
rms_norm=rms_norm.NpuRMSNorm,
|
||||
rope=rope.apply_rotary_pos_emb,
|
||||
swiglu=swiglu.NpuIntern2SwiGlu,
|
||||
sdpa_cls_name="InternLM2SdpaAttention",
|
||||
sdpa_forward=internlm2.forward,
|
||||
)
|
||||
|
||||
register_dynamic_model(
|
||||
"InternLM3ForCausalLM",
|
||||
rms_norm=rms_norm.NpuRMSNorm,
|
||||
rope=rope.apply_rotary_pos_emb,
|
||||
swiglu=swiglu.NpuSwiGlu,
|
||||
sdpa_cls_name="InternLM3SdpaAttention",
|
||||
sdpa_forward=internlm3.forward,
|
||||
)
|
||||
|
||||
|
||||
@ -124,6 +137,17 @@ def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs
|
||||
kernel._patch_sdpa_forward()
|
||||
if config is not None:
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
setattr(config, "attn_implementation", "sdpa")
|
||||
if not DYNAMIC_MODELS[model_name].get("sdpa_cls_name", None) or not DYNAMIC_MODELS[model_name].get(
|
||||
"sdpa_forward", None
|
||||
):
|
||||
logger.warning_rank0(
|
||||
"When execute register_dynamic_model function, `sdpa_cls_name` or `sdpa_forward` not given, "
|
||||
"it may not work with SDPA Attention, we suggest you to set it manually or use the eager mode."
|
||||
)
|
||||
return
|
||||
sdpa_attention_cls = getattr(module, DYNAMIC_MODELS[model_name].get("sdpa_cls_name"))
|
||||
setattr(sdpa_attention_cls, "forward", DYNAMIC_MODELS[model_name].get("sdpa_forward"))
|
||||
|
||||
|
||||
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
|
||||
|
@ -15,11 +15,12 @@ import re
|
||||
from types import ModuleType
|
||||
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.llama import modeling_llama
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
from openmind.utils.version import check_package_version
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions
|
||||
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
||||
|
||||
|
||||
@ -40,9 +41,9 @@ def _patch_sdpa_forward():
|
||||
from transformers.integrations import sdpa_attention
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||
|
||||
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
|
||||
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
||||
sdpa_attention.sdpa_attention_forward = attentions.sdpa_attention.sdpa_attention_forward
|
||||
AttentionInterface._global_mapping["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward
|
||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward
|
||||
|
||||
|
||||
def _builtin_patch_flash_attention(config=None):
|
||||
@ -105,6 +106,10 @@ def apply_fused_kernel_qwen2(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_qwen2, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_qwen3(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_qwen3, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_llama(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_llama, **kwargs)
|
||||
|
||||
|
@ -134,20 +134,28 @@ def apply_fused_kernel_to_qwen2(**kwargs):
|
||||
Apply npu fused operators for Qwen2 series models, when call this function, all supported
|
||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||
specified fused operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
`use_npu_fusion_attention: bool`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
`use_fused_rope: bool`, default is True, set it to `False` to disable npu fused RoPE
|
||||
`use_fused_swiglu: bool`, default is True, set it to `False` to disable npu fused SwiGLU.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
|
||||
_apply_log(model_type="qwen2", **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_to_qwen3(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Qwen3 series models, when call this function, all supported
|
||||
fusion operators will be enabled by default.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
|
||||
_apply_log(model_type="qwen3", **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_to_internlm2(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||
specified fused operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
fusion operators will be enabled by default.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs)
|
||||
_apply_log(model_type="internlm2", **kwargs)
|
||||
@ -156,10 +164,7 @@ def apply_fused_kernel_to_internlm2(**kwargs):
|
||||
def apply_fused_kernel_to_internlm3(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||
specified fused operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
fusion operators will be enabled by default.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm3, **kwargs)
|
||||
_apply_log(model_type="internlm3", **kwargs)
|
||||
@ -168,10 +173,7 @@ def apply_fused_kernel_to_internlm3(**kwargs):
|
||||
def apply_fused_kernel_to_llama(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Llama series models, when call this function, all supported
|
||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||
specified fused operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
fusion operators will be enabled by default.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
|
||||
_apply_log(model_type="llama", **kwargs)
|
||||
@ -180,10 +182,7 @@ def apply_fused_kernel_to_llama(**kwargs):
|
||||
def apply_fused_kernel_to_mistral(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Mistral series models, when call this function, all supported
|
||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||
specified fused operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
fusion operators will be enabled by default.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs)
|
||||
_apply_log(model_type="mistral", **kwargs)
|
||||
@ -191,6 +190,7 @@ def apply_fused_kernel_to_mistral(**kwargs):
|
||||
|
||||
SUPPORTED_FUSED_MODELS = {
|
||||
"Qwen2ForCausalLM": apply_fused_kernel_to_qwen2,
|
||||
"Qwen3ForCausalLM": apply_fused_kernel_to_qwen3,
|
||||
"LlamaForCausalLM": apply_fused_kernel_to_llama,
|
||||
"MistralForCausalLM": apply_fused_kernel_to_mistral,
|
||||
"InternLM2ForCausalLM": apply_fused_kernel_to_internlm2,
|
||||
|
@ -31,7 +31,8 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
|
||||
patch_dynamic_fused_ops,
|
||||
)
|
||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
from openmind.integrations.transformers.npu_fused_ops import attenions
|
||||
from openmind.integrations.transformers.npu_fused_ops import attentions
|
||||
from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2
|
||||
|
||||
|
||||
class TestDynamicModelsRegistration(unittest.TestCase):
|
||||
@ -78,25 +79,27 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
@patch("torch.__version__", "2.1.0")
|
||||
def test_attention_patching(self, _, __):
|
||||
|
||||
class MockAttentionBase:
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
mock_module = ModuleType("mock_module")
|
||||
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
|
||||
class MockInternLM2SdpaAttention:
|
||||
pass
|
||||
|
||||
mock_module = ModuleType("mock_module")
|
||||
setattr(mock_module, "InternLM2SdpaAttention", MockInternLM2SdpaAttention)
|
||||
mock_config = Config()
|
||||
|
||||
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config)
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
self.assertEqual(mock_module.InternLM2SdpaAttention.forward, internlm2.forward)
|
||||
|
||||
@patch("torch.__version__", "2.6.0")
|
||||
def test_torch_260_sets_sdpa(self):
|
||||
model_name = "test_model_260"
|
||||
DYNAMIC_MODELS[model_name] = {}
|
||||
DYNAMIC_MODELS[model_name] = {
|
||||
"sdpa_cls_name": "InternLM2SdpaAttention",
|
||||
"sdpa_forward": internlm2.forward,
|
||||
}
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
@ -107,7 +110,7 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
self.assertEqual(
|
||||
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
||||
attenions.sdpa_attention.sdpa_attention_forward,
|
||||
attentions.sdpa_attention.sdpa_attention_forward,
|
||||
)
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
|
@ -20,7 +20,7 @@ import transformers
|
||||
from transformers.integrations import sdpa_attention
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions
|
||||
|
||||
|
||||
class TestFusedKernel(unittest.TestCase):
|
||||
@ -61,7 +61,7 @@ class TestFusedKernel(unittest.TestCase):
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
self.assertEqual(
|
||||
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
||||
attenions.sdpa_attention.sdpa_attention_forward,
|
||||
attentions.sdpa_attention.sdpa_attention_forward,
|
||||
)
|
||||
|
||||
def test_builtin_patch_rmsnorm(self):
|
||||
@ -121,7 +121,7 @@ class TestFusedKernel(unittest.TestCase):
|
||||
module_path = self.mock_cache / "test_module.py"
|
||||
module_path.write_text(
|
||||
"class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass"
|
||||
"\ndef apply_rotary_pos_emb():\n pass\n"
|
||||
"\ndef apply_rotary_pos_emb():\n pass\nclass InternLM2SdpaAttention:\n pass\n"
|
||||
)
|
||||
original_utils = transformers.dynamic_module_utils.get_class_in_module
|
||||
kernel.apply_fused_kernel_internlm2(**kwargs)
|
||||
|
Reference in New Issue
Block a user