diff --git a/src/llamafactory/third_party/npu_fused_options/npu_fused_patcher.py b/src/llamafactory/third_party/npu_fused_options/npu_fused_patcher.py index 9eb07e06..356c77b1 100644 --- a/src/llamafactory/third_party/npu_fused_options/npu_fused_patcher.py +++ b/src/llamafactory/third_party/npu_fused_options/npu_fused_patcher.py @@ -143,9 +143,7 @@ def _patch_dynamic_fused_ops(): def apply_fused_options(config, enable: bool=False): - breakpoint() if not enable or not is_torch_npu_available(): - logger.warning_rank0("NPU fused options is disabled, or the torch NPU backend is not available.") return from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2_moe import modeling_qwen2_moe diff --git a/src/llamafactory/third_party/npu_fused_options/sdpa_attention.py b/src/llamafactory/third_party/npu_fused_options/sdpa_attention.py index dd0faa1e..d07a20ee 100644 --- a/src/llamafactory/third_party/npu_fused_options/sdpa_attention.py +++ b/src/llamafactory/third_party/npu_fused_options/sdpa_attention.py @@ -89,16 +89,6 @@ def sdpa_attention_forward( keep_prob=1 )[0] - # else: - # attn_output = torch.nn.functional.scaled_dot_product_attention( - # query, - # key, - # value, - # attn_mask=causal_mask, - # dropout_p=dropout, - # scale=scaling, - # is_causal=is_causal, - # ) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -162,7 +152,7 @@ def internlm2_sdpa_forward( # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == "npu" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -202,14 +192,6 @@ def internlm2_sdpa_forward( keep_prob=1 )[0] - # attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102 - # query_states, - # key_states, - # value_states, - # attn_mask=causal_mask, - # dropout_p=0.0, - # is_causal=is_causal, - # ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size)