[v1] kernel plugin (#9274)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
魅影
2025-10-18 18:02:14 +08:00
committed by GitHub
parent d9d67ba62d
commit 2c6aded5d4
15 changed files with 543 additions and 2 deletions

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
from typing import TYPE_CHECKING, Literal, TypedDict, Union
from typing_extensions import NotRequired
if TYPE_CHECKING:

View File

@ -13,7 +13,9 @@
# limitations under the License.
from typing import Callable, NotRequired, TypedDict
from typing import Callable, TypedDict
from typing_extensions import NotRequired
from ...extras.types import Sample, SFTSample

View File

@ -0,0 +1,30 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class KernelType(str, Enum):
RMSNORM = "rmsnorm"
SWIGLU = "swiglu"
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"
class DeviceType(str, Enum):
CPU = 'cpu'
CUDA = 'cuda'
NPU = 'npu'
XPU = 'xpu'

View File

@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,59 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import types
import torch
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
def _npu_swiglu_forward(self, hidden_state):
import torch_npu
return self.down_proj(
torch_npu.npu_swiglu(
torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
)
)
class NpuSwiGluKernel(MetaSwiGluKernel):
device = DeviceType.NPU
kernel = _npu_swiglu_forward
@classmethod
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod
def apply(cls, model, **kwargs) -> 'HFModel':
if not is_torch_npu_available():
return model
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
for name, module in model.named_modules():
# Match any module whose class name contains "RMSNorm"
if re.search(swiglu_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
module.forward = types.MethodType(cls.kernel, module)
return model

View File

@ -0,0 +1,148 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
from ....extras.types import HFModel
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
from .constants import DeviceType, KernelType
class KernelRegistry:
_instance: Optional['KernelRegistry'] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
def register(
self,
kernel_type: KernelType,
device_type: DeviceType,
kernel_impl: Optional[Callable[..., Any]]
) -> None:
"""Register a kernel implementation.
Args:
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
kernel_impl: the actual kernel function or class.
"""
if kernel_type not in self._registry:
self._registry[kernel_type] = {}
if device_type in self._registry[kernel_type]:
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(
self,
kernel_type: KernelType,
device_type: DeviceType
) -> Optional[Callable[..., Any]]:
return self._registry.get(kernel_type, {}).get(device_type)
KERNEL_REGISTRY = KernelRegistry()
class MetaKernel(ABC):
type: Optional[KernelType] = None
device: Optional[DeviceType] = None
kernel: Optional[Callable] = None
@classmethod
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod
@abstractmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaFlashAttentionKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRMSNormKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaSwiGluKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRoPEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaMoEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def discover_kernels(model: HFModel) -> list[MetaKernel]:
"""Discover and construct MetaKernel instances for the current model/device.
This is a placeholder to be implemented: it should inspect the runtime
environment (device type, available extensions, model architecture) and
return an ordered list of MetaKernel instances to be applied. Each returned
MetaKernel must encapsulate its own replacement logic in `apply`.
"""
# TODO: Implement auto discovery logic based on registry and device capabilities.
return []
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
requirement is that `apply` returns the replaced model.
Example:
from transformers import AutoModelForCausalLM
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
return kernel.apply(model, **kwargs)
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")

View File

@ -0,0 +1,73 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import types
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
def _npu_rms_forward(self, hidden_states):
"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states: Input hidden states tensor, same shape as the baseline.
Returns:
Normalized tensor consistent with the baseline RMSNorm behavior.
"""
import torch_npu
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
class NpuRMSNormKernel(MetaRMSNormKernel):
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
device = DeviceType.NPU
kernel = _npu_rms_forward
@classmethod
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
"""Register the NPU RMSNorm forward implementation to the global registry."""
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod
def apply(cls, model, **kwargs) -> HFModel:
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
"""
if not is_torch_npu_available():
return model
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():
# Match any module whose class name contains "RMSNorm"
if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
module.forward = types.MethodType(cls.kernel, module)
return model

View File

@ -0,0 +1,121 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import torch
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
import torch_npu
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
import torch_npu
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
class NpuRoPEKernel(MetaRoPEKernel):
device = DeviceType.NPU
kernel = _apply_rotary_pos_emb
@classmethod
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod
def apply(cls, model, **kwargs) -> 'HFModel':
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
if not is_torch_npu_available():
return model
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
device = DeviceType.NPU
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
@classmethod
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod
def apply(cls, model, **kwargs) -> 'HFModel':
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model

View File

@ -0,0 +1,47 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
import torch
def get_available_accelerator():
"""Get available accelerator in current environment.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
"""
accelerator = torch.accelerator.current_accelerator()
if accelerator is None:
return torch.device('cpu')
return accelerator
@lru_cache
def is_torch_npu_available():
return get_available_accelerator().type == 'npu'
@lru_cache
def is_torch_cuda_available():
return get_available_accelerator().type == 'cuda'
@lru_cache
def is_torch_xpu_available():
return get_available_accelerator().type == 'xpu'
@lru_cache
def is_torch_mps_available():
return get_available_accelerator().type == 'mps'

View File

@ -0,0 +1,46 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import MagicMock, patch
from transformers import AutoModelForCausalLM
class TestKernelPlugin(unittest.TestCase):
@patch('torch.accelerator.current_accelerator')
def test_apply_kernel(self, mock_get_accelerator):
mock_device = MagicMock()
mock_device.type = 'npu'
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward