From 330a992693082b85d7d23cb289bd7fd96550c35f Mon Sep 17 00:00:00 2001 From: frozenleaves Date: Fri, 17 Oct 2025 10:16:42 +0800 Subject: [PATCH] fix comments --- src/llamafactory/v1/extras/types.py | 17 +---- .../model_plugins/kernels/constants.py | 30 +++++++++ .../mlp/{npu_moe.py => npu_fused_moe.py} | 0 .../model_plugins/kernels/mlp/npu_swiglu.py | 7 ++- .../plugins/model_plugins/kernels/registry.py | 4 +- .../kernels/rms_norm/npu_rms_norm.py | 6 +- .../model_plugins/kernels/rope/npu_rope.py | 7 ++- .../model_plugins/test_kernel_plugin.py | 63 ++----------------- 8 files changed, 51 insertions(+), 83 deletions(-) create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/constants.py rename src/llamafactory/v1/plugins/model_plugins/kernels/mlp/{npu_moe.py => npu_fused_moe.py} (100%) diff --git a/src/llamafactory/v1/extras/types.py b/src/llamafactory/v1/extras/types.py index cbbf1555..9447133d 100644 --- a/src/llamafactory/v1/extras/types.py +++ b/src/llamafactory/v1/extras/types.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum + from typing import TYPE_CHECKING, Literal, TypedDict, Union @@ -97,18 +97,3 @@ class Model(TypedDict): """HF model.""" dist_model: DistModel """Distributed model.""" - - -class KernelType(str, Enum): - RMSNORM = "rmsnorm" - SWIGLU = "swiglu" - FLASH_ATTENTION = "flash_attention" - ROPE = "rope" - MOE = "moe" - - -class DeviceType(str, Enum): - CPU = 'cpu' - CUDA = 'cuda' - NPU = 'npu' - XPU = 'xpu' diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py new file mode 100644 index 00000000..063ebb44 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py @@ -0,0 +1,30 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class KernelType(str, Enum): + RMSNORM = "rmsnorm" + SWIGLU = "swiglu" + FLASH_ATTENTION = "flash_attention" + ROPE = "rope" + MOE = "moe" + + +class DeviceType(str, Enum): + CPU = 'cpu' + CUDA = 'cuda' + NPU = 'npu' + XPU = 'xpu' diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_moe.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py index 146c076f..be331dec 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib.util + import re import types import torch -from .....extras.types import DeviceType, HFModel, KernelType +from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel @@ -43,7 +44,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel): @classmethod def apply(cls, model, **kwargs) -> 'HFModel': - if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")): + if not is_torch_npu_available(): return model swiglu_pattern = re.compile("MLP", re.IGNORECASE) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 3192d955..33597c48 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from abc import ABC, abstractmethod from typing import Any, Callable, Optional -from ....extras.types import DeviceType, HFModel, KernelType +from ....extras.types import HFModel from ...trainer_plugins.distributed.accelerate import get_available_accelerator +from .constants import DeviceType, KernelType class KernelRegistry: diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py index ce9943de..018758ee 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib.util import re import types -from .....extras.types import DeviceType, HFModel, KernelType +from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel @@ -57,7 +57,7 @@ class NpuRMSNormKernel(MetaRMSNormKernel): - Do not modify weights, hyperparameters, or module structure to ensure numerical behavior and interface consistency. """ - if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")): + if not is_torch_npu_available(): return model rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py index 50b0bea4..a1d41dd4 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib.util + import sys import torch -from .....extras.types import DeviceType, HFModel, KernelType +from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType from ..registry import KERNEL_REGISTRY, MetaRoPEKernel @@ -66,7 +67,7 @@ class NpuRoPEKernel(MetaRoPEKernel): `apply_rotary_pos_emb` function in that module's namespace with the NPU-accelerated version from this file. """ - if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")): + if not is_torch_npu_available(): return model _modules = set() diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index e5cb0976..a89b8bd7 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -15,58 +15,7 @@ import unittest from unittest.mock import MagicMock, patch -import torch -from torch import nn - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = q * sin - k_embed = k * cos - return q_embed, k_embed - - -class TinyRMSNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - return x * self.weight - - -class TinyMLP(nn.Module): - def __init__(self): - super().__init__() - self.gate_proj = nn.Linear(10, 10) - self.up_proj = nn.Linear(10, 10) - self.down_proj = nn.Linear(10, 10) - - def forward(self, x): - return self.gate_proj(x) * self.up_proj(x) + self.down_proj(x) - - -class TinyAttention(nn.Module): - def forward(self, q, k, v, cos, sin, position_ids=None, unsqueeze_dim=1): - q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim) - return q_embed, k_embed - - -class TinyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(10, 10) - self.norm = TinyRMSNorm(10) - self.mlp = TinyMLP() - self.attn = TinyAttention() - self.attn_implementation = 'default' - - def set_attn_implementation(self, attn_implementation): - self.attn_implementation = attn_implementation - - def forward(self, x): - return self.mlp(self.norm(self.linear(x))) +from transformers import AutoModelForCausalLM class TestKernelPlugin(unittest.TestCase): @@ -77,10 +26,10 @@ class TestKernelPlugin(unittest.TestCase): mock_device.type = 'npu' mock_get_accelerator.return_value = mock_device - model = TinyModel() + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") - original_rmsnorm_forward = model.norm.forward - original_swiglu_forward = model.mlp.forward + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu @@ -91,7 +40,7 @@ class TestKernelPlugin(unittest.TestCase): apply_kernel(model, npu_rope.NpuRoPEKernel) model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel) - assert model.norm.forward is not original_rmsnorm_forward + assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) - assert model.mlp.forward is not original_swiglu_forward + assert model.model.layers[0].mlp.forward is not original_swiglu_forward