!233 修复sdpa性能劣化问题

Merge pull request !233 from 幽若/master-0610
This commit is contained in:
2025-06-12 08:14:03 +00:00
committed by i-robot
parent db0363fbba
commit 6f12fb0e8a
8 changed files with 133 additions and 13 deletions

View File

@ -27,6 +27,8 @@ $$
用户可以在[此处](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000142.html)查询该融合算子详细文档在固定shape的场景中可以较大幅度提升性能。 用户可以在[此处](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000142.html)查询该融合算子详细文档在固定shape的场景中可以较大幅度提升性能。
对于FA融合算子当前openmind统一通过torch原生的sdpa接口调用对于适配过的模型使能后sdpa走FA融合算子不使能则会走transformers实现的eager模式对于未适配的模型其行为是默认行为当前版本默认走sdpa接口但sdpa的后端为小算子拼接不保证性能。
## RMSNorm ## RMSNorm
RmsNorm算子是大模型常用的归一化操作相比LayerNorm算子其去掉了减去均值的部分 ,其计算公式为: RmsNorm算子是大模型常用的归一化操作相比LayerNorm算子其去掉了减去均值的部分 ,其计算公式为:
@ -135,3 +137,5 @@ print(output)
# #
# 3. Get enough sleep: Sleep is essential for good health. Aim for 7-9 hours of sleep each night. Establish a regular sleep schedule and create a relaxing bedtime routine to help you fall asleep more easily. Avoid using electronic devices before bed, as the blue light emitted by screens can interfere with your sleep. # 3. Get enough sleep: Sleep is essential for good health. Aim for 7-9 hours of sleep each night. Establish a regular sleep schedule and create a relaxing bedtime routine to help you fall asleep more easily. Avoid using electronic devices before bed, as the blue light emitted by screens can interfere with your sleep.
``` ```
由于transformers默认走sdpa在外部不论有无使能`apply_fused_kernel`, 均会调用sdpa接口。但是使能后openmind会对transformers的sdpa attention进行适配适配后sdpa后端走npu FA融合算子未适配的情况下则是走小算子拼接。

View File

@ -193,7 +193,7 @@ generator = pipeline(task="text-to-image",
image = generator(prompt="masterpiece, best quality, Cute dragon creature, pokemon style, night, moonlight, dim lighting",) image = generator(prompt="masterpiece, best quality, Cute dragon creature, pokemon style, night, moonlight, dim lighting",)
``` ```
silicondiff_npu和PyTorch的对应版本如下当前silicondiff_npu仅支持PyTorch 2.1.0 silicondiff_npu和PyTorch的对应版本如下当前silicondiff_npu仅支持PyTorch 2.1.0和Python3.10
| PyTorch版本 | silicondiff_npu版本 | | PyTorch版本 | silicondiff_npu版本 |
|-------------|---------------------| |-------------|---------------------|

View File

@ -11,4 +11,4 @@
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details. # See the Mulan PSL v2 for more details.
from . import qwen2, llama, mistral, internlm2, internlm3 from . import sdpa_attention

View File

@ -0,0 +1,74 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
#
# Copyright 2018- The Hugging Face team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from openmind.utils.import_utils import is_torch_npu_available
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if not is_torch_npu_available() and attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
is_causal = query.shape[2] > 1 and causal_mask is None
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
if is_torch_npu_available():
is_causal = True
causal_mask = None
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None

View File

@ -37,6 +37,7 @@ from openmind.utils import logging
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
from openmind.integrations.transformers.npu_fused_ops.rope import rope from openmind.integrations.transformers.npu_fused_ops.rope import rope
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
from openmind.integrations.transformers.npu_fused_ops import kernel
logger = logging.get_logger() logger = logging.get_logger()
@ -119,8 +120,10 @@ def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs
if model_name not in DYNAMIC_MODELS: if model_name not in DYNAMIC_MODELS:
return return
else: else:
config = kwargs.get("config") config = kwargs.get("config", None)
setattr(config, "_attn_implementation", "sdpa") kernel._patch_sdpa_forward()
if config is not None:
setattr(config, "_attn_implementation", "sdpa")
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType): def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):

View File

@ -17,8 +17,9 @@ from types import ModuleType
from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2 import modeling_qwen2
from transformers.models.llama import modeling_llama from transformers.models.llama import modeling_llama
from transformers.models.mistral import modeling_mistral from transformers.models.mistral import modeling_mistral
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu from transformers.integrations import sdpa_attention
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
@ -29,12 +30,24 @@ class Pattern:
swiglu: str = "MLP" swiglu: str = "MLP"
def _builtin_patch_flash_attention(config): def _patch_sdpa_forward():
"""
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
"""
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
def _builtin_patch_flash_attention(config=None):
""" """
Patch the FA for transformers built-in models, call this method before the model instantiation is completed, Patch the FA for transformers built-in models, call this method before the model instantiation is completed,
when the model has already been instantiated, this method is not effective. when the model has already been instantiated, this method is not effective.
""" """
setattr(config, "_attn_implementation", "sdpa") _patch_sdpa_forward()
if config is not None:
setattr(config, "_attn_implementation", "sdpa")
def _builtin_patch_rmsnorm(module: ModuleType, class_name: str): def _builtin_patch_rmsnorm(module: ModuleType, class_name: str):
@ -56,13 +69,16 @@ def _builtin_patch_swiglu(module: ModuleType, class_name: str):
def _apply_fused_kernel_base(module: ModuleType, **kwargs): def _apply_fused_kernel_base(module: ModuleType, **kwargs):
config = kwargs.get("config", None)
if kwargs.get("use_npu_fusion_attention", False): if kwargs.get("use_npu_fusion_attention", False):
config = kwargs.get("config")
_builtin_patch_flash_attention(config) _builtin_patch_flash_attention(config)
else: else:
# if the FA fused option is not open, enforce eager mode. # If the FA fused option is not open, enforce eager mode.
config = kwargs.get("config") # Note: if the config is None, the default value of `_attn_implementation` is `sdpa`, but because of the npu sdpa
setattr(config, "_attn_implementation", "eager") # implementation, it will also run as eager mode, furthermore, the performance of this case is worse than transformers
# native implementation of eager mode.
if config is not None:
setattr(config, "_attn_implementation", "eager")
if kwargs.get("use_fused_rms_norm", False): if kwargs.get("use_fused_rms_norm", False):
pattern = re.compile(Pattern.rmsnorm) pattern = re.compile(Pattern.rmsnorm)

View File

@ -19,6 +19,7 @@ from types import ModuleType
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import transformers import transformers
from transformers.integrations import sdpa_attention
from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils import ( from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils import (
DYNAMIC_MODELS, DYNAMIC_MODELS,
@ -30,6 +31,7 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
patch_dynamic_fused_ops, patch_dynamic_fused_ops,
) )
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
from openmind.integrations.transformers.npu_fused_ops import attenions
class TestDynamicModelsRegistration(unittest.TestCase): class TestDynamicModelsRegistration(unittest.TestCase):
@ -50,10 +52,12 @@ class TestDynamicModuleLoading(unittest.TestCase):
"openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils.HF_MODULES_CACHE", self.mock_cache "openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils.HF_MODULES_CACHE", self.mock_cache
) )
self.patcher.start() self.patcher.start()
self.original_sdpa_attention = sdpa_attention.sdpa_attention_forward
def tearDown(self): def tearDown(self):
self.temp_dir.cleanup() self.temp_dir.cleanup()
self.patcher.stop() self.patcher.stop()
sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention
def test_raw_get_dynamic_module(self): def test_raw_get_dynamic_module(self):
module_path = self.mock_cache / "test_module.py" module_path = self.mock_cache / "test_module.py"
@ -101,6 +105,10 @@ class TestDynamicPatching(unittest.TestCase):
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config) _dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
self.assertEqual(mock_config._attn_implementation, "sdpa") self.assertEqual(mock_config._attn_implementation, "sdpa")
self.assertEqual(
transformers.integrations.sdpa_attention.sdpa_attention_forward,
attenions.sdpa_attention.sdpa_attention_forward,
)
@patch("importlib.util.spec_from_file_location") @patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec") @patch("importlib.util.module_from_spec")

View File

@ -17,9 +17,10 @@ from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import transformers import transformers
from transformers.integrations import sdpa_attention
from openmind.integrations.transformers.npu_fused_ops import kernel from openmind.integrations.transformers.npu_fused_ops import kernel
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
class TestFusedKernel(unittest.TestCase): class TestFusedKernel(unittest.TestCase):
@ -33,6 +34,8 @@ class TestFusedKernel(unittest.TestCase):
self.original_rope = transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb self.original_rope = transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
self.original_rmsnorm = transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm self.original_rmsnorm = transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
self.original_swiglu = transformers.models.qwen2.modeling_qwen2.Qwen2MLP self.original_swiglu = transformers.models.qwen2.modeling_qwen2.Qwen2MLP
self.original_sdpa_attention = sdpa_attention.sdpa_attention_forward
self.origin_atten_func = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS
def tearDown(self): def tearDown(self):
self.temp_dir.cleanup() self.temp_dir.cleanup()
@ -40,6 +43,14 @@ class TestFusedKernel(unittest.TestCase):
transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm
transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu
sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS = self.origin_atten_func
def test_patch_sdpa_forward(self):
kernel._patch_sdpa_forward()
self.assertEqual(
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"], sdpa_attention.sdpa_attention_forward
)
def test_builtin_patch_flash_attention(self): def test_builtin_patch_flash_attention(self):
class Config: class Config:
@ -48,6 +59,10 @@ class TestFusedKernel(unittest.TestCase):
mock_config = Config() mock_config = Config()
kernel._builtin_patch_flash_attention(mock_config) kernel._builtin_patch_flash_attention(mock_config)
self.assertEqual(mock_config._attn_implementation, "sdpa") self.assertEqual(mock_config._attn_implementation, "sdpa")
self.assertEqual(
transformers.integrations.sdpa_attention.sdpa_attention_forward,
attenions.sdpa_attention.sdpa_attention_forward,
)
def test_builtin_patch_rmsnorm(self): def test_builtin_patch_rmsnorm(self):
kernel._builtin_patch_rmsnorm(transformers.models.qwen2.modeling_qwen2, "Qwen2RMSNorm") kernel._builtin_patch_rmsnorm(transformers.models.qwen2.modeling_qwen2, "Qwen2RMSNorm")