mirror of
https://github.com/frozenleaves/LLaMA-Factory.git
synced 2025-10-20 16:23:46 +08:00
dev npu fused options
This commit is contained in:
@ -514,6 +514,11 @@ class FinetuningArguments(
|
||||
metadata={"help": "Whether or not to compute effective tokens per second."},
|
||||
)
|
||||
|
||||
use_npu_fused_option: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether use NPU fused option or not. "},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
|
@ -406,6 +406,7 @@ class ModelArguments(
|
||||
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
|
||||
)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
BaseModelArguments.__post_init__(self)
|
||||
ProcessorArguments.__post_init__(self)
|
||||
|
@ -29,6 +29,7 @@ from transformers import (
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import AttentionFunction
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@ -38,6 +39,8 @@ from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
from ..third_party.npu_fused_options.npu_fused_patcher import apply_fused_options
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
@ -138,6 +141,8 @@ def load_model(
|
||||
r"""Load pretrained model."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
if model_args.flash_attn == AttentionFunction.SDPA:
|
||||
apply_fused_options()
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
|
0
src/llamafactory/third_party/npu_fused_options/__init__.py
vendored
Normal file
0
src/llamafactory/third_party/npu_fused_options/__init__.py
vendored
Normal file
48
src/llamafactory/third_party/npu_fused_options/npu_fused_patcher.py
vendored
Normal file
48
src/llamafactory/third_party/npu_fused_options/npu_fused_patcher.py
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
from types import ModuleType
|
||||
|
||||
from . import sdpa_attention as npu_sdpa_attention
|
||||
from . import rms_norm, rope, swiglu
|
||||
|
||||
|
||||
def _patch_sdpa_forward():
|
||||
"""
|
||||
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
|
||||
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
|
||||
"""
|
||||
from transformers.integrations import sdpa_attention
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||
|
||||
sdpa_attention.sdpa_attention_forward = npu_sdpa_attention.sdpa_attention_forward
|
||||
AttentionInterface._global_mapping["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
|
||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
|
||||
|
||||
|
||||
def _patch_rmsnorm(module: ModuleType, class_name: str):
|
||||
|
||||
setattr(module, class_name, rms_norm.NpuRMSNorm)
|
||||
|
||||
|
||||
def _patch_rope(module: ModuleType, func_name: str):
|
||||
setattr(module, func_name, rope.apply_rotary_pos_emb)
|
||||
|
||||
|
||||
def _patch_swiglu(module: ModuleType, class_name: str):
|
||||
setattr(module, class_name, swiglu.NpuSwiGlu)
|
||||
|
||||
|
||||
def apply_fused_options(disable: bool=False):
|
||||
if disable:
|
||||
return
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_patch_sdpa_forward()
|
||||
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
|
||||
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
|
||||
|
||||
|
||||
|
||||
|
44
src/llamafactory/third_party/npu_fused_options/rms_norm.py
vendored
Normal file
44
src/llamafactory/third_party/npu_fused_options/rms_norm.py
vendored
Normal file
@ -0,0 +1,44 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
from ...extras import logging
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
|
||||
class NpuRMSNorm(nn.Module):
|
||||
"""
|
||||
RMSNorm operator adapted for NPU. When NPU is available and the user chooses to enable NpuRMSNorm, it will be opened.
|
||||
Otherwise, the native implementation will still be used
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
27
src/llamafactory/third_party/npu_fused_options/rope.py
vendored
Normal file
27
src/llamafactory/third_party/npu_fused_options/rope.py
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""
|
||||
Applies Rotary Position Embedding to the query and key tensors.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
70
src/llamafactory/third_party/npu_fused_options/sdpa_attention.py
vendored
Normal file
70
src/llamafactory/third_party/npu_fused_options/sdpa_attention.py
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
is_causal: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if not is_torch_npu_available() and attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
|
||||
if is_causal is None:
|
||||
is_causal = query.shape[2] > 1 and causal_mask is None
|
||||
|
||||
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
||||
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
if is_torch_npu_available():
|
||||
is_causal = True
|
||||
causal_mask = None
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=dropout,
|
||||
scale=scaling,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
46
src/llamafactory/third_party/npu_fused_options/swiglu.py
vendored
Normal file
46
src/llamafactory/third_party/npu_fused_options/swiglu.py
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
|
||||
class NpuSwiGlu(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
|
||||
|
||||
class NpuIntern2SwiGlu(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.w2(torch_npu.npu_swiglu(torch.cat((self.w1(hidden_state), self.w3(hidden_state)), dim=-1), dim=-1))
|
Reference in New Issue
Block a user