!248 修复internlm系列模型sdpa接口未正常适配,新增qwen3模型支持融合算子(暂不包括多模态系列)
Merge pull request !248 from 幽若/master-0625
This commit is contained in:
@ -11,4 +11,4 @@
|
|||||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||||
# See the Mulan PSL v2 for more details.
|
# See the Mulan PSL v2 for more details.
|
||||||
|
|
||||||
from . import sdpa_attention
|
from . import sdpa_attention, internlm2, internlm3
|
@ -16,7 +16,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch InternLM2 model."""
|
"""PyTorch InternLM2 model."""
|
||||||
|
|
||||||
import math
|
# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -25,11 +26,6 @@ from transformers.cache_utils import Cache
|
|||||||
|
|
||||||
from openmind.utils import logging, is_torch_npu_available
|
from openmind.utils import logging, is_torch_npu_available
|
||||||
|
|
||||||
if is_torch_npu_available():
|
|
||||||
import torch_npu
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger()
|
logger = logging.get_logger()
|
||||||
|
|
||||||
@ -60,6 +56,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #
|
|||||||
Returns:
|
Returns:
|
||||||
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
|
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
"""
|
"""
|
||||||
|
if is_torch_npu_available():
|
||||||
|
from openmind.integrations.transformers.npu_fused_ops.rope.rope import (
|
||||||
|
apply_rotary_pos_emb as fused_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
@ -89,9 +91,7 @@ def forward(
|
|||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
self.scale_value = 1.0 / math.sqrt(self.head_dim)
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# Improve this warning with e.g. `model.config.attn_implementation = "manual"`
|
|
||||||
# once this is implemented.
|
# once this is implemented.
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
|
"InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
|
||||||
@ -144,22 +144,29 @@ def forward(
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
attn_output = torch_npu.npu_fusion_attention(
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
|
||||||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
|
||||||
|
# an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
|
||||||
|
# options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
is_causal = bool(causal_mask is None and q_len > 1)
|
||||||
|
|
||||||
|
if is_torch_npu_available():
|
||||||
|
is_causal = True
|
||||||
|
causal_mask = None
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
self.num_heads,
|
attn_mask=causal_mask,
|
||||||
input_layout="BNSD",
|
dropout_p=0.0,
|
||||||
pse=None,
|
is_causal=is_causal,
|
||||||
atten_mask=causal_mask.bool(),
|
)
|
||||||
scale=self.scale_value,
|
|
||||||
# pre_tockens and next_tockens are used for sparse computation,
|
|
||||||
# `2147483647` is the default value for the npu_fusion_attention interface.
|
|
||||||
pre_tockens=2147483647,
|
|
||||||
next_tockens=2147483647,
|
|
||||||
keep_prob=1,
|
|
||||||
inner_precise=0,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
|
# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
|
||||||
# 2025.01.16 - Adapt to openmind.
|
# 2025.01.16 - Adapt to openmind.
|
||||||
# Huawei Technologies Co., Ltd.
|
# Huawei Technologies Co., Ltd.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
|
# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -16,7 +17,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch InternLM3 model."""
|
"""PyTorch InternLM3 model."""
|
||||||
|
|
||||||
import math
|
# Note: when ascend npu give the equal implementation of the sdpa backend, this adapter code will be removed
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -24,11 +27,6 @@ from transformers.cache_utils import Cache
|
|||||||
|
|
||||||
from openmind.utils import logging, is_torch_npu_available
|
from openmind.utils import logging, is_torch_npu_available
|
||||||
|
|
||||||
if is_torch_npu_available():
|
|
||||||
import torch_npu
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger()
|
logger = logging.get_logger()
|
||||||
|
|
||||||
@ -60,6 +58,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #
|
|||||||
Returns:
|
Returns:
|
||||||
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
|
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
"""
|
"""
|
||||||
|
if is_torch_npu_available():
|
||||||
|
from openmind.integrations.transformers.npu_fused_ops.rope.rope import (
|
||||||
|
apply_rotary_pos_emb as fused_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
@ -91,7 +95,22 @@ def forward(
|
|||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
self.scale_value = 1.0 / math.sqrt(self.head_dim)
|
if output_attentions:
|
||||||
|
logger.warning_once(
|
||||||
|
"InternLM3Model is using InternLM3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
@ -127,30 +146,28 @@ def forward(
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
attn_output = torch_npu.npu_fusion_attention(
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
self.num_heads,
|
attn_mask=causal_mask,
|
||||||
input_layout="BNSD",
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
pse=None,
|
is_causal=is_causal,
|
||||||
atten_mask=causal_mask.bool(),
|
)
|
||||||
scale=self.scale_value,
|
|
||||||
# pre_tockens and next_tockens are used for sparse computation,
|
|
||||||
# `2147483647` is the default value for the npu_fusion_attention interface.
|
|
||||||
pre_tockens=2147483647,
|
|
||||||
next_tockens=2147483647,
|
|
||||||
keep_prob=1,
|
|
||||||
inner_precise=0,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
return attn_output, None, past_key_value
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
@ -38,6 +38,7 @@ from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
|||||||
from openmind.integrations.transformers.npu_fused_ops.rope import rope
|
from openmind.integrations.transformers.npu_fused_ops.rope import rope
|
||||||
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
|
from openmind.integrations.transformers.npu_fused_ops.swiglu import swiglu
|
||||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||||
|
from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2, internlm3
|
||||||
|
|
||||||
logger = logging.get_logger()
|
logger = logging.get_logger()
|
||||||
|
|
||||||
@ -59,7 +60,8 @@ def register_dynamic_model(model_name: str, /, **kwargs):
|
|||||||
model_name: Autoclass name, such as InternLM2ForCausalLM.
|
model_name: Autoclass name, such as InternLM2ForCausalLM.
|
||||||
**kwargs: supported npu fused options, all kwargs can be None.
|
**kwargs: supported npu fused options, all kwargs can be None.
|
||||||
kwargs include the follow params:
|
kwargs include the follow params:
|
||||||
npu_fusion_attention: the adapter module of npu fused attention.
|
sdpa_cls_name: the sdpa attention class name of the module
|
||||||
|
sdpa_forward: the forward function of npu fused attention.
|
||||||
rms_norm: the class of npu fused rms norm.
|
rms_norm: the class of npu fused rms norm.
|
||||||
rope: the function of npu fused rotary position embedding
|
rope: the function of npu fused rotary position embedding
|
||||||
swiglu: the class of npu fused SwiGLU.
|
swiglu: the class of npu fused SwiGLU.
|
||||||
@ -74,6 +76,17 @@ register_dynamic_model(
|
|||||||
rms_norm=rms_norm.NpuRMSNorm,
|
rms_norm=rms_norm.NpuRMSNorm,
|
||||||
rope=rope.apply_rotary_pos_emb,
|
rope=rope.apply_rotary_pos_emb,
|
||||||
swiglu=swiglu.NpuIntern2SwiGlu,
|
swiglu=swiglu.NpuIntern2SwiGlu,
|
||||||
|
sdpa_cls_name="InternLM2SdpaAttention",
|
||||||
|
sdpa_forward=internlm2.forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_dynamic_model(
|
||||||
|
"InternLM3ForCausalLM",
|
||||||
|
rms_norm=rms_norm.NpuRMSNorm,
|
||||||
|
rope=rope.apply_rotary_pos_emb,
|
||||||
|
swiglu=swiglu.NpuSwiGlu,
|
||||||
|
sdpa_cls_name="InternLM3SdpaAttention",
|
||||||
|
sdpa_forward=internlm3.forward,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -124,6 +137,17 @@ def _dynamic_patch_flash_attention(model_name: str, module: ModuleType, **kwargs
|
|||||||
kernel._patch_sdpa_forward()
|
kernel._patch_sdpa_forward()
|
||||||
if config is not None:
|
if config is not None:
|
||||||
setattr(config, "_attn_implementation", "sdpa")
|
setattr(config, "_attn_implementation", "sdpa")
|
||||||
|
setattr(config, "attn_implementation", "sdpa")
|
||||||
|
if not DYNAMIC_MODELS[model_name].get("sdpa_cls_name", None) or not DYNAMIC_MODELS[model_name].get(
|
||||||
|
"sdpa_forward", None
|
||||||
|
):
|
||||||
|
logger.warning_rank0(
|
||||||
|
"When execute register_dynamic_model function, `sdpa_cls_name` or `sdpa_forward` not given, "
|
||||||
|
"it may not work with SDPA Attention, we suggest you to set it manually or use the eager mode."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
sdpa_attention_cls = getattr(module, DYNAMIC_MODELS[model_name].get("sdpa_cls_name"))
|
||||||
|
setattr(sdpa_attention_cls, "forward", DYNAMIC_MODELS[model_name].get("sdpa_forward"))
|
||||||
|
|
||||||
|
|
||||||
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
|
def _dynamic_patch_rms_norm(model_name: str, module: ModuleType):
|
||||||
|
@ -15,11 +15,12 @@ import re
|
|||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
from transformers.models.qwen2 import modeling_qwen2
|
from transformers.models.qwen2 import modeling_qwen2
|
||||||
|
from transformers.models.qwen3 import modeling_qwen3
|
||||||
from transformers.models.llama import modeling_llama
|
from transformers.models.llama import modeling_llama
|
||||||
from transformers.models.mistral import modeling_mistral
|
from transformers.models.mistral import modeling_mistral
|
||||||
|
|
||||||
from openmind.utils.version import check_package_version
|
from openmind.utils.version import check_package_version
|
||||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions
|
||||||
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
||||||
|
|
||||||
|
|
||||||
@ -40,9 +41,9 @@ def _patch_sdpa_forward():
|
|||||||
from transformers.integrations import sdpa_attention
|
from transformers.integrations import sdpa_attention
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface
|
||||||
|
|
||||||
sdpa_attention.sdpa_attention_forward = attenions.sdpa_attention.sdpa_attention_forward
|
sdpa_attention.sdpa_attention_forward = attentions.sdpa_attention.sdpa_attention_forward
|
||||||
AttentionInterface._global_mapping["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
AttentionInterface._global_mapping["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward
|
||||||
ALL_ATTENTION_FUNCTIONS["sdpa"] = attenions.sdpa_attention.sdpa_attention_forward
|
ALL_ATTENTION_FUNCTIONS["sdpa"] = attentions.sdpa_attention.sdpa_attention_forward
|
||||||
|
|
||||||
|
|
||||||
def _builtin_patch_flash_attention(config=None):
|
def _builtin_patch_flash_attention(config=None):
|
||||||
@ -105,6 +106,10 @@ def apply_fused_kernel_qwen2(**kwargs):
|
|||||||
_apply_fused_kernel_base(modeling_qwen2, **kwargs)
|
_apply_fused_kernel_base(modeling_qwen2, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fused_kernel_qwen3(**kwargs):
|
||||||
|
_apply_fused_kernel_base(modeling_qwen3, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def apply_fused_kernel_llama(**kwargs):
|
def apply_fused_kernel_llama(**kwargs):
|
||||||
_apply_fused_kernel_base(modeling_llama, **kwargs)
|
_apply_fused_kernel_base(modeling_llama, **kwargs)
|
||||||
|
|
||||||
|
@ -134,20 +134,28 @@ def apply_fused_kernel_to_qwen2(**kwargs):
|
|||||||
Apply npu fused operators for Qwen2 series models, when call this function, all supported
|
Apply npu fused operators for Qwen2 series models, when call this function, all supported
|
||||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||||
specified fused operator:
|
specified fused operator:
|
||||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
`use_npu_fusion_attention: bool`, default is True, set it to `False` to disable npu fusion attention.
|
||||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
`use_fused_rms_norm: bool`, default is True, set it to `False` to disable npu RMSNorm.
|
||||||
|
`use_fused_rope: bool`, default is True, set it to `False` to disable npu fused RoPE
|
||||||
|
`use_fused_swiglu: bool`, default is True, set it to `False` to disable npu fused SwiGLU.
|
||||||
"""
|
"""
|
||||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
|
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
|
||||||
_apply_log(model_type="qwen2", **kwargs)
|
_apply_log(model_type="qwen2", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fused_kernel_to_qwen3(**kwargs):
|
||||||
|
"""
|
||||||
|
Apply npu fused operators for Qwen3 series models, when call this function, all supported
|
||||||
|
fusion operators will be enabled by default.
|
||||||
|
"""
|
||||||
|
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
|
||||||
|
_apply_log(model_type="qwen3", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def apply_fused_kernel_to_internlm2(**kwargs):
|
def apply_fused_kernel_to_internlm2(**kwargs):
|
||||||
"""
|
"""
|
||||||
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
||||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
fusion operators will be enabled by default.
|
||||||
specified fused operator:
|
|
||||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
|
||||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
|
||||||
"""
|
"""
|
||||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs)
|
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs)
|
||||||
_apply_log(model_type="internlm2", **kwargs)
|
_apply_log(model_type="internlm2", **kwargs)
|
||||||
@ -156,10 +164,7 @@ def apply_fused_kernel_to_internlm2(**kwargs):
|
|||||||
def apply_fused_kernel_to_internlm3(**kwargs):
|
def apply_fused_kernel_to_internlm3(**kwargs):
|
||||||
"""
|
"""
|
||||||
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
||||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
fusion operators will be enabled by default.
|
||||||
specified fused operator:
|
|
||||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
|
||||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
|
||||||
"""
|
"""
|
||||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm3, **kwargs)
|
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm3, **kwargs)
|
||||||
_apply_log(model_type="internlm3", **kwargs)
|
_apply_log(model_type="internlm3", **kwargs)
|
||||||
@ -168,10 +173,7 @@ def apply_fused_kernel_to_internlm3(**kwargs):
|
|||||||
def apply_fused_kernel_to_llama(**kwargs):
|
def apply_fused_kernel_to_llama(**kwargs):
|
||||||
"""
|
"""
|
||||||
Apply npu fused operators for Llama series models, when call this function, all supported
|
Apply npu fused operators for Llama series models, when call this function, all supported
|
||||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
fusion operators will be enabled by default.
|
||||||
specified fused operator:
|
|
||||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
|
||||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
|
||||||
"""
|
"""
|
||||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
|
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
|
||||||
_apply_log(model_type="llama", **kwargs)
|
_apply_log(model_type="llama", **kwargs)
|
||||||
@ -180,10 +182,7 @@ def apply_fused_kernel_to_llama(**kwargs):
|
|||||||
def apply_fused_kernel_to_mistral(**kwargs):
|
def apply_fused_kernel_to_mistral(**kwargs):
|
||||||
"""
|
"""
|
||||||
Apply npu fused operators for Mistral series models, when call this function, all supported
|
Apply npu fused operators for Mistral series models, when call this function, all supported
|
||||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
fusion operators will be enabled by default.
|
||||||
specified fused operator:
|
|
||||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
|
||||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
|
||||||
"""
|
"""
|
||||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs)
|
_apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs)
|
||||||
_apply_log(model_type="mistral", **kwargs)
|
_apply_log(model_type="mistral", **kwargs)
|
||||||
@ -191,6 +190,7 @@ def apply_fused_kernel_to_mistral(**kwargs):
|
|||||||
|
|
||||||
SUPPORTED_FUSED_MODELS = {
|
SUPPORTED_FUSED_MODELS = {
|
||||||
"Qwen2ForCausalLM": apply_fused_kernel_to_qwen2,
|
"Qwen2ForCausalLM": apply_fused_kernel_to_qwen2,
|
||||||
|
"Qwen3ForCausalLM": apply_fused_kernel_to_qwen3,
|
||||||
"LlamaForCausalLM": apply_fused_kernel_to_llama,
|
"LlamaForCausalLM": apply_fused_kernel_to_llama,
|
||||||
"MistralForCausalLM": apply_fused_kernel_to_mistral,
|
"MistralForCausalLM": apply_fused_kernel_to_mistral,
|
||||||
"InternLM2ForCausalLM": apply_fused_kernel_to_internlm2,
|
"InternLM2ForCausalLM": apply_fused_kernel_to_internlm2,
|
||||||
|
@ -31,7 +31,8 @@ from openmind.integrations.transformers.npu_fused_ops.dynamic_module_utils impor
|
|||||||
patch_dynamic_fused_ops,
|
patch_dynamic_fused_ops,
|
||||||
)
|
)
|
||||||
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
from openmind.integrations.transformers.npu_fused_ops.rms_norm import rms_norm
|
||||||
from openmind.integrations.transformers.npu_fused_ops import attenions
|
from openmind.integrations.transformers.npu_fused_ops import attentions
|
||||||
|
from openmind.integrations.transformers.npu_fused_ops.attentions import internlm2
|
||||||
|
|
||||||
|
|
||||||
class TestDynamicModelsRegistration(unittest.TestCase):
|
class TestDynamicModelsRegistration(unittest.TestCase):
|
||||||
@ -78,25 +79,27 @@ class TestDynamicPatching(unittest.TestCase):
|
|||||||
@patch("torch.__version__", "2.1.0")
|
@patch("torch.__version__", "2.1.0")
|
||||||
def test_attention_patching(self, _, __):
|
def test_attention_patching(self, _, __):
|
||||||
|
|
||||||
class MockAttentionBase:
|
|
||||||
def forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
_attn_implementation = "eager"
|
_attn_implementation = "eager"
|
||||||
|
|
||||||
mock_module = ModuleType("mock_module")
|
class MockInternLM2SdpaAttention:
|
||||||
mock_module.ATTENTION_CLASSES = {"eager": MockAttentionBase}
|
pass
|
||||||
|
|
||||||
|
mock_module = ModuleType("mock_module")
|
||||||
|
setattr(mock_module, "InternLM2SdpaAttention", MockInternLM2SdpaAttention)
|
||||||
mock_config = Config()
|
mock_config = Config()
|
||||||
|
|
||||||
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config)
|
_dynamic_patch_flash_attention("InternLM2ForCausalLM", mock_module, config=mock_config)
|
||||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||||
|
self.assertEqual(mock_module.InternLM2SdpaAttention.forward, internlm2.forward)
|
||||||
|
|
||||||
@patch("torch.__version__", "2.6.0")
|
@patch("torch.__version__", "2.6.0")
|
||||||
def test_torch_260_sets_sdpa(self):
|
def test_torch_260_sets_sdpa(self):
|
||||||
model_name = "test_model_260"
|
model_name = "test_model_260"
|
||||||
DYNAMIC_MODELS[model_name] = {}
|
DYNAMIC_MODELS[model_name] = {
|
||||||
|
"sdpa_cls_name": "InternLM2SdpaAttention",
|
||||||
|
"sdpa_forward": internlm2.forward,
|
||||||
|
}
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
_attn_implementation = "eager"
|
_attn_implementation = "eager"
|
||||||
@ -107,7 +110,7 @@ class TestDynamicPatching(unittest.TestCase):
|
|||||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
||||||
attenions.sdpa_attention.sdpa_attention_forward,
|
attentions.sdpa_attention.sdpa_attention_forward,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("importlib.util.spec_from_file_location")
|
@patch("importlib.util.spec_from_file_location")
|
||||||
|
@ -20,7 +20,7 @@ import transformers
|
|||||||
from transformers.integrations import sdpa_attention
|
from transformers.integrations import sdpa_attention
|
||||||
|
|
||||||
from openmind.integrations.transformers.npu_fused_ops import kernel
|
from openmind.integrations.transformers.npu_fused_ops import kernel
|
||||||
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attenions
|
from openmind.integrations.transformers.npu_fused_ops import rms_norm, rope, swiglu, attentions
|
||||||
|
|
||||||
|
|
||||||
class TestFusedKernel(unittest.TestCase):
|
class TestFusedKernel(unittest.TestCase):
|
||||||
@ -61,7 +61,7 @@ class TestFusedKernel(unittest.TestCase):
|
|||||||
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
self.assertEqual(mock_config._attn_implementation, "sdpa")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
transformers.integrations.sdpa_attention.sdpa_attention_forward,
|
||||||
attenions.sdpa_attention.sdpa_attention_forward,
|
attentions.sdpa_attention.sdpa_attention_forward,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_builtin_patch_rmsnorm(self):
|
def test_builtin_patch_rmsnorm(self):
|
||||||
@ -121,7 +121,7 @@ class TestFusedKernel(unittest.TestCase):
|
|||||||
module_path = self.mock_cache / "test_module.py"
|
module_path = self.mock_cache / "test_module.py"
|
||||||
module_path.write_text(
|
module_path.write_text(
|
||||||
"class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass"
|
"class InternLM2ForCausalLM:\n pass\nclass InternLM2RMSNorm:\n pass\nclass InternLM2MLP:\n pass"
|
||||||
"\ndef apply_rotary_pos_emb():\n pass\n"
|
"\ndef apply_rotary_pos_emb():\n pass\nclass InternLM2SdpaAttention:\n pass\n"
|
||||||
)
|
)
|
||||||
original_utils = transformers.dynamic_module_utils.get_class_in_module
|
original_utils = transformers.dynamic_module_utils.get_class_in_module
|
||||||
kernel.apply_fused_kernel_internlm2(**kwargs)
|
kernel.apply_fused_kernel_internlm2(**kwargs)
|
||||||
|
Reference in New Issue
Block a user