Support LLaVA-NeXT in Vision SFT (#1959)

* support llava next

* mention version for llava-next

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
This commit is contained in:
Quentin Gallouédec
2024-08-23 11:37:40 +02:00
committed by GitHub
parent 6cea2ef964
commit 4788e5cda5

View File

@ -27,6 +27,9 @@ python examples/scripts/vsft_llava.py \
--use_peft \
--dataloader_num_workers 32 \
--lora_target_modules=all-linear
For LLaVA-NeXT, use: (requires transformers>=4.45)
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
"""
import logging
@ -50,7 +53,7 @@ from accelerate import Accelerator
from datasets import load_dataset
from tqdm.rich import tqdm
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import AutoModelForVision2Seq, AutoProcessor
from trl import (
ModelConfig,
@ -100,7 +103,7 @@ if __name__ == "__main__":
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
model = LlavaForConditionalGeneration.from_pretrained(
model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)