diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 3130f86e..d75a2fae 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -514,6 +514,11 @@ class FinetuningArguments( metadata={"help": "Whether or not to compute effective tokens per second."}, ) + enable_npu_fused_ops: bool = field( + default=False, + metadata={"help": "Whether enable NPU fused operators or not. "}, + ) + def __post_init__(self): def split_arg(arg): if isinstance(arg, str): diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 8793135f..49892045 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -26,9 +26,11 @@ from transformers import ( AutoProcessor, AutoTokenizer, ) +from transformers.utils import is_torch_npu_available from trl import AutoModelForCausalLMWithValueHead from ..extras import logging +from ..extras.constants import AttentionFunction from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from .adapter import init_adapter from .model_utils.liger_kernel import apply_liger_kernel @@ -44,7 +46,6 @@ if TYPE_CHECKING: from ..hparams import FinetuningArguments, ModelArguments - logger = logging.get_logger(__name__) @@ -138,6 +139,11 @@ def load_model( r"""Load pretrained model.""" init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) + # Currently, the npu fused operators can only be enabled in training mode and when flash-attn==sdpa. + # Other scenarios are not yet supported. + if is_torch_npu_available() and finetuning_args.enable_npu_fused_ops and model_args.flash_attn == AttentionFunction.SDPA and is_trainable: + from ..third_party.npu_fused_ops.npu_fused_patcher import apply_fused_ops + apply_fused_ops(config) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"])) diff --git a/src/llamafactory/third_party/npu_fused_ops/__init__.py b/src/llamafactory/third_party/npu_fused_ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/third_party/npu_fused_ops/npu_fused_patcher.py b/src/llamafactory/third_party/npu_fused_ops/npu_fused_patcher.py new file mode 100644 index 00000000..35d3c673 --- /dev/null +++ b/src/llamafactory/third_party/npu_fused_ops/npu_fused_patcher.py @@ -0,0 +1,171 @@ +import hashlib +import importlib +import os +import sys +import threading +from pathlib import Path +from types import ModuleType +from typing import Optional, Union + +import transformers +from transformers.dynamic_module_utils import get_relative_import_files +from transformers.utils.hub import HF_MODULES_CACHE + +from ...extras import logging +from . import rms_norm, rope, swiglu +from . import sdpa_attention as npu_sdpa_attention + + +logger = logging.get_logger() + +_HF_REMOTE_CODE_LOCK = threading.Lock() + + +def _patch_sdpa_forward(): + r"""The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the SDPA interface of NPU. + + If not, calling the SDPA interface is still in the eagle mode. + """ + from transformers.integrations import sdpa_attention + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface + + sdpa_attention.sdpa_attention_forward = npu_sdpa_attention.sdpa_attention_forward + AttentionInterface._global_mapping["sdpa"] = npu_sdpa_attention.sdpa_attention_forward + ALL_ATTENTION_FUNCTIONS["sdpa"] = npu_sdpa_attention.sdpa_attention_forward + + +def _patch_rmsnorm(module: ModuleType, class_name: str): + setattr(module, class_name, rms_norm.NpuRMSNorm) + + +def _patch_rope(module: ModuleType, func_name: str): + setattr(module, func_name, rope.apply_rotary_pos_emb) + + +def _patch_swiglu(module: ModuleType, class_name: str): + setattr(getattr(module, class_name), "forward", swiglu.npu_swiglu_forward) + + +def _original_get_dynamic_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, +): + """Get dynamic module from py file, copied from transformers.dynamic_module_utils.get_class_in_module.""" + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + # Hash the module file and all its relative imports to check if we need to reload it + module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) + module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + if getattr(module, "__transformers_module_hash__", "") != module_hash: + module_spec.loader.exec_module(module) + module.__transformers_module_hash__ = module_hash + return module + + +def _dynamic_patch_flash_attention(sdpa_attention_cls: str, module: ModuleType, forward, **kwargs): + _patch_sdpa_forward() + setattr(getattr(module, sdpa_attention_cls), "forward", forward) + + +def _dynamic_patch_rmsnorm(rmsnorm_cls: str, module: ModuleType, **kwargs): + setattr(module, rmsnorm_cls, rms_norm.NpuRMSNorm) + + +def _dynamic_patch_rope(rope_cls: str, module: ModuleType, **kwargs): + setattr(module, rope_cls, rope.apply_rotary_pos_emb) + + +def _dynamic_patch_swiglu(swiglu_cls: str, npu_swiglu_forward, module: ModuleType, **kwargs): + setattr(getattr(module, swiglu_cls), "forward", npu_swiglu_forward) + + +def _patch_dynamic_fused_ops(): + def _get_dynamic_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, + ): + module = _original_get_dynamic_module(class_name, module_path, force_reload=force_reload) + if module.__name__.endswith("modeling_internlm3"): + _dynamic_patch_flash_attention("InternLM3SdpaAttention", module, npu_sdpa_attention.internlm3_sdpa_forward) + _dynamic_patch_rmsnorm("InternLM3RMSNorm", module) + _dynamic_patch_rope("apply_rotary_pos_emb", module) + _dynamic_patch_swiglu("InternLM3MLP", swiglu.npu_swiglu_forward, module) + if module.__name__.endswith("modeling_internlm2"): + _dynamic_patch_flash_attention("InternLM2SdpaAttention", module, npu_sdpa_attention.internlm2_sdpa_forward) + _dynamic_patch_rmsnorm("InternLM2RMSNorm", module) + _dynamic_patch_rope("apply_rotary_pos_emb", module) + _dynamic_patch_swiglu("InternLM2MLP", swiglu.npu_internlm2_swiglu_forward, module) + return module + + def _get_class_in_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, + ): + module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload) + return getattr(module, class_name) + + transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module + + +def apply_fused_ops(config): + from transformers.models.qwen2 import modeling_qwen2 + from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl + from transformers.models.qwen2_moe import modeling_qwen2_moe + from transformers.models.qwen3 import modeling_qwen3 + from transformers.models.qwen3_moe import modeling_qwen3_moe + + _patch_dynamic_fused_ops() + if "Qwen2ForCausalLM" in getattr(config, "architectures", []): + _patch_sdpa_forward() + _patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm") + _patch_rope(modeling_qwen2, "apply_rotary_pos_emb") + _patch_swiglu(modeling_qwen2, "Qwen2MLP") + + if "Qwen2MoeForCausalLM" in getattr(config, "architectures", []): + _patch_sdpa_forward() + _patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm") + _patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb") + _patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP") + + if "Qwen3ForCausalLM" in getattr(config, "architectures", []): + _patch_sdpa_forward() + _patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm") + _patch_rope(modeling_qwen3, "apply_rotary_pos_emb") + _patch_swiglu(modeling_qwen3, "Qwen3MLP") + + if "Qwen3MoeForCausalLM" in getattr(config, "architectures", []): + _patch_sdpa_forward() + _patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm") + _patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb") + _patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP") + + if "Qwen2_5_VLForConditionalGeneration" in getattr(config, "architectures", []): + _patch_sdpa_forward() + _patch_rmsnorm(modeling_qwen2_5_vl, "Qwen2RMSNorm") + _patch_swiglu(modeling_qwen2_5_vl, "Qwen2MLP") + _patch_swiglu(modeling_qwen2_5_vl, "Qwen2_5_VLMLP") + setattr(modeling_qwen2_5_vl, "apply_multimodal_rotary_pos_emb", rope.apply_multimodal_rotary_pos_emb_qwen25_vl) diff --git a/src/llamafactory/third_party/npu_fused_ops/rms_norm.py b/src/llamafactory/third_party/npu_fused_ops/rms_norm.py new file mode 100644 index 00000000..f8c8fc5d --- /dev/null +++ b/src/llamafactory/third_party/npu_fused_ops/rms_norm.py @@ -0,0 +1,34 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch import nn +from transformers.utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + + +class NpuRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/llamafactory/third_party/npu_fused_ops/rope.py b/src/llamafactory/third_party/npu_fused_ops/rope.py new file mode 100644 index 00000000..84e1b3b6 --- /dev/null +++ b/src/llamafactory/third_party/npu_fused_ops/rope.py @@ -0,0 +1,73 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +def apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed diff --git a/src/llamafactory/third_party/npu_fused_ops/sdpa_attention.py b/src/llamafactory/third_party/npu_fused_ops/sdpa_attention.py new file mode 100644 index 00000000..3cfe055f --- /dev/null +++ b/src/llamafactory/third_party/npu_fused_ops/sdpa_attention.py @@ -0,0 +1,256 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from einops import rearrange +from transformers.cache_utils import Cache +from transformers.integrations.sdpa_attention import repeat_kv +from transformers.utils import is_torch_npu_available + + +_is_torch_npu_available = is_torch_npu_available() + + +if _is_torch_npu_available: + from .rope import apply_rotary_pos_emb + + +def sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` + if is_causal is None: + is_causal = query.shape[2] > 1 and causal_mask is None + + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + + if attention_mask.dtype != torch.bool: + attention_mask = torch.logical_not(attention_mask.bool()).to(query.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None + + +def internlm2_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # once this is implemented. + + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + if query_states.device.type == "npu" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of + # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph + # options. An inline conditional prevents dynamic shapes from compiling. + is_causal = bool(causal_mask is None and q_len > 1) + + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + if attention_mask.dtype != torch.bool: + attention_mask = torch.logical_not(attention_mask.bool()).to(query_states.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102 + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + return attn_output, None, past_key_value + + +def internlm3_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + if query_states.device.type == "npu" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + if attention_mask.dtype != torch.bool: + attention_mask = torch.logical_not(attention_mask.bool()).to(query_states.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value diff --git a/src/llamafactory/third_party/npu_fused_ops/swiglu.py b/src/llamafactory/third_party/npu_fused_ops/swiglu.py new file mode 100644 index 00000000..da062f1b --- /dev/null +++ b/src/llamafactory/third_party/npu_fused_ops/swiglu.py @@ -0,0 +1,28 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers.utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + + + +def npu_swiglu_forward(self, hidden_state): + return self.down_proj( + torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1) + ) + + +def npu_internlm2_swiglu_forward(self, hidden_state): + return self.w2(torch_npu.npu_swiglu(torch.cat((self.w1(hidden_state), self.w3(hidden_state)), dim=-1), dim=-1))