mirror of
https://github.com/frozenleaves/LLaMA-Factory.git
synced 2025-10-20 16:23:46 +08:00
dev npu fused options
This commit is contained in:
@ -1,8 +1,31 @@
|
||||
from types import ModuleType
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
import importlib
|
||||
from typing import Optional, Union, List
|
||||
from types import ModuleType
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
import functools
|
||||
|
||||
import transformers
|
||||
from transformers.dynamic_module_utils import get_relative_import_files
|
||||
from transformers.utils.hub import HF_MODULES_CACHE
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from . import sdpa_attention as npu_sdpa_attention
|
||||
from . import rms_norm, rope, swiglu
|
||||
|
||||
from ...extras import logging
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||
|
||||
def _patch_sdpa_forward():
|
||||
"""
|
||||
@ -18,7 +41,6 @@ def _patch_sdpa_forward():
|
||||
|
||||
|
||||
def _patch_rmsnorm(module: ModuleType, class_name: str):
|
||||
|
||||
setattr(module, class_name, rms_norm.NpuRMSNorm)
|
||||
|
||||
|
||||
@ -31,18 +53,109 @@ def _patch_swiglu(module: ModuleType, class_name: str):
|
||||
|
||||
|
||||
def apply_fused_options(disable: bool=False):
|
||||
if disable:
|
||||
if disable or not is_torch_npu_available():
|
||||
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
|
||||
return
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_patch_sdpa_forward()
|
||||
|
||||
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
|
||||
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
|
||||
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
|
||||
|
||||
|
||||
def _original_get_dynamic_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
):
|
||||
"""
|
||||
Get dynamic module from py file, copied from transformers.dynamic_module_utils.get_class_in_module.
|
||||
"""
|
||||
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
if force_reload:
|
||||
sys.modules.pop(name, None)
|
||||
importlib.invalidate_caches()
|
||||
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||
|
||||
# Hash the module file and all its relative imports to check if we need to reload it
|
||||
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
|
||||
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
|
||||
|
||||
module: ModuleType
|
||||
if cached_module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
else:
|
||||
module = cached_module
|
||||
if getattr(module, "__transformers_module_hash__", "") != module_hash:
|
||||
module_spec.loader.exec_module(module)
|
||||
module.__transformers_module_hash__ = module_hash
|
||||
return module
|
||||
|
||||
|
||||
def _dynamic_patch_flash_attention(sdpa_attention_cls: str, module: ModuleType, **kwargs):
|
||||
_patch_sdpa_forward()
|
||||
setattr(sdpa_attention_cls, "forward", npu_sdpa_attention.sdpa_attention_forward)
|
||||
|
||||
|
||||
def _dynamic_patch_rmsnorm(rmsnorm_cls: str, module: ModuleType, **kwargs):
|
||||
setattr(module, rmsnorm_cls, rms_norm.NpuRMSNorm)
|
||||
|
||||
|
||||
def _dynamic_patch_rope(rope_cls: str, module: ModuleType, **kwargs):
|
||||
setattr(module, rope_cls, rope.apply_rotary_pos_emb)
|
||||
|
||||
|
||||
def _dynamic_patch_swiglu(swiglu_cls: str, module: ModuleType, **kwargs):
|
||||
setattr(module, swiglu_cls, swiglu.NpuSwiGlu)
|
||||
|
||||
|
||||
def patch_dynamic_fused_ops():
|
||||
|
||||
def _get_dynamic_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
):
|
||||
module = _original_get_dynamic_module(class_name, module_path, force_reload=force_reload)
|
||||
if module.__name__ == "modeling_internlm3":
|
||||
breakpoint()
|
||||
|
||||
def _get_class_in_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
) -> typing.Type:
|
||||
module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload)
|
||||
return getattr(module, class_name)
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
|
@ -19,10 +19,6 @@ from transformers.utils import is_torch_npu_available
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
from ...extras import logging
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
|
||||
class NpuRMSNorm(nn.Module):
|
||||
"""
|
||||
|
@ -9,6 +9,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -16,6 +17,9 @@ import torch
|
||||
from transformers.utils import is_torch_npu_available
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
@ -33,7 +37,7 @@ def sdpa_attention_forward(
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if not is_torch_npu_available() and attention_mask is not None and causal_mask.ndim == 4:
|
||||
if attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
|
||||
@ -53,9 +57,36 @@ def sdpa_attention_forward(
|
||||
is_causal = is_causal.item()
|
||||
|
||||
if is_torch_npu_available():
|
||||
is_causal = True
|
||||
causal_mask = None
|
||||
|
||||
if is_causal:
|
||||
atten_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(query.device) if causal_mask else None
|
||||
sparse_mode = 2 if causal_mask else 0
|
||||
head_num = query.shape[1]
|
||||
attn_output = torch_npu.npu_fusion_attention(
|
||||
query, key, value, head_num, input_layout="BNSD",
|
||||
pse=None,
|
||||
atten_mask=atten_mask_npu,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=2147483647,
|
||||
next_tockens=2147483647,
|
||||
keep_prob=1,
|
||||
)[0]
|
||||
else:
|
||||
if causal_mask.dtype == torch.bool:
|
||||
atten_mask_npu = torch.logical_not(causal_mask.bool()).to(query.device)
|
||||
else:
|
||||
atten_mask_npu = causal_mask.bool().to(query.device)
|
||||
head_num = query.shape[1]
|
||||
attn_output = torch_npu.npu_fusion_attention(
|
||||
query, key, value, head_num, input_layout="BNSD",
|
||||
pse=None,
|
||||
atten_mask=atten_mask_npu,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=2147483647,
|
||||
next_tockens=2147483647,
|
||||
keep_prob=1
|
||||
)[0]
|
||||
else:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
|
Reference in New Issue
Block a user