dev npu fused options

This commit is contained in:
frozenleaves
2025-09-11 09:37:42 +08:00
parent 59f2bf1ea3
commit f630701e88
8 changed files with 593 additions and 0 deletions

View File

@ -514,6 +514,11 @@ class FinetuningArguments(
metadata={"help": "Whether or not to compute effective tokens per second."}, metadata={"help": "Whether or not to compute effective tokens per second."},
) )
enable_npu_fused_ops: bool = field(
default=False,
metadata={"help": "Whether enable NPU fused operators or not. "},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):
if isinstance(arg, str): if isinstance(arg, str):

View File

@ -29,6 +29,7 @@ from transformers import (
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
from ..extras.constants import AttentionFunction
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
@ -38,6 +39,8 @@ from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
from ..third_party.npu_fused_ops.npu_fused_patcher import apply_fused_ops
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
@ -138,6 +141,10 @@ def load_model(
r"""Load pretrained model.""" r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
# Currently, the npu fused operators can only be enabled in training mode and when flash-attn==sdpa.
# Other scenarios are not yet supported.
if model_args.flash_attn == AttentionFunction.SDPA and is_trainable:
apply_fused_ops(config, enable=finetuning_args.enable_npu_fused_ops)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"])) apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))

View File

View File

@ -0,0 +1,183 @@
from types import ModuleType
import dataclasses
import os
import re
import sys
import threading
import typing
import importlib
from typing import Optional, Union, List
from types import ModuleType
from pathlib import Path
import hashlib
import functools
import transformers
from transformers.dynamic_module_utils import get_relative_import_files
from transformers.utils.hub import HF_MODULES_CACHE
from transformers import PretrainedConfig
from transformers.utils import is_torch_npu_available
from . import sdpa_attention as npu_sdpa_attention
from . import rms_norm, rope, swiglu
from ...extras import logging
logger = logging.get_logger()
_HF_REMOTE_CODE_LOCK = threading.Lock()
def _patch_sdpa_forward():
"""
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
"""
from transformers.integrations import sdpa_attention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
sdpa_attention.sdpa_attention_forward = npu_sdpa_attention.sdpa_attention_forward
AttentionInterface._global_mapping["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
ALL_ATTENTION_FUNCTIONS["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
def _patch_rmsnorm(module: ModuleType, class_name: str):
setattr(module, class_name, rms_norm.NpuRMSNorm)
def _patch_rope(module: ModuleType, func_name: str):
setattr(module, func_name, rope.apply_rotary_pos_emb)
def _patch_swiglu(module: ModuleType, class_name: str):
setattr(getattr(module, class_name), "forward", swiglu.npu_swiglu_forward)
def _original_get_dynamic_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
):
"""
Get dynamic module from py file, copied from transformers.dynamic_module_utils.get_class_in_module.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
# Hash the module file and all its relative imports to check if we need to reload it
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash
return module
def _dynamic_patch_flash_attention(sdpa_attention_cls: str, module: ModuleType, forward, **kwargs):
_patch_sdpa_forward()
setattr(getattr(module, sdpa_attention_cls), "forward", forward)
def _dynamic_patch_rmsnorm(rmsnorm_cls: str, module: ModuleType, **kwargs):
setattr(module, rmsnorm_cls, rms_norm.NpuRMSNorm)
def _dynamic_patch_rope(rope_cls: str, module: ModuleType, **kwargs):
setattr(module, rope_cls, rope.apply_rotary_pos_emb)
def _dynamic_patch_swiglu(swiglu_cls: str, npu_swiglu_forward, module: ModuleType, **kwargs):
setattr(getattr(module, swiglu_cls), "forward", npu_swiglu_forward)
def _patch_dynamic_fused_ops():
def _get_dynamic_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
):
module = _original_get_dynamic_module(class_name, module_path, force_reload=force_reload)
if module.__name__.endswith("modeling_internlm3"):
_dynamic_patch_flash_attention("InternLM3SdpaAttention", module, npu_sdpa_attention.internlm3_sdpa_forward)
_dynamic_patch_rmsnorm("InternLM3RMSNorm", module)
_dynamic_patch_rope("apply_rotary_pos_emb", module)
_dynamic_patch_swiglu("InternLM3MLP", swiglu.npu_swiglu_forward, module)
if module.__name__.endswith("modeling_internlm2"):
_dynamic_patch_flash_attention("InternLM2SdpaAttention", module, npu_sdpa_attention.internlm2_sdpa_forward)
_dynamic_patch_rmsnorm("InternLM2RMSNorm", module)
_dynamic_patch_rope("apply_rotary_pos_emb", module)
_dynamic_patch_swiglu("InternLM2MLP", swiglu.npu_internlm2_swiglu_forward, module)
return module
def _get_class_in_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
) -> typing.Type:
module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload)
return getattr(module, class_name)
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
def apply_fused_ops(config, enable: bool = False):
if not enable or not is_torch_npu_available():
return
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_moe import modeling_qwen2_moe
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
_patch_dynamic_fused_ops()
if "Qwen2ForCausalLM" in getattr(config, "architectures", []):
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
if "Qwen2MoeForCausalLM" in getattr(config, "architectures", []):
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
if "Qwen3ForCausalLM" in getattr(config, "architectures", []):
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
if "Qwen3MoeForCausalLM" in getattr(config, "architectures", []):
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
if "Qwen2_5_VLForConditionalGeneration" in getattr(config, "architectures", []):
_patch_sdpa_forward()
_patch_rmsnorm(modeling_qwen2_5_vl, "Qwen2RMSNorm")
_patch_swiglu(modeling_qwen2_5_vl, "Qwen2MLP")
_patch_swiglu(modeling_qwen2_5_vl, "Qwen2_5_VLMLP")
setattr(modeling_qwen2_5_vl, "apply_multimodal_rotary_pos_emb", rope.apply_multimodal_rotary_pos_emb_qwen25_vl)

View File

@ -0,0 +1,38 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
class NpuRMSNorm(nn.Module):
"""
RMSNorm operator adapted for NPU. When NPU is available and the user chooses to enable NpuRMSNorm, it will be opened.
Otherwise, the native implementation will still be used
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@ -0,0 +1,73 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Applies Rotary Position Embedding to the query and key tensors.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
Explanation:
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
difference with modern LLMs.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
mrope_section(`List(int)`):
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed

View File

@ -0,0 +1,260 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from einops import rearrange
from transformers.utils import is_torch_npu_available
from transformers.cache_utils import Cache
from transformers.integrations.sdpa_attention import repeat_kv
if is_torch_npu_available():
import torch_npu
from .rope import apply_rotary_pos_emb
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
is_causal = query.shape[2] > 1 and causal_mask is None
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
if is_causal:
attention_mask = None
else:
attention_mask = torch.logical_not(causal_mask.bool()).to(query.device)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def internlm2_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# once this is implemented.
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
qkv_states,
"b q (h gs d) -> b q h gs d",
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)
query_states = qkv_states[..., : self.num_key_value_groups, :]
query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
# an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
# options. An inline conditional prevents dynamic shapes from compiling.
is_causal = bool(causal_mask is None and q_len > 1)
# if is_torch_npu_available():
# is_causal = True
# attention_mask = None
if is_causal:
attention_mask = None
else:
attention_mask = torch.logical_not(causal_mask.bool()).to(query_states.device)
attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
return attn_output, None, past_key_value
def internlm3_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
# if is_torch_npu_available():
# is_causal = True
# attention_mask = None
if is_causal:
attention_mask = None
else:
attention_mask = torch.logical_not(causal_mask.bool()).to(query_states.device)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@ -0,0 +1,27 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers.utils import is_torch_npu_available
if is_torch_npu_available():
import torch_npu
def npu_swiglu_forward(self, hidden_state):
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
def npu_internlm2_swiglu_forward(self, hidden_state):
return self.w2(torch_npu.npu_swiglu(torch.cat((self.w1(hidden_state), self.w3(hidden_state)), dim=-1), dim=-1))