@ -119,8 +119,9 @@ SFT阶段使用的数据集为从`OpenR1-Math-220k`处理得到的数据集:[o
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml`。
|
||||
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径。
|
||||
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径, 同时请在yaml文件中增加template字段,值可参见[此处](../../../docs/zh/basic_tutorial/train/train_params.md#模型数据配置模板)设定
|
||||
- 微调后的模型保存在`output_dir`下。
|
||||
- 若需要按照step保存checkpoint,可在yaml文件中添加参数`save_strategy: steps`。
|
||||
|
||||
3、启动微调
|
||||
```shell
|
||||
|
@ -318,6 +318,7 @@ def get_model():
|
||||
use_fused_rms_norm=args.use_fused_rms_norm,
|
||||
use_fused_rope=args.use_fused_rope,
|
||||
use_fused_swiglu=args.use_fused_swiglu,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||
|
@ -28,9 +28,11 @@ from pathlib import Path
|
||||
import hashlib
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers.dynamic_module_utils import get_relative_import_files
|
||||
from transformers.utils.hub import HF_MODULES_CACHE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from openmind.utils import logging
|
||||
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
|
||||
@ -117,19 +119,26 @@ def _raw_get_dynamic_module(
|
||||
return module
|
||||
|
||||
|
||||
def _dynamic_patch_flash_attention(model_name: str, module: ModuleType):
|
||||
def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs):
|
||||
if model_name not in DYNAMIC_MODELS:
|
||||
return
|
||||
pattern = re.compile(Pattern.attention)
|
||||
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)]
|
||||
attention_classes = getattr(module, attention_classes_attr[0])
|
||||
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
|
||||
npu_attention_class = type(
|
||||
"NPUFusionAttention",
|
||||
(attention_classes["eager"],),
|
||||
{"forward": DYNAMIC_MODELS[model_name].get("npu_fusion_attention").forward},
|
||||
)
|
||||
attention_classes.update({k: npu_attention_class for k in attention_classes})
|
||||
if torch.__version__ == "2.1.0":
|
||||
pattern = re.compile(Pattern.attention)
|
||||
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)]
|
||||
attention_classes = getattr(module, attention_classes_attr[0])
|
||||
if DYNAMIC_MODELS[model_name].get("npu_fusion_attention"):
|
||||
npu_attention_class = type(
|
||||
"NPUFusionAttention",
|
||||
(attention_classes["eager"],),
|
||||
{"forward": DYNAMIC_MODELS[model_name].get("npu_fusion_attention").forward},
|
||||
)
|
||||
attention_classes.update({k: npu_attention_class for k in attention_classes})
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
config = kwargs.get("config")
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
else:
|
||||
config = kwargs.get("config")
|
||||
setattr(config, "_attn_implementation", "eager")
|
||||
|
||||
|
||||
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
|
||||
@ -163,7 +172,7 @@ def _dynamic_patch_swiglu(model_name, module):
|
||||
setattr(module, swiglu_attr[0], DYNAMIC_MODELS[model_name].get("swiglu"))
|
||||
|
||||
|
||||
def dynamic_operator_decorator(operator: typing.Callable, enable: bool = True):
|
||||
def dynamic_operator_decorator(operator: typing.Callable, enable: bool = True, **kwargs):
|
||||
def decorator(_get_dynamic_module):
|
||||
if not enable:
|
||||
return _get_dynamic_module
|
||||
@ -171,7 +180,7 @@ def dynamic_operator_decorator(operator: typing.Callable, enable: bool = True):
|
||||
@functools.wraps(_get_dynamic_module)
|
||||
def wrapper(class_name: str, module_path: Union[str, os.PathLike], *, force_reload: bool = False):
|
||||
module = _get_dynamic_module(class_name, module_path, force_reload=force_reload)
|
||||
operator(class_name, module)
|
||||
operator(class_name, module, **kwargs)
|
||||
return module
|
||||
|
||||
return wrapper
|
||||
@ -184,8 +193,9 @@ def patch_dynamic_fused_ops(
|
||||
use_fused_rms_norm: bool = True,
|
||||
use_fused_rope: bool = True,
|
||||
use_fused_swiglu: bool = True,
|
||||
config: PretrainedConfig = None,
|
||||
):
|
||||
@dynamic_operator_decorator(operator=_dynamic_patch_flash_attention, enable=use_npu_fusion_attention)
|
||||
@dynamic_operator_decorator(operator=_dynamic_patch_flash_attention, enable=use_npu_fusion_attention, config=config)
|
||||
@dynamic_operator_decorator(operator=_dynamic_patch_rms_norm, enable=use_fused_rms_norm)
|
||||
@dynamic_operator_decorator(operator=_dynamic_patch_rope, enable=use_fused_rope)
|
||||
@dynamic_operator_decorator(operator=_dynamic_patch_swiglu, enable=use_fused_swiglu)
|
||||
|
@ -15,6 +15,7 @@ import re
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.llama import modeling_llama
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
@ -39,6 +40,10 @@ def _builtin_patch_flash_attention(RAW_ATTENTION_CLASSES: Dict, NEW_ATTENTION_CL
|
||||
RAW_ATTENTION_CLASSES.update({k: NEW_ATTENTION_CLASS for k in RAW_ATTENTION_CLASSES})
|
||||
|
||||
|
||||
def __builtin_patch_flash_attention_v2(config):
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
|
||||
|
||||
def _builtin_patch_rmsnorm(module: ModuleType, class_name: str):
|
||||
"""
|
||||
Patch the RMSNorm for transformers built-in models, call this method before the model instantiation is completed,
|
||||
@ -59,10 +64,20 @@ def _builtin_patch_swiglu(module: ModuleType, class_name: str):
|
||||
|
||||
def _apply_fused_kernel_base(module: ModuleType, **kwargs):
|
||||
if kwargs.get("use_npu_fusion_attention", False):
|
||||
attention = kwargs.get("attention")
|
||||
pattern = re.compile(Pattern.attention)
|
||||
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)][0]
|
||||
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
|
||||
if torch.__version__ == "2.1.0":
|
||||
attention = kwargs.get("attention")
|
||||
pattern = re.compile(Pattern.attention)
|
||||
attention_classes_attr = [attr for attr in dir(module) if pattern.search(attr)][0]
|
||||
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
config = kwargs.get("config")
|
||||
__builtin_patch_flash_attention_v2(config)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
# if the FA fused option is not open, enforce eager mode.
|
||||
config = kwargs.get("config")
|
||||
setattr(config, "_attn_implementation", "eager")
|
||||
|
||||
if kwargs.get("use_fused_rms_norm", False):
|
||||
pattern = re.compile(Pattern.rmsnorm)
|
||||
@ -97,6 +112,7 @@ def apply_fused_kernel_internlm2(**kwargs):
|
||||
use_fused_rms_norm = kwargs.get("use_fused_rms_norm", False)
|
||||
use_fused_rope = kwargs.get("use_fused_rope", False)
|
||||
use_fused_swiglu = kwargs.get("use_fused_swiglu", False)
|
||||
config = kwargs.get("config", None)
|
||||
if "InternLM2ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
|
||||
dynamic_module_utils.register_dynamic_model(
|
||||
"InternLM2ForCausalLM",
|
||||
@ -110,6 +126,7 @@ def apply_fused_kernel_internlm2(**kwargs):
|
||||
use_fused_rms_norm=use_fused_rms_norm,
|
||||
use_fused_rope=use_fused_rope,
|
||||
use_fused_swiglu=use_fused_swiglu,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
@ -118,6 +135,7 @@ def apply_fused_kernel_internlm3(**kwargs):
|
||||
use_fused_rms_norm = kwargs.get("use_fused_rms_norm", False)
|
||||
use_fused_rope = kwargs.get("use_fused_rope", False)
|
||||
use_fused_swiglu = kwargs.get("use_fused_swiglu", False)
|
||||
config = kwargs.get("config", None)
|
||||
if "InternLM3ForCausalLM" not in dynamic_module_utils.DYNAMIC_MODELS:
|
||||
dynamic_module_utils.register_dynamic_model(
|
||||
"InternLM3ForCausalLM",
|
||||
@ -131,4 +149,5 @@ def apply_fused_kernel_internlm3(**kwargs):
|
||||
use_fused_rms_norm=use_fused_rms_norm,
|
||||
use_fused_rope=use_fused_rope,
|
||||
use_fused_swiglu=use_fused_swiglu,
|
||||
config=config,
|
||||
)
|
||||
|
@ -61,15 +61,21 @@ def check_use_fused_kernel(inner=False) -> bool:
|
||||
if args.disable_fused_options:
|
||||
return False
|
||||
|
||||
# installed version of transformers is not compatible for npu fused options
|
||||
# installed version of transformers and torch is not compatible for npu fused options
|
||||
try:
|
||||
version.require_version("transformers<=4.47.1")
|
||||
version.require_version("transformers>=4.39.2")
|
||||
if torch.__version__ == "2.1.0":
|
||||
version.require_version("transformers<=4.47.1")
|
||||
version.require_version("transformers>=4.45.0")
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
version.require_version("transformers>=4.51.1")
|
||||
else:
|
||||
return False
|
||||
except ImportError:
|
||||
logger.warning_rank0(
|
||||
f"RuntimeWarning: The npu fused options is not available under the transformers "
|
||||
f"v{transformers.__version__}. To use npu fused options, the version of transformers "
|
||||
f"is required at least v4.39.2 and no more than v4.47.1."
|
||||
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__} "
|
||||
f"and the torch v{torch.__version__}. To use npu fused options, if torch version >= 2.6.0, the version of "
|
||||
f"transformers is required at least v4.51.1; if torch version == 2.1.0, the version of transformers is "
|
||||
f"required >= v4.45.0, and <= 4.47.1; In other cases, the npu fused options will not be available. "
|
||||
)
|
||||
return False
|
||||
# check pass
|
||||
@ -93,6 +99,7 @@ def _parse_params(**kwargs):
|
||||
|
||||
def _apply_log(model_type: str = None, **kwargs):
|
||||
model = model_type if model_type else "supported"
|
||||
kwargs.pop("config", None)
|
||||
kwargs = _parse_params(**kwargs)
|
||||
if kwargs.get("use_npu_fusion_attention", False):
|
||||
logger.info_rank0(f"The {model} model will load with npu fused attention.")
|
||||
@ -105,8 +112,9 @@ def _apply_log(model_type: str = None, **kwargs):
|
||||
|
||||
|
||||
def _apply_fused_kernel_generic(apply_func: typing.Callable, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
params = _parse_params(**kwargs)
|
||||
apply_func(**params)
|
||||
apply_func(config=config, **params)
|
||||
|
||||
|
||||
def apply_fused_kernel(**kwargs):
|
||||
|
@ -18,6 +18,8 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import transformers
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils import (
|
||||
DYNAMIC_MODELS,
|
||||
_raw_get_dynamic_module,
|
||||
@ -25,6 +27,7 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
|
||||
_dynamic_patch_rms_norm,
|
||||
_dynamic_patch_rope,
|
||||
_dynamic_patch_swiglu,
|
||||
patch_dynamic_fused_ops,
|
||||
)
|
||||
from openmind.integrations.transformers.npu_fused_ops.attenions import internlm2
|
||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||
@ -70,18 +73,48 @@ class TestDynamicModuleLoading(unittest.TestCase):
|
||||
class TestDynamicPatching(unittest.TestCase):
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
@patch("torch.__version__", "2.1.0")
|
||||
def test_attention_patching(self, _, __):
|
||||
|
||||
class MockAttentionBase:
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
mock_module = ModuleType("mock_module")
|
||||
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
|
||||
|
||||
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module)
|
||||
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=Config())
|
||||
self.assertIsInstance(mock_module.ATTENTION_CLASSES["eager"].forward, internlm2.forward.__class__)
|
||||
|
||||
@patch("torch.__version__", "2.6.0")
|
||||
def test_torch_260_sets_sdpa(self):
|
||||
model_name = "test_model_260"
|
||||
DYNAMIC_MODELS[model_name] = {}
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
mock_config = Config()
|
||||
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
|
||||
|
||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||
|
||||
@patch("torch.__version__", "2.5.0")
|
||||
def test_torch_other_sets_sdpa(self):
|
||||
model_name = "test_model_260"
|
||||
DYNAMIC_MODELS[model_name] = {}
|
||||
|
||||
class Config:
|
||||
_attn_implementation = "sdpa"
|
||||
|
||||
mock_config = Config()
|
||||
_dynamic_patch_flash_attention(model_name, MagicMock(), config=mock_config)
|
||||
|
||||
self.assertEqual(mock_config._attn_implementation, "eager")
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
def test_rms_norm_patching(self, _, __):
|
||||
@ -133,3 +166,21 @@ class TestDynamicPatching(unittest.TestCase):
|
||||
DYNAMIC_MODELS["InternLM2ForCausalLM"]["npu_fusion_attention"] = original_attention
|
||||
|
||||
self.assertEqual(mock_module.apply_rotary_pos_emb, DYNAMIC_MODELS["InternLM2ForCausalLM"]["rope"])
|
||||
|
||||
@patch("importlib.util.spec_from_file_location")
|
||||
@patch("importlib.util.module_from_spec")
|
||||
@patch("torch.__version__", "2.6.0")
|
||||
def test_patch_dynamic_fused_ops(self, _, __):
|
||||
class Config:
|
||||
_attn_implementation = "eager"
|
||||
|
||||
config = Config()
|
||||
raw_get_class_in_module = transformers.dynamic_module_utils.get_class_in_module
|
||||
patch_dynamic_fused_ops(
|
||||
use_npu_fusion_attention=True,
|
||||
use_fused_rms_norm=True,
|
||||
use_fused_rope=True,
|
||||
use_fused_swiglu=True,
|
||||
config=config,
|
||||
)
|
||||
self.assertNotEqual(transformers.dynamic_module_utils.get_class_in_module, raw_get_class_in_module)
|
||||
|
Reference in New Issue
Block a user