mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
* Refactor main function in dpo.py * Update setup.py and add cli.py * Add examples to package data * style * Refactor setup.py file * Add new file t.py * Move dpo to package * Update MANIFEST.in and setup.py, refactor trl/cli.py * Add __init__.py to trl/scripts directory * Add license header to __init__.py * File moved instruction * Add Apache License and update file path * Move dpo.py to new location * Refactor CLI and DPO script * Refactor import structure in scripts package * env * rm config from chat arg * rm old cli * chat init * test cli [skip ci] * Add `datast_config_name` to `ScriptArguments` (#2440) * add missing arg * Add test cases for 'trl sft' and 'trl dpo' commands * Add sft.py script and update cli.py to include sft command * Move sft script * chat * style [ci skip] * kto * rm example config * first step on doc * see #2442 * see #2443 * fix chat windows * ©️ Copyrights update (#2454) * First changes * Other files * Finally * rm comment * fix nashmd * Fix example * Fix example [ci skip] * 💬 Fix chat for windows (#2443) * fix chat for windows * add some tests back * Revert "add some tests back" This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06. * 🆔 Add `datast_config` to `ScriptArguments` (#2440) * datast_config_name * Update trl/utils.py [ci skip] * sort import * typo [ci skip] * Trigger CI * Rename `dataset_config_name` to `dataset_config` * 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417) * Remove unused deepspeed code * add model prep back * add deepspeed even if it doesn't work * rm old code * Fix config name * Remove `make dev` in favor of `pip install -e .[dev]` * Update script paths and remove old symlink related things * Fix chat script path [ci skip] * style
59 lines
1.5 KiB
Bash
59 lines
1.5 KiB
Bash
#!/bin/bash
|
|
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
|
# but defaults to QLoRA + PEFT
|
|
OUTPUT_DIR="test_dpo/"
|
|
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
|
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
|
MAX_STEPS=5
|
|
BATCH_SIZE=2
|
|
SEQ_LEN=128
|
|
|
|
# Handle extra arguments in case one passes accelerate configs.
|
|
EXTRA_ACCELERATE_ARGS=""
|
|
EXTRA_TRAINING_ARGS="""--use_peft \
|
|
--load_in_4bit
|
|
"""
|
|
|
|
# This is a hack to get the number of available GPUs
|
|
NUM_GPUS=2
|
|
|
|
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
|
EXTRA_ACCELERATE_ARGS=""
|
|
else
|
|
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
|
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
|
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
|
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
|
EXTRA_TRAINING_ARGS="--fp16"
|
|
else
|
|
echo "Keeping QLoRA + PEFT"
|
|
fi
|
|
fi
|
|
|
|
|
|
CMD="""
|
|
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
|
--num_processes $NUM_GPUS \
|
|
--mixed_precision 'fp16' \
|
|
`pwd`/trl/scripts/dpo.py \
|
|
--model_name_or_path $MODEL_NAME \
|
|
--dataset_name $DATASET_NAME \
|
|
--output_dir $OUTPUT_DIR \
|
|
--max_steps $MAX_STEPS \
|
|
--per_device_train_batch_size $BATCH_SIZE \
|
|
--max_length $SEQ_LEN \
|
|
$EXTRA_TRAINING_ARGS
|
|
"""
|
|
|
|
echo "Starting program..."
|
|
|
|
{ # try
|
|
echo $CMD
|
|
eval "$CMD"
|
|
} || { # catch
|
|
# save log for exception
|
|
echo "Operation Failed!"
|
|
exit 1
|
|
}
|
|
exit 0
|