diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/__init__.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/__init__.py similarity index 91% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/__init__.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/__init__.py index e8041c3..59a9016 100644 --- a/src/openmind/integrations/transformers/npu_fused_ops/attenions/__init__.py +++ b/src/openmind/integrations/transformers/npu_fused_ops/attentions/__init__.py @@ -11,4 +11,4 @@ # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # See the Mulan PSL v2 for more details. -from . import sdpa_attention +from . import sdpa_attention, internlm2, internlm3 diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/internlm2.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/internlm2.py similarity index 82% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/internlm2.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/internlm2.py index e0fb43b..84c72da 100644 --- a/src/openmind/integrations/transformers/npu_fused_ops/attenions/internlm2.py +++ b/src/openmind/integrations/transformers/npu_fused_ops/attentions/internlm2.py @@ -16,7 +16,8 @@ # limitations under the License. """PyTorch InternLM2 model.""" -import math +# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed + from typing import Optional, Tuple import torch @@ -25,11 +26,6 @@ from transformers.cache_utils import Cache from openmind.utils import logging, is_torch_npu_available -if is_torch_npu_available(): - import torch_npu -else: - pass - logger = logging.get_logger() @@ -60,6 +56,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # Returns: `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding. """ + if is_torch_npu_available(): + from openmind.integrations.transformers.npu_fused_ops.rope.rope import ( + apply_rotary_pos_emb as fused_rotary_pos_emb, + ) + + return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) @@ -89,9 +91,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - self.scale_value = 1.0 / math.sqrt(self.head_dim) if output_attentions: - # Improve this warning with e.g. `model.config.attn_implementation = "manual"` # once this is implemented. logger.warning_once( "InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` " @@ -144,22 +144,29 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - attn_output = torch_npu.npu_fusion_attention( + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of + # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph + # options. An inline conditional prevents dynamic shapes from compiling. + is_causal = bool(causal_mask is None and q_len > 1) + + if is_torch_npu_available(): + is_causal = True + causal_mask = None + + attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102 query_states, key_states, value_states, - self.num_heads, - input_layout="BNSD", - pse=None, - atten_mask=causal_mask.bool(), - scale=self.scale_value, - # pre_tockens and next_tockens are used for sparse computation, - # `2147483647` is the default value for the npu_fusion_attention interface. - pre_tockens=2147483647, - next_tockens=2147483647, - keep_prob=1, - inner_precise=0, - )[0] + attn_mask=causal_mask, + dropout_p=0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/internlm3.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/internlm3.py similarity index 71% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/internlm3.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/internlm3.py index fec8f5d..00f32d4 100644 --- a/src/openmind/integrations/transformers/npu_fused_ops/attenions/internlm3.py +++ b/src/openmind/integrations/transformers/npu_fused_ops/attentions/internlm3.py @@ -1,6 +1,7 @@ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. # 2025.01.16 - Adapt to openmind. # Huawei Technologies Co., Ltd. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # This code is based on transformers/src/transformers/models/llama/modeling_llama.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,7 +17,9 @@ # limitations under the License. """PyTorch InternLM3 model.""" -import math +# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed + + from typing import Optional, Tuple import torch @@ -24,11 +27,6 @@ from transformers.cache_utils import Cache from openmind.utils import logging, is_torch_npu_available -if is_torch_npu_available(): - import torch_npu -else: - pass - logger = logging.get_logger() @@ -60,6 +58,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # Returns: `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding. """ + if is_torch_npu_available(): + from openmind.integrations.transformers.npu_fused_ops.rope.rope import ( + apply_rotary_pos_emb as fused_rotary_pos_emb, + ) + + return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) @@ -91,7 +95,22 @@ def forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - self.scale_value = 1.0 / math.sqrt(self.head_dim) + if output_attentions: + logger.warning_once( + "InternLM3Model is using InternLM3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -127,30 +146,28 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - attn_output = torch_npu.npu_fusion_attention( + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - self.num_heads, - input_layout="BNSD", - pse=None, - atten_mask=causal_mask.bool(), - scale=self.scale_value, - # pre_tockens and next_tockens are used for sparse computation, - # `2147483647` is the default value for the npu_fusion_attention interface. - pre_tockens=2147483647, - next_tockens=2147483647, - keep_prob=1, - inner_precise=0, - )[0] + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/llama.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/llama.py similarity index 100% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/llama.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/llama.py diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/mistral.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/mistral.py similarity index 100% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/mistral.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/mistral.py diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/qwen2.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/qwen2.py similarity index 100% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/qwen2.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/qwen2.py diff --git a/src/openmind/integrations/transformers/npu_fused_ops/attenions/sdpa_attention.py b/src/openmind/integrations/transformers/npu_fused_ops/attentions/sdpa_attention.py similarity index 100% rename from src/openmind/integrations/transformers/npu_fused_ops/attenions/sdpa_attention.py rename to src/openmind/integrations/transformers/npu_fused_ops/attentions/sdpa_attention.py diff --git a/src/openmind/integrations/transformers/npu_fused_ops/dynamic_module_utils.py b/src/openmind/integrations/transformers/npu_fused_ops/dynamic_module_utils.py index 8a23e69..ba99059 100644 --- a/src/openmind/integrations/transformers/npu_fused_ops/dynamic_module_utils.py +++ b/src/openmind/integrations/transformers/npu_fused_ops/dynamic_module_utils.py @@ -38,6 +38,7 @@ from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm from openmind.integrations.transformers.npu_fused_ops.rope import rope from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu from openmind.integrations.transformers.npu_fused_ops import kernel +from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2, internlm3 logger = logging.get_logger() @@ -59,7 +60,8 @@ def register_dynamic_model(model_name: str, /, **kwargs): model_name: Autoclass name, such as InternLM2ForCausalLM. **kwargs: supported npu fused options, all kwargs can be None. kwargs include the follow params: - npu_fusion_attention: the adapter module of npu fused attention. + sdpa_cls_name: the sdpa attention class name of the module + sdpa_forward: the forward function of npu fused attention. rms_norm: the class of npu fused rms norm. rope: the function of npu fused rotary position embedding swiglu: the class of npu fused SwiGLU. @@ -74,6 +76,17 @@ register_dynamic_model( rms_norm=rms_norm.NpuRMSNorm, rope=rope.apply_rotary_pos_emb, swiglu=swiglu.NpuIntern2SwiGlu, + sdpa_cls_name="InternLM2SdpaAttention", + sdpa_forward=internlm2.forward, +) + +register_dynamic_model( + "InternLM3ForCausalLM", + rms_norm=rms_norm.NpuRMSNorm, + rope=rope.apply_rotary_pos_emb, + swiglu=swiglu.NpuSwiGlu, + sdpa_cls_name="InternLM3SdpaAttention", + sdpa_forward=internlm3.forward, ) @@ -124,6 +137,17 @@ def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs kernel._patch_sdpa_forward() if config is not None: setattr(config, "_attn_implementation", "sdpa") + setattr(config, "attn_implementation", "sdpa") + if not DYNAMIC_MODELS[model_name].get("sdpa_cls_name", None) or not DYNAMIC_MODELS[model_name].get( + "sdpa_forward", None + ): + logger.warning_rank0( + "When execute register_dynamic_model function, `sdpa_cls_name` or `sdpa_forward` not given, " + "it may not work with SDPA Attention, we suggest you to set it manually or use the eager mode." + ) + return + sdpa_attention_cls = getattr(module, DYNAMIC_MODELS[model_name].get("sdpa_cls_name")) + setattr(sdpa_attention_cls, "forward", DYNAMIC_MODELS[model_name].get("sdpa_forward")) def _dynamic_patch_rms_norm(model_name: str, module: ModuleType): diff --git a/src/openmind/integrations/transformers/npu_fused_ops/kernel.py b/src/openmind/integrations/transformers/npu_fused_ops/kernel.py index ef28ec1..0d9a281 100644 --- a/src/openmind/integrations/transformers/npu_fused_ops/kernel.py +++ b/src/openmind/integrations/transformers/npu_fused_ops/kernel.py @@ -15,11 +15,12 @@ import re from types import ModuleType from transformers.models.qwen2 import modeling_qwen2 +from transformers.models.qwen3 import modeling_qwen3 from transformers.models.llama import modeling_llama from transformers.models.mistral import modeling_mistral from openmind.utils.version import check_package_version -from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions +from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils @@ -40,9 +41,9 @@ def _patch_sdpa_forward(): from transformers.integrations import sdpa_attention from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface - sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward - AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward - ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward + sdpa_attention.sdpa_attention_forward = attentions.sdpa_attention.sdpa_attention_forward + AttentionInterface._global_mapping["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward + ALL_ATTENTION_FUNCTIONS["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward def _builtin_patch_flash_attention(config=None): @@ -105,6 +106,10 @@ def apply_fused_kernel_qwen2(**kwargs): _apply_fused_kernel_base(modeling_qwen2, **kwargs) +def apply_fused_kernel_qwen3(**kwargs): + _apply_fused_kernel_base(modeling_qwen3, **kwargs) + + def apply_fused_kernel_llama(**kwargs): _apply_fused_kernel_base(modeling_llama, **kwargs) diff --git a/src/openmind/integrations/transformers/npu_fused_ops/sdk.py b/src/openmind/integrations/transformers/npu_fused_ops/sdk.py index e7a188e..b79299d 100644 --- a/src/openmind/integrations/transformers/npu_fused_ops/sdk.py +++ b/src/openmind/integrations/transformers/npu_fused_ops/sdk.py @@ -134,20 +134,28 @@ def apply_fused_kernel_to_qwen2(**kwargs): Apply npu fused operators for Qwen2 series models, when call this function, all supported fusion operators will be enabled by default. You can set the following parameters to disable the specified fused operator: - `use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention. - `use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm. + `use_npu_fusion_attention: bool`, default is True, set it to `False` to disable npu fusion attention. + `use_fused_rms_norm: bool`, default is True, set it to `False` to disable npu RMSNorm. + `use_fused_rope: bool`, default is True, set it to `False` to disable npu fused RoPE + `use_fused_swiglu: bool`, default is True, set it to `False` to disable npu fused SwiGLU. """ _apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs) _apply_log(model_type="qwen2", **kwargs) +def apply_fused_kernel_to_qwen3(**kwargs): + """ + Apply npu fused operators for Qwen3 series models, when call this function, all supported + fusion operators will be enabled by default. + """ + _apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs) + _apply_log(model_type="qwen3", **kwargs) + + def apply_fused_kernel_to_internlm2(**kwargs): """ Apply npu fused operators for Internlm2 series models, when call this function, all supported - fusion operators will be enabled by default. You can set the following parameters to disable the - specified fused operator: - `use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention. - `use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm. + fusion operators will be enabled by default. """ _apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs) _apply_log(model_type="internlm2", **kwargs) @@ -156,10 +164,7 @@ def apply_fused_kernel_to_internlm2(**kwargs): def apply_fused_kernel_to_internlm3(**kwargs): """ Apply npu fused operators for Internlm2 series models, when call this function, all supported - fusion operators will be enabled by default. You can set the following parameters to disable the - specified fused operator: - `use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention. - `use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm. + fusion operators will be enabled by default. """ _apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm3, **kwargs) _apply_log(model_type="internlm3", **kwargs) @@ -168,10 +173,7 @@ def apply_fused_kernel_to_internlm3(**kwargs): def apply_fused_kernel_to_llama(**kwargs): """ Apply npu fused operators for Llama series models, when call this function, all supported - fusion operators will be enabled by default. You can set the following parameters to disable the - specified fused operator: - `use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention. - `use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm. + fusion operators will be enabled by default. """ _apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs) _apply_log(model_type="llama", **kwargs) @@ -180,10 +182,7 @@ def apply_fused_kernel_to_llama(**kwargs): def apply_fused_kernel_to_mistral(**kwargs): """ Apply npu fused operators for Mistral series models, when call this function, all supported - fusion operators will be enabled by default. You can set the following parameters to disable the - specified fused operator: - `use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention. - `use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm. + fusion operators will be enabled by default. """ _apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs) _apply_log(model_type="mistral", **kwargs) @@ -191,6 +190,7 @@ def apply_fused_kernel_to_mistral(**kwargs): SUPPORTED_FUSED_MODELS = { "Qwen2ForCausalLM": apply_fused_kernel_to_qwen2, + "Qwen3ForCausalLM": apply_fused_kernel_to_qwen3, "LlamaForCausalLM": apply_fused_kernel_to_llama, "MistralForCausalLM": apply_fused_kernel_to_mistral, "InternLM2ForCausalLM": apply_fused_kernel_to_internlm2, diff --git a/tests/unit/integrations/transformers/npu_fused_ops/attenions/__init__.py b/tests/unit/integrations/transformers/npu_fused_ops/attentions/__init__.py similarity index 100% rename from tests/unit/integrations/transformers/npu_fused_ops/attenions/__init__.py rename to tests/unit/integrations/transformers/npu_fused_ops/attentions/__init__.py diff --git a/tests/unit/integrations/transformers/npu_fused_ops/test_dynamic_module_utils.py b/tests/unit/integrations/transformers/npu_fused_ops/test_dynamic_module_utils.py index 726b286..9eb0fb9 100644 --- a/tests/unit/integrations/transformers/npu_fused_ops/test_dynamic_module_utils.py +++ b/tests/unit/integrations/transformers/npu_fused_ops/test_dynamic_module_utils.py @@ -31,7 +31,8 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor patch_dynamic_fused_ops, ) from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm -from openmind.integrations.transformers.npu_fused_ops import attenions +from openmind.integrations.transformers.npu_fused_ops import attentions +from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2 class TestDynamicModelsRegistration(unittest.TestCase): @@ -78,25 +79,27 @@ class TestDynamicPatching(unittest.TestCase): @patch("torch.__version__", "2.1.0") def test_attention_patching(self, _, __): - class MockAttentionBase: - def forward(self): - pass - class Config: _attn_implementation = "eager" - mock_module = ModuleType("mock_module") - mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase} + class MockInternLM2SdpaAttention: + pass + mock_module = ModuleType("mock_module") + setattr(mock_module, "InternLM2SdpaAttention", MockInternLM2SdpaAttention) mock_config = Config() _dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config) self.assertEqual(mock_config._attn_implementation, "sdpa") + self.assertEqual(mock_module.InternLM2SdpaAttention.forward, internlm2.forward) @patch("torch.__version__", "2.6.0") def test_torch_260_sets_sdpa(self): model_name = "test_model_260" - DYNAMIC_MODELS[model_name] = {} + DYNAMIC_MODELS[model_name] = { + "sdpa_cls_name": "InternLM2SdpaAttention", + "sdpa_forward": internlm2.forward, + } class Config: _attn_implementation = "eager" @@ -107,7 +110,7 @@ class TestDynamicPatching(unittest.TestCase): self.assertEqual(mock_config._attn_implementation, "sdpa") self.assertEqual( transformers.integrations.sdpa_attention.sdpa_attention_forward, - attenions.sdpa_attention.sdpa_attention_forward, + attentions.sdpa_attention.sdpa_attention_forward, ) @patch("importlib.util.spec_from_file_location") diff --git a/tests/unit/integrations/transformers/npu_fused_ops/test_kernel.py b/tests/unit/integrations/transformers/npu_fused_ops/test_kernel.py index 6d8dc1a..4f9daa1 100644 --- a/tests/unit/integrations/transformers/npu_fused_ops/test_kernel.py +++ b/tests/unit/integrations/transformers/npu_fused_ops/test_kernel.py @@ -20,7 +20,7 @@ import transformers from transformers.integrations import sdpa_attention from openmind.integrations.transformers.npu_fused_ops import kernel -from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions +from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions class TestFusedKernel(unittest.TestCase): @@ -61,7 +61,7 @@ class TestFusedKernel(unittest.TestCase): self.assertEqual(mock_config._attn_implementation, "sdpa") self.assertEqual( transformers.integrations.sdpa_attention.sdpa_attention_forward, - attenions.sdpa_attention.sdpa_attention_forward, + attentions.sdpa_attention.sdpa_attention_forward, ) def test_builtin_patch_rmsnorm(self): @@ -121,7 +121,7 @@ class TestFusedKernel(unittest.TestCase): module_path = self.mock_cache / "test_module.py" module_path.write_text( "class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass" - "\ndef apply_rotary_pos_emb():\n pass\n" + "\ndef apply_rotary_pos_emb():\n pass\nclass InternLM2SdpaAttention:\n pass\n" ) original_utils = transformers.dynamic_module_utils.get_class_in_module kernel.apply_fused_kernel_internlm2(**kwargs)