[misc] fix import error (#9296)

This commit is contained in:
Yaowei Zheng
2025-10-17 10:54:30 +08:00
committed by GitHub
parent 8c341cbaae
commit a442fa90ad

View File

@ -14,8 +14,6 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras import logging
from ...extras.constants import AttentionFunction
@ -30,6 +28,8 @@ logger = logging.get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
from transformers.utils import is_flash_attn_2_available
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
@ -51,6 +51,8 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager"
elif model_args.flash_attn == AttentionFunction.SDPA:
from transformers.utils import is_torch_sdpa_available
if not is_torch_sdpa_available():
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return