mirror of
https://github.com/frozenleaves/LLaMA-Factory.git
synced 2025-10-20 16:23:46 +08:00
dev npu fused options
This commit is contained in:
@ -52,35 +52,6 @@ def _patch_swiglu(module: ModuleType, class_name: str):
|
||||
setattr(module, class_name, swiglu.NpuSwiGlu)
|
||||
|
||||
|
||||
def apply_fused_options(disable: bool=False):
|
||||
if disable or not is_torch_npu_available():
|
||||
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
|
||||
return
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_patch_sdpa_forward()
|
||||
|
||||
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
|
||||
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
|
||||
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
|
||||
|
||||
|
||||
def _original_get_dynamic_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
@ -137,7 +108,7 @@ def _dynamic_patch_swiglu(swiglu_cls: str, module: ModuleType, **kwargs):
|
||||
setattr(module, swiglu_cls, swiglu.NpuSwiGlu)
|
||||
|
||||
|
||||
def patch_dynamic_fused_ops():
|
||||
def _patch_dynamic_fused_ops():
|
||||
|
||||
def _get_dynamic_module(
|
||||
class_name: str,
|
||||
@ -158,4 +129,34 @@ def patch_dynamic_fused_ops():
|
||||
module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload)
|
||||
return getattr(module, class_name)
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
|
||||
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
|
||||
|
||||
|
||||
def apply_fused_options(disable: bool=False):
|
||||
if disable or not is_torch_npu_available():
|
||||
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
|
||||
return
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_patch_sdpa_forward()
|
||||
_patch_dynamic_fused_ops()
|
||||
|
||||
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
|
||||
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
|
||||
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
|
||||
|
||||
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
|
Reference in New Issue
Block a user