dev npu fused options

This commit is contained in:
frozenleaves
2025-08-25 16:31:35 +08:00
parent d3791e8ee1
commit 1614a15329
9 changed files with 246 additions and 0 deletions

View File

@ -514,6 +514,11 @@ class FinetuningArguments(
metadata={"help": "Whether or not to compute effective tokens per second."}, metadata={"help": "Whether or not to compute effective tokens per second."},
) )
use_npu_fused_option: bool = field(
default=False,
metadata={"help": "Whether use NPU fused option or not. "},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):
if isinstance(arg, str): if isinstance(arg, str):

View File

@ -406,6 +406,7 @@ class ModelArguments(
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."}, metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
) )
def __post_init__(self): def __post_init__(self):
BaseModelArguments.__post_init__(self) BaseModelArguments.__post_init__(self)
ProcessorArguments.__post_init__(self) ProcessorArguments.__post_init__(self)

View File

@ -29,6 +29,7 @@ from transformers import (
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
from ..extras.constants import AttentionFunction
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
@ -38,6 +39,8 @@ from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
from ..third_party.npu_fused_options.npu_fused_patcher import apply_fused_options
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
@ -138,6 +141,8 @@ def load_model(
r"""Load pretrained model.""" r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
if model_args.flash_attn == AttentionFunction.SDPA:
apply_fused_options()
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"])) apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))

View File

@ -0,0 +1,48 @@
from types import ModuleType
from . import sdpa_attention as npu_sdpa_attention
from . import rms_norm, rope, swiglu
def _patch_sdpa_forward():
"""
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
"""
from transformers.integrations import sdpa_attention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
sdpa_attention.sdpa_attention_forward = npu_sdpa_attention.sdpa_attention_forward
AttentionInterface._global_mapping["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
ALL_ATTENTION_FUNCTIONS["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
def _patch_rmsnorm(module: ModuleType, class_name: str):
setattr(module, class_name, rms_norm.NpuRMSNorm)
def _patch_rope(module: ModuleType, func_name: str):
setattr(module, func_name, rope.apply_rotary_pos_emb)
def _patch_swiglu(module: ModuleType, class_name: str):
setattr(module, class_name, swiglu.NpuSwiGlu)
def apply_fused_options(disable: bool=False):
if disable:
return
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_moe import modeling_qwen2_moe
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2, "Qwen2MLP")

View File

@ -0,0 +1,44 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
from ...extras import logging
logger = logging.get_logger()
class NpuRMSNorm(nn.Module):
"""
RMSNorm operator adapted for NPU. When NPU is available and the user chooses to enable NpuRMSNorm, it will be opened.
Otherwise, the native implementation will still be used
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@ -0,0 +1,27 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Applies Rotary Position Embedding to the query and key tensors.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed

View File

@ -0,0 +1,70 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers.utils import is_torch_npu_available
from transformers.integrations.sdpa_attention import repeat_kv
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if not is_torch_npu_available() and attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
is_causal = query.shape[2] > 1 and causal_mask is None
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
if is_torch_npu_available():
is_causal = True
causal_mask = None
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None

View File

@ -0,0 +1,46 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
class NpuSwiGlu(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, hidden_state):
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
class NpuIntern2SwiGlu(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, hidden_state):
return self.w2(torch_npu.npu_swiglu(torch.cat((self.w1(hidden_state), self.w3(hidden_state)), dim=-1), dim=-1))