mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Support LLaVA-NeXT in Vision SFT (#1959)
* support llava next * mention version for llava-next --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
6cea2ef964
commit
4788e5cda5
@ -27,6 +27,9 @@ python examples/scripts/vsft_llava.py \
|
||||
--use_peft \
|
||||
--dataloader_num_workers 32 \
|
||||
--lora_target_modules=all-linear
|
||||
|
||||
For LLaVA-NeXT, use: (requires transformers>=4.45)
|
||||
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
|
||||
"""
|
||||
|
||||
import logging
|
||||
@ -50,7 +53,7 @@ from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
|
||||
from tqdm.rich import tqdm
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
@ -100,7 +103,7 @@ if __name__ == "__main__":
|
||||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
|
||||
)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user