From 2c6aded5d4f4ff23aa1887d16972afb3c2543ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=85=E5=BD=B1?= <46097299+frozenleaves@users.noreply.github.com> Date: Sat, 18 Oct 2025 18:02:14 +0800 Subject: [PATCH] [v1] kernel plugin (#9274) Co-authored-by: frozenleaves --- src/llamafactory/v1/extras/types.py | 4 +- .../v1/plugins/data_plugins/converter.py | 4 +- .../plugins/model_plugins/kernels/__init__.py | 0 .../model_plugins/kernels/constants.py | 30 ++++ .../model_plugins/kernels/fa/__init__.py | 0 .../model_plugins/kernels/mlp/__init__.py | 0 .../kernels/mlp/npu_fused_moe.py | 13 ++ .../model_plugins/kernels/mlp/npu_swiglu.py | 59 +++++++ .../plugins/model_plugins/kernels/registry.py | 148 ++++++++++++++++++ .../kernels/rms_norm/__init__.py | 0 .../kernels/rms_norm/npu_rms_norm.py | 73 +++++++++ .../model_plugins/kernels/rope/__init__.py | 0 .../model_plugins/kernels/rope/npu_rope.py | 121 ++++++++++++++ .../trainer_plugins/distributed/accelerate.py | 47 ++++++ .../model_plugins/test_kernel_plugin.py | 46 ++++++ 15 files changed, 543 insertions(+), 2 deletions(-) create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/constants.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/registry.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py create mode 100644 tests_v1/plugins/model_plugins/test_kernel_plugin.py diff --git a/src/llamafactory/v1/extras/types.py b/src/llamafactory/v1/extras/types.py index ac3251a6..9539931a 100644 --- a/src/llamafactory/v1/extras/types.py +++ b/src/llamafactory/v1/extras/types.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union +from typing import TYPE_CHECKING, Literal, TypedDict, Union + +from typing_extensions import NotRequired if TYPE_CHECKING: diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py index 496a5505..7d197b9b 100644 --- a/src/llamafactory/v1/plugins/data_plugins/converter.py +++ b/src/llamafactory/v1/plugins/data_plugins/converter.py @@ -13,7 +13,9 @@ # limitations under the License. -from typing import Callable, NotRequired, TypedDict +from typing import Callable, TypedDict + +from typing_extensions import NotRequired from ...extras.types import Sample, SFTSample diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py new file mode 100644 index 00000000..063ebb44 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py @@ -0,0 +1,30 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class KernelType(str, Enum): + RMSNORM = "rmsnorm" + SWIGLU = "swiglu" + FLASH_ATTENTION = "flash_attention" + ROPE = "rope" + MOE = "moe" + + +class DeviceType(str, Enum): + CPU = 'cpu' + CUDA = 'cuda' + NPU = 'npu' + XPU = 'xpu' diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py new file mode 100644 index 00000000..ec0d6255 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py @@ -0,0 +1,13 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py new file mode 100644 index 00000000..be331dec --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py @@ -0,0 +1,59 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import types + +import torch + +from .....extras.types import HFModel +from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType +from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel + + +def _npu_swiglu_forward(self, hidden_state): + import torch_npu + + return self.down_proj( + torch_npu.npu_swiglu( + torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1 + ) + ) + + +class NpuSwiGluKernel(MetaSwiGluKernel): + + device = DeviceType.NPU + kernel = _npu_swiglu_forward + + @classmethod + def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU): + KERNEL_REGISTRY.register(kernel_type, device_type, cls) + + @classmethod + def apply(cls, model, **kwargs) -> 'HFModel': + if not is_torch_npu_available(): + return model + + swiglu_pattern = re.compile("MLP", re.IGNORECASE) + for name, module in model.named_modules(): + # Match any module whose class name contains "RMSNorm" + if re.search(swiglu_pattern, module.__class__.__name__): + + # Bind function as an instance method to preserve `self` semantics + # and replace the original forward + module.forward = types.MethodType(cls.kernel, module) + + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py new file mode 100644 index 00000000..33597c48 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -0,0 +1,148 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + +from ....extras.types import HFModel +from ...trainer_plugins.distributed.accelerate import get_available_accelerator +from .constants import DeviceType, KernelType + + +class KernelRegistry: + _instance: Optional['KernelRegistry'] = None + _initialized: bool = False + + def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry': + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {} + self._initialized = True + + def register( + self, + kernel_type: KernelType, + device_type: DeviceType, + kernel_impl: Optional[Callable[..., Any]] + ) -> None: + """Register a kernel implementation. + + Args: + kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION). + device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA). + kernel_impl: the actual kernel function or class. + """ + if kernel_type not in self._registry: + self._registry[kernel_type] = {} + + if device_type in self._registry[kernel_type]: + print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.") + + self._registry[kernel_type][device_type] = kernel_impl + print(f"Registered kernel {kernel_type.name} for device {device_type.name}.") + + def get_kernel( + self, + kernel_type: KernelType, + device_type: DeviceType + ) -> Optional[Callable[..., Any]]: + return self._registry.get(kernel_type, {}).get(device_type) + + +KERNEL_REGISTRY = KernelRegistry() + + +class MetaKernel(ABC): + type: Optional[KernelType] = None + device: Optional[DeviceType] = None + kernel: Optional[Callable] = None + + @classmethod + def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType): + KERNEL_REGISTRY.register(kernel_type, device_type, cls) + + @classmethod + @abstractmethod + def apply(cls, model: HFModel, **kwargs) -> HFModel: + raise NotImplementedError + + +class MetaFlashAttentionKernel(MetaKernel): + + @classmethod + def apply(cls, model: HFModel, **kwargs) -> HFModel: + raise NotImplementedError + + +class MetaRMSNormKernel(MetaKernel): + + @classmethod + def apply(cls, model: HFModel, **kwargs) -> HFModel: + raise NotImplementedError + + +class MetaSwiGluKernel(MetaKernel): + + @classmethod + def apply(cls, model: HFModel, **kwargs) -> HFModel: + raise NotImplementedError + + +class MetaRoPEKernel(MetaKernel): + + @classmethod + def apply(cls, model: HFModel, **kwargs) -> HFModel: + raise NotImplementedError + + +class MetaMoEKernel(MetaKernel): + + @classmethod + def apply(cls, model: HFModel, **kwargs) -> HFModel: + raise NotImplementedError + + +def discover_kernels(model: HFModel) -> list[MetaKernel]: + """Discover and construct MetaKernel instances for the current model/device. + + This is a placeholder to be implemented: it should inspect the runtime + environment (device type, available extensions, model architecture) and + return an ordered list of MetaKernel instances to be applied. Each returned + MetaKernel must encapsulate its own replacement logic in `apply`. + """ + # TODO: Implement auto discovery logic based on registry and device capabilities. + return [] + + +def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel': + """Call the MetaKernel's `apply` to perform the replacement. + + Corresponding replacement logic is maintained inside each kernel; the only + requirement is that `apply` returns the replaced model. + + Example: + from transformers import AutoModelForCausalLM + from .rms_norm.npu_rms_norm import NpuRMSNormKernel + model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") + model = apply_kernel(model, NpuRMSNormKernel) + """ + if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type: + return kernel.apply(model, **kwargs) + + raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.") diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py new file mode 100644 index 00000000..018758ee --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py @@ -0,0 +1,73 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import types + +from .....extras.types import HFModel +from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType +from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel + + +def _npu_rms_forward(self, hidden_states): + """NPU forward implementation for RMSNorm. + + Args: + self: RMSNorm module instance with `weight` and `variance_epsilon`. + hidden_states: Input hidden states tensor, same shape as the baseline. + + Returns: + Normalized tensor consistent with the baseline RMSNorm behavior. + """ + import torch_npu + + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + +class NpuRMSNormKernel(MetaRMSNormKernel): + """NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" + + device = DeviceType.NPU + kernel = _npu_rms_forward + + @classmethod + def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU): + """Register the NPU RMSNorm forward implementation to the global registry.""" + KERNEL_REGISTRY.register(kernel_type, device_type, cls) + + @classmethod + def apply(cls, model, **kwargs) -> HFModel: + """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. + + Key points: + - Match modules whose class name contains "RMSNorm" (case-insensitive). + - Bind `_npu_rms_forward` as an instance method via `types.MethodType` to + replace the original `forward`. + - Do not modify weights, hyperparameters, or module structure to ensure + numerical behavior and interface consistency. + """ + if not is_torch_npu_available(): + return model + + rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) + + for name, module in model.named_modules(): + # Match any module whose class name contains "RMSNorm" + if re.search(rms_norm_pattern, module.__class__.__name__): + + # Bind function as an instance method to preserve `self` semantics + # and replace the original forward + module.forward = types.MethodType(cls.kernel, module) + + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py new file mode 100644 index 00000000..a1d41dd4 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py @@ -0,0 +1,121 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import torch + +from .....extras.types import HFModel +from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType +from ..registry import KERNEL_REGISTRY, MetaRoPEKernel + + +def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + import torch_npu + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL).""" + import torch_npu + + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +class NpuRoPEKernel(MetaRoPEKernel): + device = DeviceType.NPU + kernel = _apply_rotary_pos_emb + + @classmethod + def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU): + KERNEL_REGISTRY.register(kernel_type, device_type, cls) + + @classmethod + def apply(cls, model, **kwargs) -> 'HFModel': + """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. + + This function iterates through the model's modules to find attention layers, + identifies the module where they are defined, and replaces the original + `apply_rotary_pos_emb` function in that module's namespace with the + NPU-accelerated version from this file. + """ + if not is_torch_npu_available(): + return model + + _modules = set() + for module in model.modules(): + if "Attention" in module.__class__.__name__: + module_name = module.__class__.__module__ + if module_name in _modules: + continue + try: + target_module = sys.modules[module_name] + if hasattr(target_module, "apply_rotary_pos_emb"): + if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel: + setattr(target_module, "apply_rotary_pos_emb", cls.kernel) + _modules.add(module_name) + except Exception: + pass + return model + + +class NpuQwen2VLRoPEKernel(MetaRoPEKernel): + device = DeviceType.NPU + kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl + + @classmethod + def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU): + KERNEL_REGISTRY.register(kernel_type, device_type, cls) + + @classmethod + def apply(cls, model, **kwargs) -> 'HFModel': + """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. + + This function iterates through the model's modules to find attention layers, + identifies the module where they are defined, and replaces the original + `apply_rotary_pos_emb` function in that module's namespace with the + NPU-accelerated version from this file. + """ + _modules = set() + for module in model.modules(): + if "Attention" in module.__class__.__name__: + module_name = module.__class__.__module__ + if module_name in _modules: + continue + try: + target_module = sys.modules[module_name] + if hasattr(target_module, "apply_multimodal_rotary_pos_emb"): + if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel: + setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel) + _modules.add(module_name) + except Exception: + pass + return model diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py index e69de29b..4f090e7d 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py @@ -0,0 +1,47 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache + +import torch + + +def get_available_accelerator(): + """Get available accelerator in current environment. + + Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError + """ + accelerator = torch.accelerator.current_accelerator() + if accelerator is None: + return torch.device('cpu') + return accelerator + + +@lru_cache +def is_torch_npu_available(): + return get_available_accelerator().type == 'npu' + + +@lru_cache +def is_torch_cuda_available(): + return get_available_accelerator().type == 'cuda' + + +@lru_cache +def is_torch_xpu_available(): + return get_available_accelerator().type == 'xpu' + + +@lru_cache +def is_torch_mps_available(): + return get_available_accelerator().type == 'mps' diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py new file mode 100644 index 00000000..a89b8bd7 --- /dev/null +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -0,0 +1,46 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +from transformers import AutoModelForCausalLM + + +class TestKernelPlugin(unittest.TestCase): + + @patch('torch.accelerator.current_accelerator') + def test_apply_kernel(self, mock_get_accelerator): + mock_device = MagicMock() + mock_device.type = 'npu' + mock_get_accelerator.return_value = mock_device + + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") + + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward + + + from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu + from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel + from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm + from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope + + apply_kernel(model, npu_rope.NpuRoPEKernel) + + model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel) + assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward + + model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) + assert model.model.layers[0].mlp.forward is not original_swiglu_forward