mirror of
https://github.com/frozenleaves/LLaMA-Factory.git
synced 2025-10-20 16:23:46 +08:00
dev npu fused options
This commit is contained in:
@ -143,9 +143,7 @@ def _patch_dynamic_fused_ops():
|
||||
|
||||
|
||||
def apply_fused_options(config, enable: bool=False):
|
||||
breakpoint()
|
||||
if not enable or not is_torch_npu_available():
|
||||
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
|
||||
return
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
|
@ -89,16 +89,6 @@ def sdpa_attention_forward(
|
||||
keep_prob=1
|
||||
)[0]
|
||||
|
||||
# else:
|
||||
# attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
# query,
|
||||
# key,
|
||||
# value,
|
||||
# attn_mask=causal_mask,
|
||||
# dropout_p=dropout,
|
||||
# scale=scaling,
|
||||
# is_causal=is_causal,
|
||||
# )
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
||||
@ -162,7 +152,7 @@ def internlm2_sdpa_forward(
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
|
||||
# custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "npu" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@ -202,14 +192,6 @@ def internlm2_sdpa_forward(
|
||||
keep_prob=1
|
||||
)[0]
|
||||
|
||||
# attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
|
||||
# query_states,
|
||||
# key_states,
|
||||
# value_states,
|
||||
# attn_mask=causal_mask,
|
||||
# dropout_p=0.0,
|
||||
# is_causal=is_causal,
|
||||
# )
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
Reference in New Issue
Block a user