@ -27,6 +27,8 @@ $$
|
||||
|
||||
用户可以在[此处](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000142.html)查询该融合算子详细文档,在固定shape的场景中,可以较大幅度提升性能。
|
||||
|
||||
对于FA融合算子,当前openmind统一通过torch原生的sdpa接口调用,对于适配过的模型,使能后,sdpa走FA融合算子,不使能则会走transformers实现的eager模式;对于未适配的模型,其行为是默认行为,当前版本默认走sdpa接口,但sdpa的后端为小算子拼接,不保证性能。
|
||||
|
||||
## RMSNorm
|
||||
|
||||
RmsNorm算子是大模型常用的归一化操作,相比LayerNorm算子,其去掉了减去均值的部分 ,其计算公式为:
|
||||
@ -135,3 +137,5 @@ print(output)
|
||||
#
|
||||
# 3. Get enough sleep: Sleep is essential for good health. Aim for 7-9 hours of sleep each night. Establish a regular sleep schedule and create a relaxing bedtime routine to help you fall asleep more easily. Avoid using electronic devices before bed, as the blue light emitted by screens can interfere with your sleep.
|
||||
```
|
||||
|
||||
注:由于transformers默认走sdpa,在外部不论有无使能`apply_fused_kernel`, 均会调用sdpa接口。但是使能后,openmind会对transformers的sdpa attention进行适配,适配后sdpa后端走npu FA融合算子,未适配的情况下,则是走小算子拼接。
|
||||
|
@ -193,7 +193,7 @@ generator = pipeline(task="text-to-image",
|
||||
image = generator(prompt="masterpiece, best quality, Cute dragon creature, pokemon style, night, moonlight, dim lighting",)
|
||||
```
|
||||
|
||||
silicondiff_npu和PyTorch的对应版本如下,当前silicondiff_npu仅支持PyTorch 2.1.0:
|
||||
silicondiff_npu和PyTorch的对应版本如下,当前silicondiff_npu仅支持PyTorch 2.1.0和Python3.10:
|
||||
|
||||
| PyTorch版本 | silicondiff_npu版本 |
|
||||
|-------------|---------------------|
|
||||
|
@ -11,4 +11,4 @@
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
from . import qwen2, llama, mistral, internlm2, internlm3
|
||||
from . import sdpa_attention
|
||||
|
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# Copyright 2018- The Hugging Face team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
from openmind.utils.import_utils import is_torch_npu_available
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
is_causal: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if not is_torch_npu_available() and attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
|
||||
if is_causal is None:
|
||||
is_causal = query.shape[2] > 1 and causal_mask is None
|
||||
|
||||
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
||||
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
if is_torch_npu_available():
|
||||
is_causal = True
|
||||
causal_mask = None
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=dropout,
|
||||
scale=scaling,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
@ -37,6 +37,7 @@ from openmind.utils import logging
|
||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
from openmind.integrations.transformers.npu_fused_ops.rope import rope
|
||||
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
|
||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
@ -119,7 +120,9 @@ def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs
|
||||
if model_name not in DYNAMIC_MODELS:
|
||||
return
|
||||
else:
|
||||
config = kwargs.get("config")
|
||||
config = kwargs.get("config", None)
|
||||
kernel._patch_sdpa_forward()
|
||||
if config is not None:
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
|
||||
|
||||
|
@ -17,8 +17,9 @@ from types import ModuleType
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.llama import modeling_llama
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||
from transformers.integrations import sdpa_attention
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
||||
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
||||
|
||||
|
||||
@ -29,11 +30,23 @@ class Pattern:
|
||||
swiglu: str = "MLP"
|
||||
|
||||
|
||||
def _builtin_patch_flash_attention(config):
|
||||
def _patch_sdpa_forward():
|
||||
"""
|
||||
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
|
||||
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
|
||||
"""
|
||||
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
|
||||
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
||||
|
||||
|
||||
def _builtin_patch_flash_attention(config=None):
|
||||
"""
|
||||
Patch the FA for transformers built-in models, call this method before the model instantiation is completed,
|
||||
when the model has already been instantiated, this method is not effective.
|
||||
"""
|
||||
_patch_sdpa_forward()
|
||||
if config is not None:
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
|
||||
|
||||
@ -56,12 +69,15 @@ def _builtin_patch_swiglu(module: ModuleType, class_name: str):
|
||||
|
||||
|
||||
def _apply_fused_kernel_base(module: ModuleType, **kwargs):
|
||||
config = kwargs.get("config", None)
|
||||
if kwargs.get("use_npu_fusion_attention", False):
|
||||
config = kwargs.get("config")
|
||||
_builtin_patch_flash_attention(config)
|
||||
else:
|
||||
# if the FA fused option is not open, enforce eager mode.
|
||||
config = kwargs.get("config")
|
||||
# If the FA fused option is not open, enforce eager mode.
|
||||
# Note: if the config is None, the default value of `_attn_implementation` is `sdpa`, but because of the npu sdpa
|
||||
# implementation, it will also run as eager mode, furthermore, the performance of this case is worse than transformers
|
||||
# native implementation of eager mode.
|
||||
if config is not None:
|
||||
setattr(config, "_attn_implementation", "eager")
|
||||
|
||||
if kwargs.get("use_fused_rms_norm", False):
|
||||
|
@ -19,6 +19,7 @@ from types import ModuleType
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import transformers
|
||||
from transformers.integrations import sdpa_attention
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils import (
|
||||
DYNAMIC_MODELS,
|
||||
@ -30,6 +31,7 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
|
||||
patch_dynamic_fused_ops,
|
||||
)
|
||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
from openmind.integrations.transformers.npu_fused_ops import attenions
|
||||
|
||||
|
||||
class TestDynamicModelsRegistration(unittest.TestCase):
|
||||
@ -50,10 +52,12 @@ class TestDynamicModuleLoading(unittest.TestCase):
|
||||
"openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils.HF_MODULES_CACHE", self.mock_cache
|
||||
)
|
||||
self.patcher.start()
|
||||
self.original_sdpa_attention = sdpa_attention.sdpa_attention_forward
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
self.patcher.stop()
|
||||
sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention
|
||||
|
||||
def test_raw_get_dynamic_module(self):
|
||||
module_path = self.mock_cache / "test_module.py"
|
||||
@ -101,6 +105,10 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
|
||||
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
self.assertEqual(
|
||||
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
||||
attenions.sdpa_attention.sdpa_attention_forward,
|
||||
)
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
|
@ -17,9 +17,10 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers
|
||||
from transformers.integrations import sdpa_attention
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu
|
||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
||||
|
||||
|
||||
class TestFusedKernel(unittest.TestCase):
|
||||
@ -33,6 +34,8 @@ class TestFusedKernel(unittest.TestCase):
|
||||
self.original_rope = transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb
|
||||
self.original_rmsnorm = transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
|
||||
self.original_swiglu = transformers.models.qwen2.modeling_qwen2.Qwen2MLP
|
||||
self.original_sdpa_attention = sdpa_attention.sdpa_attention_forward
|
||||
self.origin_atten_func = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
@ -40,6 +43,14 @@ class TestFusedKernel(unittest.TestCase):
|
||||
transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb = self.original_rope
|
||||
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm = self.original_rmsnorm
|
||||
transformers.models.qwen2.modeling_qwen2.Qwen2MLP = self.original_swiglu
|
||||
sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS = self.origin_atten_func
|
||||
|
||||
def test_patch_sdpa_forward(self):
|
||||
kernel._patch_sdpa_forward()
|
||||
self.assertEqual(
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"], sdpa_attention.sdpa_attention_forward
|
||||
)
|
||||
|
||||
def test_builtin_patch_flash_attention(self):
|
||||
class Config:
|
||||
@ -48,6 +59,10 @@ class TestFusedKernel(unittest.TestCase):
|
||||
mock_config = Config()
|
||||
kernel._builtin_patch_flash_attention(mock_config)
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
self.assertEqual(
|
||||
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
||||
attenions.sdpa_attention.sdpa_attention_forward,
|
||||
)
|
||||
|
||||
def test_builtin_patch_rmsnorm(self):
|
||||
kernel._builtin_patch_rmsnorm(transformers.models.qwen2.modeling_qwen2, "Qwen2RMSNorm")
|
||||
|
Reference in New Issue
Block a user