mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-20 12:54:18 +08:00
[v1] kernel plugin (#9274)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, NotRequired, TypedDict
|
||||
from typing import Callable, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...extras.types import Sample, SFTSample
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class KernelType(str, Enum):
|
||||
RMSNORM = "rmsnorm"
|
||||
SWIGLU = "swiglu"
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = 'cpu'
|
||||
CUDA = 'cuda'
|
||||
NPU = 'npu'
|
||||
XPU = 'xpu'
|
@ -0,0 +1,13 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,59 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
|
||||
|
||||
|
||||
def _npu_swiglu_forward(self, hidden_state):
|
||||
import torch_npu
|
||||
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(
|
||||
torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_swiglu_forward
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
||||
for name, module in model.named_modules():
|
||||
# Match any module whose class name contains "RMSNorm"
|
||||
if re.search(swiglu_pattern, module.__class__.__name__):
|
||||
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
|
||||
return model
|
148
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
Normal file
148
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ....extras.types import HFModel
|
||||
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
|
||||
from .constants import DeviceType, KernelType
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
_instance: Optional['KernelRegistry'] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self,
|
||||
kernel_type: KernelType,
|
||||
device_type: DeviceType,
|
||||
kernel_impl: Optional[Callable[..., Any]]
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
Args:
|
||||
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
|
||||
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
|
||||
kernel_impl: the actual kernel function or class.
|
||||
"""
|
||||
if kernel_type not in self._registry:
|
||||
self._registry[kernel_type] = {}
|
||||
|
||||
if device_type in self._registry[kernel_type]:
|
||||
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
|
||||
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(
|
||||
self,
|
||||
kernel_type: KernelType,
|
||||
device_type: DeviceType
|
||||
) -> Optional[Callable[..., Any]]:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
|
||||
|
||||
KERNEL_REGISTRY = KernelRegistry()
|
||||
|
||||
|
||||
class MetaKernel(ABC):
|
||||
type: Optional[KernelType] = None
|
||||
device: Optional[DeviceType] = None
|
||||
kernel: Optional[Callable] = None
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaFlashAttentionKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaRMSNormKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaSwiGluKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaRoPEKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaMoEKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
||||
"""Discover and construct MetaKernel instances for the current model/device.
|
||||
|
||||
This is a placeholder to be implemented: it should inspect the runtime
|
||||
environment (device type, available extensions, model architecture) and
|
||||
return an ordered list of MetaKernel instances to be applied. Each returned
|
||||
MetaKernel must encapsulate its own replacement logic in `apply`.
|
||||
"""
|
||||
# TODO: Implement auto discovery logic based on registry and device capabilities.
|
||||
return []
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
requirement is that `apply` returns the replaced model.
|
||||
|
||||
Example:
|
||||
from transformers import AutoModelForCausalLM
|
||||
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")
|
@ -0,0 +1,73 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
|
||||
|
||||
|
||||
def _npu_rms_forward(self, hidden_states):
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
hidden_states: Input hidden states tensor, same shape as the baseline.
|
||||
|
||||
Returns:
|
||||
Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
"""
|
||||
import torch_npu
|
||||
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
|
||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_rms_forward
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
|
||||
"""Register the NPU RMSNorm forward implementation to the global registry."""
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
Key points:
|
||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
|
||||
replace the original `forward`.
|
||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||
numerical behavior and interface consistency.
|
||||
"""
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# Match any module whose class name contains "RMSNorm"
|
||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
|
||||
return model
|
@ -0,0 +1,121 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors."""
|
||||
import torch_npu
|
||||
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
|
||||
import torch_npu
|
||||
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class NpuRoPEKernel(MetaRoPEKernel):
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_rotary_pos_emb
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
|
||||
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
|
||||
_modules.add(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
|
||||
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
|
||||
_modules.add(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
@ -0,0 +1,47 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_available_accelerator():
|
||||
"""Get available accelerator in current environment.
|
||||
|
||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||
"""
|
||||
accelerator = torch.accelerator.current_accelerator()
|
||||
if accelerator is None:
|
||||
return torch.device('cpu')
|
||||
return accelerator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_npu_available():
|
||||
return get_available_accelerator().type == 'npu'
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_cuda_available():
|
||||
return get_available_accelerator().type == 'cuda'
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available():
|
||||
return get_available_accelerator().type == 'xpu'
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_mps_available():
|
||||
return get_available_accelerator().type == 'mps'
|
||||
|
46
tests_v1/plugins/model_plugins/test_kernel_plugin.py
Normal file
46
tests_v1/plugins/model_plugins/test_kernel_plugin.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class TestKernelPlugin(unittest.TestCase):
|
||||
|
||||
@patch('torch.accelerator.current_accelerator')
|
||||
def test_apply_kernel(self, mock_get_accelerator):
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = 'npu'
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
||||
|
||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||
|
||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
Reference in New Issue
Block a user