dev npu fused options

This commit is contained in:
frozenleaves
2025-09-03 16:33:17 +08:00
parent 715834cc16
commit cc6ccdb626

View File

@ -52,35 +52,6 @@ def _patch_swiglu(module: ModuleType, class_name: str):
setattr(module, class_name, swiglu.NpuSwiGlu)
def apply_fused_options(disable: bool=False):
if disable or not is_torch_npu_available():
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
return
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_moe import modeling_qwen2_moe
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
def _original_get_dynamic_module(
class_name: str,
module_path: Union[str, os.PathLike],
@ -137,7 +108,7 @@ def _dynamic_patch_swiglu(swiglu_cls: str, module: ModuleType, **kwargs):
setattr(module, swiglu_cls, swiglu.NpuSwiGlu)
def patch_dynamic_fused_ops():
def _patch_dynamic_fused_ops():
def _get_dynamic_module(
class_name: str,
@ -158,4 +129,34 @@ def patch_dynamic_fused_ops():
module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload)
return getattr(module, class_name)
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
def apply_fused_options(disable: bool=False):
if disable or not is_torch_npu_available():
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
return
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_moe import modeling_qwen2_moe
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
_patch_sdpa_forward()
_patch_dynamic_fused_ops()
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")