dev npu fused options

This commit is contained in:
frozenleaves
2025-09-03 16:22:21 +08:00
parent 1614a15329
commit 715834cc16
3 changed files with 159 additions and 19 deletions

View File

@ -1,8 +1,31 @@
from types import ModuleType
import dataclasses
import os
import re
import sys
import threading
import typing
import importlib
from typing import Optional, Union, List
from types import ModuleType
from pathlib import Path
import hashlib
import functools
import transformers
from transformers.dynamic_module_utils import get_relative_import_files
from transformers.utils.hub import HF_MODULES_CACHE
from transformers import PretrainedConfig
from transformers.utils import is_torch_npu_available
from . import sdpa_attention as npu_sdpa_attention
from . import rms_norm, rope, swiglu
from ...extras import logging
logger = logging.get_logger()
_HF_REMOTE_CODE_LOCK = threading.Lock()
def _patch_sdpa_forward():
"""
@ -18,7 +41,6 @@ def _patch_sdpa_forward():
def _patch_rmsnorm(module: ModuleType, class_name: str):
setattr(module, class_name, rms_norm.NpuRMSNorm)
@ -31,18 +53,109 @@ def _patch_swiglu(module: ModuleType, class_name: str):
def apply_fused_options(disable: bool=False):
if disable:
if disable or not is_torch_npu_available():
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
return
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_moe import modeling_qwen2_moe
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
def _original_get_dynamic_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
):
"""
Get dynamic module from py file, copied from transformers.dynamic_module_utils.get_class_in_module.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
# Hash the module file and all its relative imports to check if we need to reload it
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash
return module
def _dynamic_patch_flash_attention(sdpa_attention_cls: str, module: ModuleType, **kwargs):
_patch_sdpa_forward()
setattr(sdpa_attention_cls, "forward", npu_sdpa_attention.sdpa_attention_forward)
def _dynamic_patch_rmsnorm(rmsnorm_cls: str, module: ModuleType, **kwargs):
setattr(module, rmsnorm_cls, rms_norm.NpuRMSNorm)
def _dynamic_patch_rope(rope_cls: str, module: ModuleType, **kwargs):
setattr(module, rope_cls, rope.apply_rotary_pos_emb)
def _dynamic_patch_swiglu(swiglu_cls: str, module: ModuleType, **kwargs):
setattr(module, swiglu_cls, swiglu.NpuSwiGlu)
def patch_dynamic_fused_ops():
def _get_dynamic_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
):
module = _original_get_dynamic_module(class_name, module_path, force_reload=force_reload)
if module.__name__ == "modeling_internlm3":
breakpoint()
def _get_class_in_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
) -> typing.Type:
module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload)
return getattr(module, class_name)
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module

View File

@ -19,10 +19,6 @@ from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
from ...extras import logging
logger = logging.get_logger()
class NpuRMSNorm(nn.Module):
"""

View File

@ -9,6 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
@ -16,6 +17,9 @@ import torch
from transformers.utils import is_torch_npu_available
from transformers.integrations.sdpa_attention import repeat_kv
if is_torch_npu_available():
import torch_npu
def sdpa_attention_forward(
module: torch.nn.Module,
@ -33,7 +37,7 @@ def sdpa_attention_forward(
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if not is_torch_npu_available() and attention_mask is not None and causal_mask.ndim == 4:
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
@ -53,9 +57,36 @@ def sdpa_attention_forward(
is_causal = is_causal.item()
if is_torch_npu_available():
is_causal = True
causal_mask = None
if is_causal:
atten_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(query.device) if causal_mask else None
sparse_mode = 2 if causal_mask else 0
head_num = query.shape[1]
attn_output = torch_npu.npu_fusion_attention(
query, key, value, head_num, input_layout="BNSD",
pse=None,
atten_mask=atten_mask_npu,
sparse_mode=sparse_mode,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=2147483647,
next_tockens=2147483647,
keep_prob=1,
)[0]
else:
if causal_mask.dtype == torch.bool:
atten_mask_npu = torch.logical_not(causal_mask.bool()).to(query.device)
else:
atten_mask_npu = causal_mask.bool().to(query.device)
head_num = query.shape[1]
attn_output = torch_npu.npu_fusion_attention(
query, key, value, head_num, input_layout="BNSD",
pse=None,
atten_mask=atten_mask_npu,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=2147483647,
next_tockens=2147483647,
keep_prob=1
)[0]
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,