mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-20 21:14:17 +08:00
[v1] kernel plugin (#9274)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@ -12,7 +12,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
||||||
|
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Callable, NotRequired, TypedDict
|
from typing import Callable, TypedDict
|
||||||
|
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from ...extras.types import Sample, SFTSample
|
from ...extras.types import Sample, SFTSample
|
||||||
|
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class KernelType(str, Enum):
|
||||||
|
RMSNORM = "rmsnorm"
|
||||||
|
SWIGLU = "swiglu"
|
||||||
|
FLASH_ATTENTION = "flash_attention"
|
||||||
|
ROPE = "rope"
|
||||||
|
MOE = "moe"
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceType(str, Enum):
|
||||||
|
CPU = 'cpu'
|
||||||
|
CUDA = 'cuda'
|
||||||
|
NPU = 'npu'
|
||||||
|
XPU = 'xpu'
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .....extras.types import HFModel
|
||||||
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
|
from ..constants import DeviceType, KernelType
|
||||||
|
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_swiglu_forward(self, hidden_state):
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
return self.down_proj(
|
||||||
|
torch_npu.npu_swiglu(
|
||||||
|
torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||||
|
|
||||||
|
device = DeviceType.NPU
|
||||||
|
kernel = _npu_swiglu_forward
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
return model
|
||||||
|
|
||||||
|
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
# Match any module whose class name contains "RMSNorm"
|
||||||
|
if re.search(swiglu_pattern, module.__class__.__name__):
|
||||||
|
|
||||||
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
|
# and replace the original forward
|
||||||
|
module.forward = types.MethodType(cls.kernel, module)
|
||||||
|
|
||||||
|
return model
|
148
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
Normal file
148
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from ....extras.types import HFModel
|
||||||
|
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
|
||||||
|
from .constants import DeviceType, KernelType
|
||||||
|
|
||||||
|
|
||||||
|
class KernelRegistry:
|
||||||
|
_instance: Optional['KernelRegistry'] = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
kernel_type: KernelType,
|
||||||
|
device_type: DeviceType,
|
||||||
|
kernel_impl: Optional[Callable[..., Any]]
|
||||||
|
) -> None:
|
||||||
|
"""Register a kernel implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
|
||||||
|
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
|
||||||
|
kernel_impl: the actual kernel function or class.
|
||||||
|
"""
|
||||||
|
if kernel_type not in self._registry:
|
||||||
|
self._registry[kernel_type] = {}
|
||||||
|
|
||||||
|
if device_type in self._registry[kernel_type]:
|
||||||
|
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
|
||||||
|
|
||||||
|
self._registry[kernel_type][device_type] = kernel_impl
|
||||||
|
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||||
|
|
||||||
|
def get_kernel(
|
||||||
|
self,
|
||||||
|
kernel_type: KernelType,
|
||||||
|
device_type: DeviceType
|
||||||
|
) -> Optional[Callable[..., Any]]:
|
||||||
|
return self._registry.get(kernel_type, {}).get(device_type)
|
||||||
|
|
||||||
|
|
||||||
|
KERNEL_REGISTRY = KernelRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
class MetaKernel(ABC):
|
||||||
|
type: Optional[KernelType] = None
|
||||||
|
device: Optional[DeviceType] = None
|
||||||
|
kernel: Optional[Callable] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MetaFlashAttentionKernel(MetaKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MetaRMSNormKernel(MetaKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MetaSwiGluKernel(MetaKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MetaRoPEKernel(MetaKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MetaMoEKernel(MetaKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
||||||
|
"""Discover and construct MetaKernel instances for the current model/device.
|
||||||
|
|
||||||
|
This is a placeholder to be implemented: it should inspect the runtime
|
||||||
|
environment (device type, available extensions, model architecture) and
|
||||||
|
return an ordered list of MetaKernel instances to be applied. Each returned
|
||||||
|
MetaKernel must encapsulate its own replacement logic in `apply`.
|
||||||
|
"""
|
||||||
|
# TODO: Implement auto discovery logic based on registry and device capabilities.
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
|
||||||
|
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||||
|
|
||||||
|
Corresponding replacement logic is maintained inside each kernel; the only
|
||||||
|
requirement is that `apply` returns the replaced model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||||
|
model = apply_kernel(model, NpuRMSNormKernel)
|
||||||
|
"""
|
||||||
|
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||||
|
return kernel.apply(model, **kwargs)
|
||||||
|
|
||||||
|
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
|
||||||
|
from .....extras.types import HFModel
|
||||||
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
|
from ..constants import DeviceType, KernelType
|
||||||
|
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_rms_forward(self, hidden_states):
|
||||||
|
"""NPU forward implementation for RMSNorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||||
|
hidden_states: Input hidden states tensor, same shape as the baseline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||||
|
"""
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||||
|
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||||
|
|
||||||
|
device = DeviceType.NPU
|
||||||
|
kernel = _npu_rms_forward
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
|
||||||
|
"""Register the NPU RMSNorm forward implementation to the global registry."""
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model, **kwargs) -> HFModel:
|
||||||
|
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||||
|
|
||||||
|
Key points:
|
||||||
|
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||||
|
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
|
||||||
|
replace the original `forward`.
|
||||||
|
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||||
|
numerical behavior and interface consistency.
|
||||||
|
"""
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
return model
|
||||||
|
|
||||||
|
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
# Match any module whose class name contains "RMSNorm"
|
||||||
|
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||||
|
|
||||||
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
|
# and replace the original forward
|
||||||
|
module.forward = types.MethodType(cls.kernel, module)
|
||||||
|
|
||||||
|
return model
|
@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .....extras.types import HFModel
|
||||||
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
|
from ..constants import DeviceType, KernelType
|
||||||
|
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors."""
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||||
|
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
mrope_section = mrope_section * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||||
|
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class NpuRoPEKernel(MetaRoPEKernel):
|
||||||
|
device = DeviceType.NPU
|
||||||
|
kernel = _apply_rotary_pos_emb
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||||
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
|
|
||||||
|
This function iterates through the model's modules to find attention layers,
|
||||||
|
identifies the module where they are defined, and replaces the original
|
||||||
|
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||||
|
NPU-accelerated version from this file.
|
||||||
|
"""
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
return model
|
||||||
|
|
||||||
|
_modules = set()
|
||||||
|
for module in model.modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
module_name = module.__class__.__module__
|
||||||
|
if module_name in _modules:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
target_module = sys.modules[module_name]
|
||||||
|
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||||
|
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
|
||||||
|
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
|
||||||
|
_modules.add(module_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||||
|
device = DeviceType.NPU
|
||||||
|
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||||
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
|
|
||||||
|
This function iterates through the model's modules to find attention layers,
|
||||||
|
identifies the module where they are defined, and replaces the original
|
||||||
|
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||||
|
NPU-accelerated version from this file.
|
||||||
|
"""
|
||||||
|
_modules = set()
|
||||||
|
for module in model.modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
module_name = module.__class__.__module__
|
||||||
|
if module_name in _modules:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
target_module = sys.modules[module_name]
|
||||||
|
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||||
|
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
|
||||||
|
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
|
||||||
|
_modules.add(module_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_accelerator():
|
||||||
|
"""Get available accelerator in current environment.
|
||||||
|
|
||||||
|
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||||
|
"""
|
||||||
|
accelerator = torch.accelerator.current_accelerator()
|
||||||
|
if accelerator is None:
|
||||||
|
return torch.device('cpu')
|
||||||
|
return accelerator
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_npu_available():
|
||||||
|
return get_available_accelerator().type == 'npu'
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_cuda_available():
|
||||||
|
return get_available_accelerator().type == 'cuda'
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_xpu_available():
|
||||||
|
return get_available_accelerator().type == 'xpu'
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_mps_available():
|
||||||
|
return get_available_accelerator().type == 'mps'
|
||||||
|
46
tests_v1/plugins/model_plugins/test_kernel_plugin.py
Normal file
46
tests_v1/plugins/model_plugins/test_kernel_plugin.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class TestKernelPlugin(unittest.TestCase):
|
||||||
|
|
||||||
|
@patch('torch.accelerator.current_accelerator')
|
||||||
|
def test_apply_kernel(self, mock_get_accelerator):
|
||||||
|
mock_device = MagicMock()
|
||||||
|
mock_device.type = 'npu'
|
||||||
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
|
||||||
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|
||||||
|
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
||||||
|
|
||||||
|
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||||
|
|
||||||
|
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||||
|
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||||
|
|
||||||
|
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||||
|
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
Reference in New Issue
Block a user