mirror of
https://github.com/frozenleaves/LLaMA-Factory.git
synced 2025-10-20 16:23:46 +08:00
fix comments
This commit is contained in:
@ -11,14 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from enum import Enum
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired # Python 3.11+
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired # Python < 3.11
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -97,18 +93,3 @@ class Model(TypedDict):
|
||||
"""HF model."""
|
||||
dist_model: DistModel
|
||||
"""Distributed model."""
|
||||
|
||||
|
||||
class KernelType(str, Enum):
|
||||
RMSNORM = "rmsnorm"
|
||||
SWIGLU = "swiglu"
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = 'cpu'
|
||||
CUDA = 'cuda'
|
||||
NPU = 'npu'
|
||||
XPU = 'xpu'
|
||||
|
@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, NotRequired, TypedDict
|
||||
from typing import Callable, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...extras.types import Sample, SFTSample
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class KernelType(str, Enum):
|
||||
RMSNORM = "rmsnorm"
|
||||
SWIGLU = "swiglu"
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = 'cpu'
|
||||
CUDA = 'cuda'
|
||||
NPU = 'npu'
|
||||
XPU = 'xpu'
|
@ -11,14 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....extras.types import DeviceType, HFModel, KernelType
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
|
||||
|
||||
|
||||
@ -43,7 +44,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")):
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
||||
|
@ -11,11 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ....extras.types import DeviceType, HFModel, KernelType
|
||||
from ....extras.types import HFModel
|
||||
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
|
||||
from .constants import DeviceType, KernelType
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
|
@ -11,12 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....extras.types import DeviceType, HFModel, KernelType
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||
numerical behavior and interface consistency.
|
||||
"""
|
||||
if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")):
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
@ -11,13 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....extras.types import DeviceType, HFModel, KernelType
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
|
||||
|
||||
|
||||
@ -66,7 +67,7 @@ class NpuRoPEKernel(MetaRoPEKernel):
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
if not (is_torch_npu_available() and importlib.util.find_spec("torch_npu")):
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
_modules = set()
|
||||
|
@ -15,58 +15,7 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = q * sin
|
||||
k_embed = k * cos
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class TinyRMSNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.weight
|
||||
|
||||
|
||||
class TinyMLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(10, 10)
|
||||
self.up_proj = nn.Linear(10, 10)
|
||||
self.down_proj = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.gate_proj(x) * self.up_proj(x) + self.down_proj(x)
|
||||
|
||||
|
||||
class TinyAttention(nn.Module):
|
||||
def forward(self, q, k, v, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class TinyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(10, 10)
|
||||
self.norm = TinyRMSNorm(10)
|
||||
self.mlp = TinyMLP()
|
||||
self.attn = TinyAttention()
|
||||
self.attn_implementation = 'default'
|
||||
|
||||
def set_attn_implementation(self, attn_implementation):
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(self.norm(self.linear(x)))
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class TestKernelPlugin(unittest.TestCase):
|
||||
@ -77,10 +26,10 @@ class TestKernelPlugin(unittest.TestCase):
|
||||
mock_device.type = 'npu'
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = TinyModel()
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.norm.forward
|
||||
original_swiglu_forward = model.mlp.forward
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
@ -91,7 +40,7 @@ class TestKernelPlugin(unittest.TestCase):
|
||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||
|
||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||
assert model.norm.forward is not original_rmsnorm_forward
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.mlp.forward is not original_swiglu_forward
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
|
Reference in New Issue
Block a user