dev npu fused options

This commit is contained in:
frozenleaves
2025-09-11 09:23:46 +08:00
parent eafa225c97
commit 321444ef78
2 changed files with 1 additions and 21 deletions

View File

@ -143,9 +143,7 @@ def _patch_dynamic_fused_ops():
def apply_fused_options(config, enable: bool=False): def apply_fused_options(config, enable: bool=False):
breakpoint()
if not enable or not is_torch_npu_available(): if not enable or not is_torch_npu_available():
logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.")
return return
from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2_moe import modeling_qwen2_moe from transformers.models.qwen2_moe import modeling_qwen2_moe

View File

@ -89,16 +89,6 @@ def sdpa_attention_forward(
keep_prob=1 keep_prob=1
)[0] )[0]
# else:
# attn_output = torch.nn.functional.scaled_dot_product_attention(
# query,
# key,
# value,
# attn_mask=causal_mask,
# dropout_p=dropout,
# scale=scaling,
# is_causal=is_causal,
# )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None return attn_output, None
@ -162,7 +152,7 @@ def internlm2_sdpa_forward(
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
# custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577. # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None: if query_states.device.type == "npu" and causal_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
@ -202,14 +192,6 @@ def internlm2_sdpa_forward(
keep_prob=1 keep_prob=1
)[0] )[0]
# attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
# query_states,
# key_states,
# value_states,
# attn_mask=causal_mask,
# dropout_p=0.0,
# is_causal=is_causal,
# )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)