fix comments

This commit is contained in:
frozenleaves
2025-10-17 10:16:42 +08:00
parent 3f0d2c433e
commit 003ccfa3fb
9 changed files with 55 additions and 89 deletions

View File

@ -11,14 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypedDict, Union
try:
from typing import NotRequired # Python 3.11+
except ImportError:
from typing_extensions import NotRequired # Python < 3.11
from typing_extensions import NotRequired
if TYPE_CHECKING:
@ -97,18 +93,3 @@ class Model(TypedDict):
"""HF model."""
dist_model: DistModel
"""Distributed model."""
class KernelType(str, Enum):
RMSNORM = "rmsnorm"
SWIGLU = "swiglu"
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"
class DeviceType(str, Enum):
CPU = 'cpu'
CUDA = 'cuda'
NPU = 'npu'
XPU = 'xpu'

View File

@ -13,7 +13,9 @@
# limitations under the License.
from typing import Callable, NotRequired, TypedDict
from typing import Callable, TypedDict
from typing_extensions import NotRequired
from ...extras.types import Sample, SFTSample

View File

@ -0,0 +1,30 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class KernelType(str, Enum):
RMSNORM = "rmsnorm"
SWIGLU = "swiglu"
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"
class DeviceType(str, Enum):
CPU = 'cpu'
CUDA = 'cuda'
NPU = 'npu'
XPU = 'xpu'

View File

@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import re
import types
import torch
from .....extras.types import DeviceType, HFModel, KernelType
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
@ -43,7 +44,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
@classmethod
def apply(cls, model, **kwargs) -> 'HFModel':
if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")):
if not is_torch_npu_available():
return model
swiglu_pattern = re.compile("MLP", re.IGNORECASE)

View File

@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
from ....extras.types import DeviceType, HFModel, KernelType
from ....extras.types import HFModel
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
from .constants import DeviceType, KernelType
class KernelRegistry:

View File

@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import re
import types
from .....extras.types import DeviceType, HFModel, KernelType
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
@ -57,7 +57,7 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
"""
if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")):
if not is_torch_npu_available():
return model
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)

View File

@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import sys
import torch
from .....extras.types import DeviceType, HFModel, KernelType
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
@ -66,7 +67,7 @@ class NpuRoPEKernel(MetaRoPEKernel):
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")):
if not is_torch_npu_available():
return model
_modules = set()

View File

@ -15,58 +15,7 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
from torch import nn
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = q * sin
k_embed = k * cos
return q_embed, k_embed
class TinyRMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * self.weight
class TinyMLP(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(10, 10)
self.up_proj = nn.Linear(10, 10)
self.down_proj = nn.Linear(10, 10)
def forward(self, x):
return self.gate_proj(x) * self.up_proj(x) + self.down_proj(x)
class TinyAttention(nn.Module):
def forward(self, q, k, v, cos, sin, position_ids=None, unsqueeze_dim=1):
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim)
return q_embed, k_embed
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.norm = TinyRMSNorm(10)
self.mlp = TinyMLP()
self.attn = TinyAttention()
self.attn_implementation = 'default'
def set_attn_implementation(self, attn_implementation):
self.attn_implementation = attn_implementation
def forward(self, x):
return self.mlp(self.norm(self.linear(x)))
from transformers import AutoModelForCausalLM
class TestKernelPlugin(unittest.TestCase):
@ -77,10 +26,10 @@ class TestKernelPlugin(unittest.TestCase):
mock_device.type = 'npu'
mock_get_accelerator.return_value = mock_device
model = TinyModel()
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.norm.forward
original_swiglu_forward = model.mlp.forward
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
@ -91,7 +40,7 @@ class TestKernelPlugin(unittest.TestCase):
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.norm.forward is not original_rmsnorm_forward
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.mlp.forward is not original_swiglu_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward