mirror of
https://github.com/frozenleaves/LLaMA-Factory.git
synced 2025-10-20 16:23:46 +08:00
dev npu fused options
This commit is contained in:
@ -514,6 +514,11 @@ class FinetuningArguments(
|
||||
metadata={"help": "Whether or not to compute effective tokens per second."},
|
||||
)
|
||||
|
||||
enable_npu_fused_ops: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether enable NPU fused operators or not. "},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
|
@ -16,6 +16,7 @@ import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_npu_available
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@ -29,6 +30,7 @@ from transformers import (
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import AttentionFunction
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@ -38,13 +40,11 @@ from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -138,6 +138,11 @@ def load_model(
|
||||
r"""Load pretrained model."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
# Currently, the npu fused operators can only be enabled in training mode and when flash-attn==sdpa.
|
||||
# Other scenarios are not yet supported.
|
||||
if is_torch_npu_available() and finetuning_args.enable_npu_fused_ops and model_args.flash_attn == AttentionFunction.SDPA and is_trainable:
|
||||
from ..third_party.npu_fused_ops.npu_fused_patcher import apply_fused_ops
|
||||
apply_fused_ops(config)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
|
0
src/llamafactory/third_party/npu_fused_ops/__init__.py
vendored
Normal file
0
src/llamafactory/third_party/npu_fused_ops/__init__.py
vendored
Normal file
177
src/llamafactory/third_party/npu_fused_ops/npu_fused_patcher.py
vendored
Normal file
177
src/llamafactory/third_party/npu_fused_ops/npu_fused_patcher.py
vendored
Normal file
@ -0,0 +1,177 @@
|
||||
from types import ModuleType
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
import importlib
|
||||
from typing import Optional, Union, List
|
||||
from types import ModuleType
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
import transformers
|
||||
from transformers.dynamic_module_utils import get_relative_import_files
|
||||
from transformers.utils.hub import HF_MODULES_CACHE
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from . import sdpa_attention as npu_sdpa_attention
|
||||
from . import rms_norm, rope, swiglu
|
||||
|
||||
from ...extras import logging
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _patch_sdpa_forward():
|
||||
"""
|
||||
The purpose of this patch is to enable the native SDPA forward function of transformers to adapt to the
|
||||
SDPA interface of NPU. If not, calling the SDPA interface is still in the eagle mode
|
||||
"""
|
||||
from transformers.integrations import sdpa_attention
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||
|
||||
sdpa_attention.sdpa_attention_forward = npu_sdpa_attention.sdpa_attention_forward
|
||||
AttentionInterface._global_mapping["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
|
||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = npu_sdpa_attention.sdpa_attention_forward
|
||||
|
||||
|
||||
def _patch_rmsnorm(module: ModuleType, class_name: str):
|
||||
setattr(module, class_name, rms_norm.NpuRMSNorm)
|
||||
|
||||
|
||||
def _patch_rope(module: ModuleType, func_name: str):
|
||||
setattr(module, func_name, rope.apply_rotary_pos_emb)
|
||||
|
||||
|
||||
def _patch_swiglu(module: ModuleType, class_name: str):
|
||||
setattr(getattr(module, class_name), "forward", swiglu.npu_swiglu_forward)
|
||||
|
||||
|
||||
def _original_get_dynamic_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
):
|
||||
"""
|
||||
Get dynamic module from py file, copied from transformers.dynamic_module_utils.get_class_in_module.
|
||||
"""
|
||||
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
if force_reload:
|
||||
sys.modules.pop(name, None)
|
||||
importlib.invalidate_caches()
|
||||
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||
|
||||
# Hash the module file and all its relative imports to check if we need to reload it
|
||||
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
|
||||
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
|
||||
|
||||
module: ModuleType
|
||||
if cached_module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
else:
|
||||
module = cached_module
|
||||
if getattr(module, "__transformers_module_hash__", "") != module_hash:
|
||||
module_spec.loader.exec_module(module)
|
||||
module.__transformers_module_hash__ = module_hash
|
||||
return module
|
||||
|
||||
|
||||
def _dynamic_patch_flash_attention(sdpa_attention_cls: str, module: ModuleType, forward, **kwargs):
|
||||
_patch_sdpa_forward()
|
||||
setattr(getattr(module, sdpa_attention_cls), "forward", forward)
|
||||
|
||||
|
||||
def _dynamic_patch_rmsnorm(rmsnorm_cls: str, module: ModuleType, **kwargs):
|
||||
setattr(module, rmsnorm_cls, rms_norm.NpuRMSNorm)
|
||||
|
||||
|
||||
def _dynamic_patch_rope(rope_cls: str, module: ModuleType, **kwargs):
|
||||
setattr(module, rope_cls, rope.apply_rotary_pos_emb)
|
||||
|
||||
|
||||
def _dynamic_patch_swiglu(swiglu_cls: str, npu_swiglu_forward, module: ModuleType, **kwargs):
|
||||
setattr(getattr(module, swiglu_cls), "forward", npu_swiglu_forward)
|
||||
|
||||
|
||||
def _patch_dynamic_fused_ops():
|
||||
def _get_dynamic_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
):
|
||||
module = _original_get_dynamic_module(class_name, module_path, force_reload=force_reload)
|
||||
if module.__name__.endswith("modeling_internlm3"):
|
||||
_dynamic_patch_flash_attention("InternLM3SdpaAttention", module, npu_sdpa_attention.internlm3_sdpa_forward)
|
||||
_dynamic_patch_rmsnorm("InternLM3RMSNorm", module)
|
||||
_dynamic_patch_rope("apply_rotary_pos_emb", module)
|
||||
_dynamic_patch_swiglu("InternLM3MLP", swiglu.npu_swiglu_forward, module)
|
||||
if module.__name__.endswith("modeling_internlm2"):
|
||||
_dynamic_patch_flash_attention("InternLM2SdpaAttention", module, npu_sdpa_attention.internlm2_sdpa_forward)
|
||||
_dynamic_patch_rmsnorm("InternLM2RMSNorm", module)
|
||||
_dynamic_patch_rope("apply_rotary_pos_emb", module)
|
||||
_dynamic_patch_swiglu("InternLM2MLP", swiglu.npu_internlm2_swiglu_forward, module)
|
||||
return module
|
||||
|
||||
def _get_class_in_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
) -> typing.Type:
|
||||
module = _get_dynamic_module(class_name=class_name, module_path=module_path, force_reload=force_reload)
|
||||
return getattr(module, class_name)
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = _get_class_in_module
|
||||
|
||||
|
||||
def apply_fused_ops(config):
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_patch_dynamic_fused_ops()
|
||||
if "Qwen2ForCausalLM" in getattr(config, "architectures", []):
|
||||
_patch_sdpa_forward()
|
||||
_patch_rmsnorm(modeling_qwen2, "Qwen2RMSNorm")
|
||||
_patch_rope(modeling_qwen2, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2, "Qwen2MLP")
|
||||
|
||||
if "Qwen2MoeForCausalLM" in getattr(config, "architectures", []):
|
||||
_patch_sdpa_forward()
|
||||
_patch_rmsnorm(modeling_qwen2_moe, "Qwen2MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen2_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen2_moe, "Qwen2MoeMLP")
|
||||
|
||||
if "Qwen3ForCausalLM" in getattr(config, "architectures", []):
|
||||
_patch_sdpa_forward()
|
||||
_patch_rmsnorm(modeling_qwen3, "Qwen3RMSNorm")
|
||||
_patch_rope(modeling_qwen3, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3, "Qwen3MLP")
|
||||
|
||||
if "Qwen3MoeForCausalLM" in getattr(config, "architectures", []):
|
||||
_patch_sdpa_forward()
|
||||
_patch_rmsnorm(modeling_qwen3_moe, "Qwen3MoeRMSNorm")
|
||||
_patch_rope(modeling_qwen3_moe, "apply_rotary_pos_emb")
|
||||
_patch_swiglu(modeling_qwen3_moe, "Qwen3MoeMLP")
|
||||
|
||||
if "Qwen2_5_VLForConditionalGeneration" in getattr(config, "architectures", []):
|
||||
_patch_sdpa_forward()
|
||||
_patch_rmsnorm(modeling_qwen2_5_vl, "Qwen2RMSNorm")
|
||||
_patch_swiglu(modeling_qwen2_5_vl, "Qwen2MLP")
|
||||
_patch_swiglu(modeling_qwen2_5_vl, "Qwen2_5_VLMLP")
|
||||
setattr(modeling_qwen2_5_vl, "apply_multimodal_rotary_pos_emb", rope.apply_multimodal_rotary_pos_emb_qwen25_vl)
|
38
src/llamafactory/third_party/npu_fused_ops/rms_norm.py
vendored
Normal file
38
src/llamafactory/third_party/npu_fused_ops/rms_norm.py
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
|
||||
class NpuRMSNorm(nn.Module):
|
||||
"""
|
||||
RMSNorm operator adapted for NPU. When NPU is available and the user chooses to enable NpuRMSNorm, it will be opened.
|
||||
Otherwise, the native implementation will still be used
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
73
src/llamafactory/third_party/npu_fused_ops/rope.py
vendored
Normal file
73
src/llamafactory/third_party/npu_fused_ops/rope.py
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""
|
||||
Applies Rotary Position Embedding to the query and key tensors.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
||||
|
||||
Explanation:
|
||||
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
||||
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
||||
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
||||
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
||||
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
||||
difference with modern LLMs.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
mrope_section(`List(int)`):
|
||||
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
260
src/llamafactory/third_party/npu_fused_ops/sdpa_attention.py
vendored
Normal file
260
src/llamafactory/third_party/npu_fused_ops/sdpa_attention.py
vendored
Normal file
@ -0,0 +1,260 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from transformers.utils import is_torch_npu_available
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
from .rope import apply_rotary_pos_emb
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
is_causal: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
|
||||
if is_causal is None:
|
||||
is_causal = query.shape[2] > 1 and causal_mask is None
|
||||
|
||||
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
||||
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
if is_causal:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.logical_not(causal_mask.bool()).to(query.device)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=dropout,
|
||||
scale=scaling,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def internlm2_sdpa_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# once this is implemented.
|
||||
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states = self.wqkv(hidden_states)
|
||||
|
||||
qkv_states = rearrange(
|
||||
qkv_states,
|
||||
"b q (h gs d) -> b q h gs d",
|
||||
gs=2 + self.num_key_value_groups,
|
||||
d=self.head_dim,
|
||||
)
|
||||
|
||||
query_states = qkv_states[..., : self.num_key_value_groups, :]
|
||||
query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
|
||||
key_states = qkv_states[..., -2, :]
|
||||
value_states = qkv_states[..., -1, :]
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
|
||||
# an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
|
||||
# options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = bool(causal_mask is None and q_len > 1)
|
||||
|
||||
# if is_torch_npu_available():
|
||||
# is_causal = True
|
||||
# attention_mask = None
|
||||
|
||||
if is_causal:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.logical_not(causal_mask.bool()).to(query_states.device)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.wo(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
def internlm3_sdpa_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
# if is_torch_npu_available():
|
||||
# is_causal = True
|
||||
# attention_mask = None
|
||||
|
||||
if is_causal:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.logical_not(causal_mask.bool()).to(query_states.device)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
27
src/llamafactory/third_party/npu_fused_ops/swiglu.py
vendored
Normal file
27
src/llamafactory/third_party/npu_fused_ops/swiglu.py
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
|
||||
|
||||
def npu_swiglu_forward(self, hidden_state):
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
|
||||
|
||||
def npu_internlm2_swiglu_forward(self, hidden_state):
|
||||
return self.w2(torch_npu.npu_swiglu(torch.cat((self.w1(hidden_state), self.w3(hidden_state)), dim=-1), dim=-1))
|
Reference in New Issue
Block a user