!197 FA融合算子支持走sdpa接口

Merge pull request !197 from 幽若/master-0429
This commit is contained in:
2025-05-14 03:47:47 +00:00
committed by i-robot
parent 6c27012844
commit 8b9abdf814
6 changed files with 117 additions and 27 deletions

View File

@ -119,8 +119,9 @@ SFT阶段使用的数据集为从`OpenR1-Math-220k`处理得到的数据集:[o
2、更新微调配置
- 微调配置为`examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml`
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径, 同时请在yaml文件中增加template字段值可参见[此处](../../../docs/zh/basic_tutorial/train/train_params.md#模型数据配置模板)设定
- 微调后的模型保存在`output_dir`下。
- 若需要按照step保存checkpoint可在yaml文件中添加参数`save_strategy: steps`
3、启动微调
```shell

View File

@ -318,6 +318,7 @@ def get_model():
use_fused_rms_norm=args.use_fused_rms_norm,
use_fused_rope=args.use_fused_rope,
use_fused_swiglu=args.use_fused_swiglu,
config=config,
)
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models

View File

@ -28,9 +28,11 @@ from pathlib import Path
import hashlib
import functools
import torch
import transformers
from transformers.dynamic_module_utils import get_relative_import_files
from transformers.utils.hub import HF_MODULES_CACHE
from transformers import PretrainedConfig
from openmind.utils import logging
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
@ -117,19 +119,26 @@ def _raw_get_dynamic_module(
return module
def _dynamic_patch_flash_attention(model_name: str, module: ModuleType):
def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs):
if model_name not in DYNAMIC_MODELS:
return
pattern = re.compile(Pattern.attention)
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)]
attention_classes = getattr(module, attention_classes_attr[0])
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
npu_attention_class = type(
"NPUFusionAttention",
(attention_classes["eager"],),
{"forward": DYNAMIC_MODELS[model_name].get("npu_fusion_attention").forward},
)
attention_classes.update({k: npu_attention_class for k in attention_classes})
if torch.__version__ == "2.1.0":
pattern = re.compile(Pattern.attention)
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)]
attention_classes = getattr(module, attention_classes_attr[0])
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
npu_attention_class = type(
"NPUFusionAttention",
(attention_classes["eager"],),
{"forward": DYNAMIC_MODELS[model_name].get("npu_fusion_attention").forward},
)
attention_classes.update({k: npu_attention_class for k in attention_classes})
elif torch.__version__ >= "2.6.0":
config = kwargs.get("config")
setattr(config, "_attn_implementation", "sdpa")
else:
config = kwargs.get("config")
setattr(config, "_attn_implementation", "eager")
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
@ -163,7 +172,7 @@ def _dynamic_patch_swiglu(model_name, module):
setattr(module, swiglu_attr[0], DYNAMIC_MODELS[model_name].get("swiglu"))
def dynamic_operator_decorator(operator: typing.Callable, enable: bool = True):
def dynamic_operator_decorator(operator: typing.Callable, enable: bool = True, **kwargs):
def decorator(_get_dynamic_module):
if not enable:
return _get_dynamic_module
@ -171,7 +180,7 @@ def dynamic_operator_decorator(operator: typing.Callable, enable: bool = True):
@functools.wraps(_get_dynamic_module)
def wrapper(class_name: str, module_path: Union[str, os.PathLike], *, force_reload: bool = False):
module = _get_dynamic_module(class_name, module_path, force_reload=force_reload)
operator(class_name, module)
operator(class_name, module, **kwargs)
return module
return wrapper
@ -184,8 +193,9 @@ def patch_dynamic_fused_ops(
use_fused_rms_norm: bool = True,
use_fused_rope: bool = True,
use_fused_swiglu: bool = True,
config: PretrainedConfig = None,
):
@dynamic_operator_decorator(operator=_dynamic_patch_flash_attention, enable=use_npu_fusion_attention)
@dynamic_operator_decorator(operator=_dynamic_patch_flash_attention, enable=use_npu_fusion_attention, config=config)
@dynamic_operator_decorator(operator=_dynamic_patch_rms_norm, enable=use_fused_rms_norm)
@dynamic_operator_decorator(operator=_dynamic_patch_rope, enable=use_fused_rope)
@dynamic_operator_decorator(operator=_dynamic_patch_swiglu, enable=use_fused_swiglu)

View File

@ -15,6 +15,7 @@ import re
from types import ModuleType
from typing import Dict, Type
import torch
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.llama import modeling_llama
from transformers.models.mistral import modeling_mistral
@ -39,6 +40,10 @@ def _builtin_patch_flash_attention(RAW_ATTENTION_CLASSES: Dict, NEW_ATTENTION_CL
RAW_ATTENTION_CLASSES.update({k: NEW_ATTENTION_CLASS for k in RAW_ATTENTION_CLASSES})
def __builtin_patch_flash_attention_v2(config):
setattr(config, "_attn_implementation", "sdpa")
def _builtin_patch_rmsnorm(module: ModuleType, class_name: str):
"""
Patch the RMSNorm for transformers built-in models, call this method before the model instantiation is completed,
@ -59,10 +64,20 @@ def _builtin_patch_swiglu(module: ModuleType, class_name: str):
def _apply_fused_kernel_base(module: ModuleType, **kwargs):
if kwargs.get("use_npu_fusion_attention", False):
attention = kwargs.get("attention")
pattern = re.compile(Pattern.attention)
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)][0]
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
if torch.__version__ == "2.1.0":
attention = kwargs.get("attention")
pattern = re.compile(Pattern.attention)
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)][0]
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
elif torch.__version__ >= "2.6.0":
config = kwargs.get("config")
__builtin_patch_flash_attention_v2(config)
else:
pass
else:
# if the FA fused option is not open, enforce eager mode.
config = kwargs.get("config")
setattr(config, "_attn_implementation", "eager")
if kwargs.get("use_fused_rms_norm", False):
pattern = re.compile(Pattern.rmsnorm)
@ -97,6 +112,7 @@ def apply_fused_kernel_internlm2(**kwargs):
use_fused_rms_norm = kwargs.get("use_fused_rms_norm", False)
use_fused_rope = kwargs.get("use_fused_rope", False)
use_fused_swiglu = kwargs.get("use_fused_swiglu", False)
config = kwargs.get("config", None)
if "InternLM2ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
dynamic_module_utils.register_dynamic_model(
"InternLM2ForCausalLM",
@ -110,6 +126,7 @@ def apply_fused_kernel_internlm2(**kwargs):
use_fused_rms_norm=use_fused_rms_norm,
use_fused_rope=use_fused_rope,
use_fused_swiglu=use_fused_swiglu,
config=config,
)
@ -118,6 +135,7 @@ def apply_fused_kernel_internlm3(**kwargs):
use_fused_rms_norm = kwargs.get("use_fused_rms_norm", False)
use_fused_rope = kwargs.get("use_fused_rope", False)
use_fused_swiglu = kwargs.get("use_fused_swiglu", False)
config = kwargs.get("config", None)
if "InternLM3ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
dynamic_module_utils.register_dynamic_model(
"InternLM3ForCausalLM",
@ -131,4 +149,5 @@ def apply_fused_kernel_internlm3(**kwargs):
use_fused_rms_norm=use_fused_rms_norm,
use_fused_rope=use_fused_rope,
use_fused_swiglu=use_fused_swiglu,
config=config,
)

View File

@ -61,15 +61,21 @@ def check_use_fused_kernel(inner=False) -> bool:
if args.disable_fused_options:
return False
# installed version of transformers is not compatible for npu fused options
# installed version of transformers and torch is not compatible for npu fused options
try:
version.require_version("transformers<=4.47.1")
version.require_version("transformers>=4.39.2")
if torch.__version__ == "2.1.0":
version.require_version("transformers<=4.47.1")
version.require_version("transformers>=4.45.0")
elif torch.__version__ >= "2.6.0":
version.require_version("transformers>=4.51.1")
else:
return False
except ImportError:
logger.warning_rank0(
f"RuntimeWarning: The npu fused options is not available under the transformers "
f"v{transformers.__version__}. To use npu fused options, the version of transformers "
f"is required at least v4.39.2 and no more than v4.47.1."
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__} "
f"and the torch v{torch.__version__}. To use npu fused options, if torch version >= 2.6.0, the version of "
f"transformers is required at least v4.51.1; if torch version == 2.1.0, the version of transformers is "
f"required >= v4.45.0, and <= 4.47.1; In other cases, the npu fused options will not be available. "
)
return False
# check pass
@ -93,6 +99,7 @@ def _parse_params(**kwargs):
def _apply_log(model_type: str = None, **kwargs):
model = model_type if model_type else "supported"
kwargs.pop("config", None)
kwargs = _parse_params(**kwargs)
if kwargs.get("use_npu_fusion_attention", False):
logger.info_rank0(f"The {model} model will load with npu fused attention.")
@ -105,8 +112,9 @@ def _apply_log(model_type: str = None, **kwargs):
def _apply_fused_kernel_generic(apply_func: typing.Callable, **kwargs):
config = kwargs.pop("config", None)
params = _parse_params(**kwargs)
apply_func(**params)
apply_func(config=config, **params)
def apply_fused_kernel(**kwargs):

View File

@ -18,6 +18,8 @@ from pathlib import Path
from types import ModuleType
from unittest.mock import patch, MagicMock
import transformers
from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils import (
DYNAMIC_MODELS,
_raw_get_dynamic_module,
@ -25,6 +27,7 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
_dynamic_patch_rms_norm,
_dynamic_patch_rope,
_dynamic_patch_swiglu,
patch_dynamic_fused_ops,
)
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
@ -70,18 +73,48 @@ class TestDynamicModuleLoading(unittest.TestCase):
class TestDynamicPatching(unittest.TestCase):
@patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec")
@patch("torch.__version__", "2.1.0")
def test_attention_patching(self, _, __):
class MockAttentionBase:
def forward(self):
pass
class Config:
_attn_implementation = "eager"
mock_module = ModuleType("mock_module")
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module)
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=Config())
self.assertIsInstance(mock_module.ATTENTION_CLASSES["eager"].forward, internlm2.forward.__class__)
@patch("torch.__version__", "2.6.0")
def test_torch_260_sets_sdpa(self):
model_name = "test_model_260"
DYNAMIC_MODELS[model_name] = {}
class Config:
_attn_implementation = "eager"
mock_config = Config()
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
self.assertEqual(mock_config._attn_implementation, "sdpa")
@patch("torch.__version__", "2.5.0")
def test_torch_other_sets_sdpa(self):
model_name = "test_model_260"
DYNAMIC_MODELS[model_name] = {}
class Config:
_attn_implementation = "sdpa"
mock_config = Config()
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
self.assertEqual(mock_config._attn_implementation, "eager")
@patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec")
def test_rms_norm_patching(self, _, __):
@ -133,3 +166,21 @@ class TestDynamicPatching(unittest.TestCase):
DYNAMIC_MODELS["InternLM2ForCausalLM"]["npu_fusion_attention"] = original_attention
self.assertEqual(mock_module.apply_rotary_pos_emb, DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"])
@patch("importlib.util.spec_from_file_location")
@patch("importlib.util.module_from_spec")
@patch("torch.__version__", "2.6.0")
def test_patch_dynamic_fused_ops(self, _, __):
class Config:
_attn_implementation = "eager"
config = Config()
raw_get_class_in_module = transformers.dynamic_module_utils.get_class_in_module
patch_dynamic_fused_ops(
use_npu_fusion_attention=True,
use_fused_rms_norm=True,
use_fused_rope=True,
use_fused_swiglu=True,
config=config,
)
self.assertNotEqual(transformers.dynamic_module_utils.get_class_in_module, raw_get_class_in_module)