!248 修复internlm系列模型sdpa接口未正常适配,新增qwen3模型支持融合算子(暂不包括多模态系列)

Merge pull request !248 from 幽若/master-0625
This commit is contained in:
2025-06-28 17:22:01 +00:00
committed by i-robot
parent b8186e58a8
commit 34b63d55f7
13 changed files with 139 additions and 83 deletions

View File

@ -11,4 +11,4 @@
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
from . import sdpa_attention
from . import sdpa_attention, internlm2, internlm3

View File

@ -16,7 +16,8 @@
# limitations under the License.
"""PyTorch InternLM2 model."""
import math
# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed
from typing import Optional, Tuple
import torch
@ -25,11 +26,6 @@ from transformers.cache_utils import Cache
from openmind.utils import logging, is_torch_npu_available
if is_torch_npu_available():
import torch_npu
else:
pass
logger = logging.get_logger()
@ -60,6 +56,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #
Returns:
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
"""
if is_torch_npu_available():
from openmind.integrations.transformers.npu_fused_ops.rope.rope import (
apply_rotary_pos_emb as fused_rotary_pos_emb,
)
return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
@ -89,9 +91,7 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
self.scale_value = 1.0 / math.sqrt(self.head_dim)
if output_attentions:
# Improve this warning with e.g. `model.config.attn_implementation = "manual"`
# once this is implemented.
logger.warning_once(
"InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
@ -144,22 +144,29 @@ def forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
attn_output = torch_npu.npu_fusion_attention(
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
# an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
# options. An inline conditional prevents dynamic shapes from compiling.
is_causal = bool(causal_mask is None and q_len > 1)
if is_torch_npu_available():
is_causal = True
causal_mask = None
attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
query_states,
key_states,
value_states,
self.num_heads,
input_layout="BNSD",
pse=None,
atten_mask=causal_mask.bool(),
scale=self.scale_value,
# pre_tockens and next_tockens are used for sparse computation,
# `2147483647` is the default value for the npu_fusion_attention interface.
pre_tockens=2147483647,
next_tockens=2147483647,
keep_prob=1,
inner_precise=0,
)[0]
attn_mask=causal_mask,
dropout_p=0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)

View File

@ -1,6 +1,7 @@
# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
# 2025.01.16 - Adapt to openmind.
# Huawei Technologies Co., Ltd.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -16,7 +17,9 @@
# limitations under the License.
"""PyTorch InternLM3 model."""
import math
# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed
from typing import Optional, Tuple
import torch
@ -24,11 +27,6 @@ from transformers.cache_utils import Cache
from openmind.utils import logging, is_torch_npu_available
if is_torch_npu_available():
import torch_npu
else:
pass
logger = logging.get_logger()
@ -60,6 +58,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #
Returns:
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
"""
if is_torch_npu_available():
from openmind.integrations.transformers.npu_fused_ops.rope.rope import (
apply_rotary_pos_emb as fused_rotary_pos_emb,
)
return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
@ -91,7 +95,22 @@ def forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
self.scale_value = 1.0 / math.sqrt(self.head_dim)
if output_attentions:
logger.warning_once(
"InternLM3Model is using InternLM3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
@ -127,30 +146,28 @@ def forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
attn_output = torch_npu.npu_fusion_attention(
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
self.num_heads,
input_layout="BNSD",
pse=None,
atten_mask=causal_mask.bool(),
scale=self.scale_value,
# pre_tockens and next_tockens are used for sparse computation,
# `2147483647` is the default value for the npu_fusion_attention interface.
pre_tockens=2147483647,
next_tockens=2147483647,
keep_prob=1,
inner_precise=0,
)[0]
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, None, past_key_value

View File

@ -38,6 +38,7 @@ from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
from openmind.integrations.transformers.npu_fused_ops.rope import rope
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
from openmind.integrations.transformers.npu_fused_ops import kernel
from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2, internlm3
logger = logging.get_logger()
@ -59,7 +60,8 @@ def register_dynamic_model(model_name: str, /, **kwargs):
model_name: Autoclass name, such as InternLM2ForCausalLM.
**kwargs: supported npu fused options, all kwargs can be None.
kwargs include the follow params:
npu_fusion_attention: the adapter module of npu fused attention.
sdpa_cls_name: the sdpa attention class name of the module
sdpa_forward: the forward function of npu fused attention.
rms_norm: the class of npu fused rms norm.
rope: the function of npu fused rotary position embedding
swiglu: the class of npu fused SwiGLU.
@ -74,6 +76,17 @@ register_dynamic_model(
rms_norm=rms_norm.NpuRMSNorm,
rope=rope.apply_rotary_pos_emb,
swiglu=swiglu.NpuIntern2SwiGlu,
sdpa_cls_name="InternLM2SdpaAttention",
sdpa_forward=internlm2.forward,
)
register_dynamic_model(
"InternLM3ForCausalLM",
rms_norm=rms_norm.NpuRMSNorm,
rope=rope.apply_rotary_pos_emb,
swiglu=swiglu.NpuSwiGlu,
sdpa_cls_name="InternLM3SdpaAttention",
sdpa_forward=internlm3.forward,
)
@ -124,6 +137,17 @@ def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs
kernel._patch_sdpa_forward()
if config is not None:
setattr(config, "_attn_implementation", "sdpa")
setattr(config, "attn_implementation", "sdpa")
if not DYNAMIC_MODELS[model_name].get("sdpa_cls_name", None) or not DYNAMIC_MODELS[model_name].get(
"sdpa_forward", None
):
logger.warning_rank0(
"When execute register_dynamic_model function, `sdpa_cls_name` or `sdpa_forward` not given, "
"it may not work with SDPA Attention, we suggest you to set it manually or use the eager mode."
)
return
sdpa_attention_cls = getattr(module, DYNAMIC_MODELS[model_name].get("sdpa_cls_name"))
setattr(sdpa_attention_cls, "forward", DYNAMIC_MODELS[model_name].get("sdpa_forward"))
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):

View File

@ -15,11 +15,12 @@ import re
from types import ModuleType
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.llama import modeling_llama
from transformers.models.mistral import modeling_mistral
from openmind.utils.version import check_package_version
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
@ -40,9 +41,9 @@ def _patch_sdpa_forward():
from transformers.integrations import sdpa_attention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
sdpa_attention.sdpa_attention_forward = attentions.sdpa_attention.sdpa_attention_forward
AttentionInterface._global_mapping["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward
ALL_ATTENTION_FUNCTIONS["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward
def _builtin_patch_flash_attention(config=None):
@ -105,6 +106,10 @@ def apply_fused_kernel_qwen2(**kwargs):
_apply_fused_kernel_base(modeling_qwen2, **kwargs)
def apply_fused_kernel_qwen3(**kwargs):
_apply_fused_kernel_base(modeling_qwen3, **kwargs)
def apply_fused_kernel_llama(**kwargs):
_apply_fused_kernel_base(modeling_llama, **kwargs)

View File

@ -134,20 +134,28 @@ def apply_fused_kernel_to_qwen2(**kwargs):
Apply npu fused operators for Qwen2 series models, when call this function, all supported
fusion operators will be enabled by default. You can set the following parameters to disable the
specified fused operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
`use_npu_fusion_attention: bool`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool`, default is True, set it to `False` to disable npu RMSNorm.
`use_fused_rope: bool`, default is True, set it to `False` to disable npu fused RoPE
`use_fused_swiglu: bool`, default is True, set it to `False` to disable npu fused SwiGLU.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
_apply_log(model_type="qwen2", **kwargs)
def apply_fused_kernel_to_qwen3(**kwargs):
"""
Apply npu fused operators for Qwen3 series models, when call this function, all supported
fusion operators will be enabled by default.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
_apply_log(model_type="qwen3", **kwargs)
def apply_fused_kernel_to_internlm2(**kwargs):
"""
Apply npu fused operators for Internlm2 series models, when call this function, all supported
fusion operators will be enabled by default. You can set the following parameters to disable the
specified fused operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
fusion operators will be enabled by default.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs)
_apply_log(model_type="internlm2", **kwargs)
@ -156,10 +164,7 @@ def apply_fused_kernel_to_internlm2(**kwargs):
def apply_fused_kernel_to_internlm3(**kwargs):
"""
Apply npu fused operators for Internlm2 series models, when call this function, all supported
fusion operators will be enabled by default. You can set the following parameters to disable the
specified fused operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
fusion operators will be enabled by default.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm3, **kwargs)
_apply_log(model_type="internlm3", **kwargs)
@ -168,10 +173,7 @@ def apply_fused_kernel_to_internlm3(**kwargs):
def apply_fused_kernel_to_llama(**kwargs):
"""
Apply npu fused operators for Llama series models, when call this function, all supported
fusion operators will be enabled by default. You can set the following parameters to disable the
specified fused operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
fusion operators will be enabled by default.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
_apply_log(model_type="llama", **kwargs)
@ -180,10 +182,7 @@ def apply_fused_kernel_to_llama(**kwargs):
def apply_fused_kernel_to_mistral(**kwargs):
"""
Apply npu fused operators for Mistral series models, when call this function, all supported
fusion operators will be enabled by default. You can set the following parameters to disable the
specified fused operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
fusion operators will be enabled by default.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs)
_apply_log(model_type="mistral", **kwargs)
@ -191,6 +190,7 @@ def apply_fused_kernel_to_mistral(**kwargs):
SUPPORTED_FUSED_MODELS = {
"Qwen2ForCausalLM": apply_fused_kernel_to_qwen2,
"Qwen3ForCausalLM": apply_fused_kernel_to_qwen3,
"LlamaForCausalLM": apply_fused_kernel_to_llama,
"MistralForCausalLM": apply_fused_kernel_to_mistral,
"InternLM2ForCausalLM": apply_fused_kernel_to_internlm2,

View File

@ -31,7 +31,8 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
patch_dynamic_fused_ops,
)
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
from openmind.integrations.transformers.npu_fused_ops import attenions
from openmind.integrations.transformers.npu_fused_ops import attentions
from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2
class TestDynamicModelsRegistration(unittest.TestCase):
@ -78,25 +79,27 @@ class TestDynamicPatching(unittest.TestCase):
@patch("torch.__version__", "2.1.0")
def test_attention_patching(self, _, __):
class MockAttentionBase:
def forward(self):
pass
class Config:
_attn_implementation = "eager"
mock_module = ModuleType("mock_module")
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
class MockInternLM2SdpaAttention:
pass
mock_module = ModuleType("mock_module")
setattr(mock_module, "InternLM2SdpaAttention", MockInternLM2SdpaAttention)
mock_config = Config()
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config)
self.assertEqual(mock_config._attn_implementation, "sdpa")
self.assertEqual(mock_module.InternLM2SdpaAttention.forward, internlm2.forward)
@patch("torch.__version__", "2.6.0")
def test_torch_260_sets_sdpa(self):
model_name = "test_model_260"
DYNAMIC_MODELS[model_name] = {}
DYNAMIC_MODELS[model_name] = {
"sdpa_cls_name": "InternLM2SdpaAttention",
"sdpa_forward": internlm2.forward,
}
class Config:
_attn_implementation = "eager"
@ -107,7 +110,7 @@ class TestDynamicPatching(unittest.TestCase):
self.assertEqual(mock_config._attn_implementation, "sdpa")
self.assertEqual(
transformers.integrations.sdpa_attention.sdpa_attention_forward,
attenions.sdpa_attention.sdpa_attention_forward,
attentions.sdpa_attention.sdpa_attention_forward,
)
@patch("importlib.util.spec_from_file_location")

View File

@ -20,7 +20,7 @@ import transformers
from transformers.integrations import sdpa_attention
from openmind.integrations.transformers.npu_fused_ops import kernel
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions
class TestFusedKernel(unittest.TestCase):
@ -61,7 +61,7 @@ class TestFusedKernel(unittest.TestCase):
self.assertEqual(mock_config._attn_implementation, "sdpa")
self.assertEqual(
transformers.integrations.sdpa_attention.sdpa_attention_forward,
attenions.sdpa_attention.sdpa_attention_forward,
attentions.sdpa_attention.sdpa_attention_forward,
)
def test_builtin_patch_rmsnorm(self):
@ -121,7 +121,7 @@ class TestFusedKernel(unittest.TestCase):
module_path = self.mock_cache / "test_module.py"
module_path.write_text(
"class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass"
"\ndef apply_rotary_pos_emb():\n pass\n"
"\ndef apply_rotary_pos_emb():\n pass\nclass InternLM2SdpaAttention:\n pass\n"
)
original_utils = transformers.dynamic_module_utils.get_class_in_module
kernel.apply_fused_kernel_internlm2(**kwargs)