From a442fa90ad4b567990ec511e7e774f074ef479e4 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Fri, 17 Oct 2025 10:54:30 +0800 Subject: [PATCH] [misc] fix import error (#9296) --- src/llamafactory/model/model_utils/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index fb86a163..d901a9a8 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -14,8 +14,6 @@ from typing import TYPE_CHECKING -from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available - from ...extras import logging from ...extras.constants import AttentionFunction @@ -30,6 +28,8 @@ logger = logging.get_logger(__name__) def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: + from transformers.utils import is_flash_attn_2_available + if getattr(config, "model_type", None) == "gemma2": if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if is_flash_attn_2_available(): @@ -51,6 +51,8 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model requested_attn_implementation = "eager" elif model_args.flash_attn == AttentionFunction.SDPA: + from transformers.utils import is_torch_sdpa_available + if not is_torch_sdpa_available(): logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") return