mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-20 12:54:18 +08:00
[misc] fix import error (#9296)
This commit is contained in:
@ -14,8 +14,6 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import AttentionFunction
|
||||
|
||||
@ -30,6 +28,8 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
if getattr(config, "model_type", None) == "gemma2":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||
if is_flash_attn_2_available():
|
||||
@ -51,6 +51,8 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||
from transformers.utils import is_torch_sdpa_available
|
||||
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
Reference in New Issue
Block a user