mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
126 Commits
Author | SHA1 | Date | |
---|---|---|---|
f2c71771cc | |||
631c33cbb3 | |||
3f7ff60528 | |||
1705aebeba | |||
4e622a9033 | |||
eb2d5b2972 | |||
f976c6d234 | |||
abc7301bab | |||
6cfa5cfc81 | |||
a2aa0f0b09 | |||
304e208f77 | |||
4fe8b027f6 | |||
fb6ebb1e11 | |||
66078c7c01 | |||
58c0888996 | |||
486e7a4071 | |||
7630f877f9 | |||
4d862da181 | |||
22b4f548f4 | |||
4219cbfedc | |||
3bd02380c7 | |||
067db7553a | |||
93e85ed808 | |||
14e0d78807 | |||
b32656f726 | |||
9399bc113b | |||
11f122ad49 | |||
009c9a610b | |||
7712d42f8c | |||
7c2213b9e5 | |||
ddeebce176 | |||
cf68d871cf | |||
2a2676e7ec | |||
ca90cba351 | |||
4f97fb4a74 | |||
a46cd84a64 | |||
1f56bffdf8 | |||
1bfe0b8fcb | |||
0f13e51efa | |||
1e77d8aeb2 | |||
3b1911c2a9 | |||
851e7fe556 | |||
31b02d0cd0 | |||
9bc478ecbb | |||
29f162b86c | |||
6852097169 | |||
f12a1da74b | |||
ae87b3aefa | |||
3f7cee7643 | |||
ae8431bd50 | |||
66a976c6bd | |||
814930377c | |||
88685f2cd4 | |||
6f40f20233 | |||
036213bd85 | |||
6042596705 | |||
070c75ec54 | |||
b415224a4a | |||
9186710671 | |||
aa35fec099 | |||
737d771941 | |||
ef441ea028 | |||
af623aeba6 | |||
3843cfc32f | |||
9a71e67be9 | |||
09ca565b24 | |||
4edc688311 | |||
29d439a204 | |||
5760e5d3db | |||
a3c5b7178a | |||
222d275b8a | |||
09ca7607d5 | |||
1e68753216 | |||
1f59eeb9bb | |||
928d14445e | |||
3319993bd1 | |||
4fb3d0c860 | |||
bcccdeb6f9 | |||
ef209e311f | |||
341f6a6787 | |||
97b9fa212a | |||
a7d796c9a2 | |||
fa074e6a15 | |||
776939dcc4 | |||
163ca9f059 | |||
2eeb7b04cf | |||
9f8d0e48ad | |||
c9b7145c75 | |||
baf3c1c293 | |||
b181e401a7 | |||
26da9e80cb | |||
d6cc88ab2c | |||
7a95cc8696 | |||
d1715514de | |||
d116887ed4 | |||
a236c5750f | |||
4ae35afdd6 | |||
b21ed0ddbc | |||
384b868fe6 | |||
3267be0fcd | |||
dbcb2f0021 | |||
d5910b0ff5 | |||
104a02d207 | |||
ad597dbcb3 | |||
d57d0f9ca4 | |||
ec3d41b879 | |||
be32d304db | |||
dc53b8c6b0 | |||
20428c48ba | |||
6614b8aa6b | |||
df7b770da8 | |||
18a33ffcd3 | |||
911d3658e2 | |||
95ec8577df | |||
3539f3e3cd | |||
e451298b50 | |||
3efb484694 | |||
8f5b4923c8 | |||
e0dec27272 | |||
6ef785a6fb | |||
950ee2187d | |||
c1bb1f39f6 | |||
54babd9508 | |||
0c4edb750e | |||
17ec68d980 | |||
9be5680039 |
127
.github/workflows/docker-build.yml
vendored
Normal file
127
.github/workflows/docker-build.yml
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
name: Build Docker images (scheduled)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
trl-latest:
|
||||
name: "Latest TRL GPU"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-latest-gpu
|
||||
push: true
|
||||
tags: huggingface/trl-latest-gpu
|
||||
|
||||
- name: Post to a Slack channel
|
||||
id: slack
|
||||
#uses: slackapi/slack-github-action@v1.25.0
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# For posting a rich message using Block Kit
|
||||
payload: |
|
||||
{
|
||||
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
trl-source:
|
||||
name: "Latest TRL + HF ecosystem from source"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-source-gpu
|
||||
push: true
|
||||
tags: huggingface/trl-source-gpu
|
||||
|
||||
- name: Post to a Slack channel
|
||||
id: slack
|
||||
#uses: slackapi/slack-github-action@v1.25.0
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# For posting a rich message using Block Kit
|
||||
payload: |
|
||||
{
|
||||
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
96
.github/workflows/slow-tests.yml
vendored
Normal file
96
.github/workflows/slow-tests.yml
vendored
Normal file
@ -0,0 +1,96 @@
|
||||
name: Slow tests (on push)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
# Run only when python files are modified
|
||||
- "trl/**.py"
|
||||
- "examples/**.py"
|
||||
env:
|
||||
RUN_SLOW: "yes"
|
||||
IS_GITHUB_CI: "1"
|
||||
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
|
||||
jobs:
|
||||
run_all_tests_single_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
||||
runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Pip install
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
|
||||
- name: Run slow SFT tests on single GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
make slow_tests
|
||||
|
||||
- name: Generate Report
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
run_all_tests_multi_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
||||
runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Pip install
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
|
||||
- name: Run slow SFT tests on Multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
make slow_tests
|
||||
|
||||
- name: Run end-to-end examples tests on multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
pip install deepspeed
|
||||
make test_examples
|
||||
|
||||
- name: Generate Reports
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
rm *.txt
|
63
.github/workflows/tests-main.yml
vendored
Normal file
63
.github/workflows/tests-main.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: tests on transformers PEFT main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# install PEFT & transformers from source
|
||||
pip install -U git+https://github.com/huggingface/peft.git
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
# cpu version of pytorch
|
||||
pip install ".[test, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
- name: Post to a Slack channel
|
||||
if: always()
|
||||
id: slack
|
||||
#uses: slackapi/slack-github-action@v1.25.0
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# For posting a rich message using Block Kit
|
||||
payload: |
|
||||
{
|
||||
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
11
.github/workflows/tests.yml
vendored
11
.github/workflows/tests.yml
vendored
@ -5,6 +5,13 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
# Run only when relevant files are modified
|
||||
- "trl/**.py"
|
||||
- "examples/**.py"
|
||||
- "scripts/**.py"
|
||||
- ".github/**.yml"
|
||||
- "tests/**.py"
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
@ -47,7 +54,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install -e ".[test, peft, diffusers]"
|
||||
pip install ".[test, peft, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
@ -72,4 +79,4 @@ jobs:
|
||||
pip install .[test]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
make test
|
||||
|
@ -1,37 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
- --skip-glob=wandb/**/*
|
||||
- --thirdparty=wandb
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args:
|
||||
- -r
|
||||
- --exclude=wandb,__init__.py
|
||||
- --in-place
|
||||
- --remove-unused-variables
|
||||
- --remove-all-unused-imports
|
||||
- repo: https://github.com/python/black
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
args:
|
||||
- --line-length=119
|
||||
- --target-version=py38
|
||||
- --exclude=wandb
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
- --ignore=E203,E501,W503,E128
|
||||
- --max-line-length=119
|
||||
- id: ruff
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
||||
# - repo: https://github.com/codespell-project/codespell
|
||||
# rev: v2.1.0
|
||||
|
@ -5,7 +5,7 @@
|
||||
Before you start contributing make sure you installed all the dev tools:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
make dev
|
||||
```
|
||||
|
||||
## Did you find a bug?
|
||||
|
@ -2,4 +2,4 @@ include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
recursive-exclude * __pycache__
|
30
Makefile
30
Makefile
@ -1,7 +1,16 @@
|
||||
.PHONY: test precommit benchmark_core benchmark_aux
|
||||
.PHONY: test precommit benchmark_core benchmark_aux common_tests slow_tests test_examples tests_gpu
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
|
||||
dev:
|
||||
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
|
||||
pip install -e ".[dev]"
|
||||
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
@ -13,3 +22,22 @@ benchmark_core:
|
||||
|
||||
benchmark_aux:
|
||||
bash ./benchmark/benchmark_aux.sh
|
||||
|
||||
tests_gpu:
|
||||
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
|
||||
|
||||
slow_tests:
|
||||
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
122
README.md
122
README.md
@ -3,7 +3,7 @@
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
> Full stack transformer language models with reinforcement learning.
|
||||
> Full stack library to fine-tune and align large language models.
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
|
||||
@ -20,58 +20,70 @@
|
||||
|
||||
## What is it?
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
|
||||
</div>
|
||||
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
|
||||
|
||||
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
|
||||
|
||||
**Highlights:**
|
||||
|
||||
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
|
||||
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
|
||||
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
|
||||
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
|
||||
|
||||
## How PPO works
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
## Highlights
|
||||
|
||||
- **`Efficient and scalable`**:
|
||||
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
|
||||
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
|
||||
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
|
||||
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
|
||||
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), and [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer).
|
||||
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
|
||||
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
|
||||
|
||||
## Installation
|
||||
|
||||
### Python package
|
||||
Install the library with pip:
|
||||
Install the library with `pip`:
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### From source
|
||||
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
|
||||
If you want to use the latest features before an official release you can install from source:
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
### Repository
|
||||
If you want to use the examples you can clone the repository with the following command:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install .
|
||||
```
|
||||
|
||||
If you wish to develop TRL, you should install in editable mode:
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over the training you can use the dedicated trainer classes to fine-tune the model in Python.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
@ -138,11 +150,10 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
model_ref = create_reference_model(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# initialize trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=1,
|
||||
)
|
||||
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
|
||||
|
||||
# encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
@ -162,13 +173,50 @@ reward = [torch.tensor(1.0)]
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
|
||||
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example on how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
...
|
||||
|
||||
# load trainer
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
make dev
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
### Proximal Policy Optimisation
|
||||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
|
||||
|
||||
### Language models
|
||||
The language models utilize the `transformers` library by 🤗 Hugging Face.
|
||||
### Direct Preference Optimization
|
||||
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
|
@ -1,20 +1,5 @@
|
||||
#### Step 1: create a work directory:
|
||||
# this is necessary because another github action job will remove
|
||||
# the entire directory, which slurm depends on.
|
||||
# https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory
|
||||
MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir
|
||||
mkdir -p $MY_SLURM_TMP_DIR
|
||||
WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"`
|
||||
cp -r "$PWD" "$WORK_DIR"
|
||||
cd "$WORK_DIR/$(basename "$PWD")"
|
||||
echo WORK_DIR: $WORK_DIR
|
||||
|
||||
#### Step 2: actual work starts:
|
||||
echo PATH is $PATH
|
||||
echo PYTHONPATH is $PYTHONPATH
|
||||
echo whcih python is $(which python)
|
||||
|
||||
export WANDB_ENTITY=huggingface
|
||||
export WANDB_PROJECT=trl
|
||||
bash $BENCHMARK_SCRIPT > output.txt
|
||||
|
||||
# Extract Job IDs into an array
|
||||
|
@ -1,6 +1,39 @@
|
||||
# hello world experiment
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/dpo.py --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 1e-3 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="dpo_anthropic_hh" --optim adamw_torch --warmup_steps 150 --report_to wandb --bf16 --logging_first_step --no_remove_unused_columns" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/sft.py --model_name_or_path="facebook/opt-350m" --report_to="wandb" --learning_rate=1.41e-5 --per_device_train_batch_size=64 --gradient_accumulation_steps=16 --output_dir="sft_openassistant-guanaco" --logging_steps=1 --num_train_epochs=3 --max_steps=-1 --push_to_hub --gradient_checkpointing" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/reward_modeling.py --model_name_or_path=facebook/opt-350m --output_dir="reward_modeling_anthropic_hh" --per_device_train_batch_size=64 --num_train_epochs=1 --gradient_accumulation_steps=16 --gradient_checkpointing=True --learning_rate=1.41e-5 --report_to="wandb" --remove_unused_columns=False --optim="adamw_torch" --logging_steps=10 --evaluation_strategy="steps" --max_length=512" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
|
@ -9,7 +9,37 @@ python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/hello_world \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/ppo \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/rewards/accuracies&metrics=train/loss' \
|
||||
"gpt2$TAGS_STRING" \
|
||||
--env-ids dpo_anthropic_hh \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/dpo \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss&metrics=eval/accuracy&metrics=eval/loss' \
|
||||
"facebook/opt-350m$TAGS_STRING" \
|
||||
--env-ids reward_modeling_anthropic_hh \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/reward_modeling \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss' \
|
||||
"facebook/opt-350m$TAGS_STRING" \
|
||||
--env-ids sft_openassistant-guanaco \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/sft \
|
||||
--scan-history
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
|
@ -1,6 +1,6 @@
|
||||
# compound experiments: gpt2xl + grad_accu
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -12,7 +12,7 @@ python benchmark/benchmark.py \
|
||||
|
||||
# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
|
||||
python benchmark/benchmark.py \
|
||||
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
|
||||
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
|
@ -1,6 +1,6 @@
|
||||
## w/ and w/o gradient accumulation
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -12,7 +12,7 @@ python benchmark/benchmark.py \
|
||||
|
||||
## w/ different models (gpt2, gpt2-xl, falcon, llama2)
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -22,7 +22,7 @@ python benchmark/benchmark.py \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -35,7 +35,7 @@ python benchmark/benchmark.py \
|
||||
|
||||
## w/ and w/o PEFT
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_peft --use_peft --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=trl
|
||||
#SBATCH --partition=production-cluster
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --output=slurm/logs/%x_%j.out
|
||||
|
||||
|
3
benchmark/regression_test.sh
Normal file
3
benchmark/regression_test.sh
Normal file
@ -0,0 +1,3 @@
|
||||
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" \
|
||||
BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" \
|
||||
bash benchmark/benchmark_and_report.sh
|
@ -1,16 +1,19 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=trl
|
||||
#SBATCH --partition=production-cluster
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --gpus-per-task={{gpus_per_task}}
|
||||
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
|
||||
#SBATCH --ntasks={{ntasks}}
|
||||
#SBATCH --output=slurm/logs/%x_%j.out
|
||||
#SBATCH --array={{array}}
|
||||
#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193
|
||||
##SBATCH --exclude=ip-26-0-149-199
|
||||
|
||||
module load cuda/12.1
|
||||
|
||||
{{nodes}}
|
||||
|
||||
seeds={{seeds}}
|
||||
seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}
|
||||
|
||||
echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
|
||||
srun {{command}} --ppo_config.seed $seed
|
||||
srun {{command}} --seed $seed
|
||||
|
58
commands/run_dpo.sh
Normal file
58
commands/run_dpo.sh
Normal file
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="trl-internal-testing/Anthropic-hh-rlhf-processed"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
59
commands/run_sft.sh
Normal file
59
commands/run_sft.sh
Normal file
@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
66
docker/trl-latest-gpu/Dockerfile
Normal file
66
docker/trl-latest-gpu/Dockerfile
Normal file
@ -0,0 +1,66 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
transformers \
|
||||
accelerate \
|
||||
peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep trl
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
66
docker/trl-source-gpu/Dockerfile
Normal file
66
docker/trl-source-gpu/Dockerfile
Normal file
@ -0,0 +1,66 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
git+https://github.com/huggingface/transformers \
|
||||
git+https://github.com/huggingface/accelerate \
|
||||
git+https://github.com/huggingface/peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep transformers
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
@ -5,6 +5,8 @@
|
||||
title: Quickstart
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: clis
|
||||
title: Get started with Command Line Interfaces (CLIs)
|
||||
- local: how_to_train
|
||||
title: PPO Training FAQ
|
||||
- local: use_model
|
||||
@ -29,6 +31,8 @@
|
||||
title: Best of N Sampling
|
||||
- local: dpo_trainer
|
||||
title: DPO Trainer
|
||||
- local: kto_trainer
|
||||
title: KTO Trainer
|
||||
- local: ddpo_trainer
|
||||
title: Denoising Diffusion Policy Optimization
|
||||
- local: iterative_sft_trainer
|
||||
|
109
docs/source/clis.mdx
Normal file
109
docs/source/clis.mdx
Normal file
@ -0,0 +1,109 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
|
||||
Currently supported CLIs are:
|
||||
|
||||
- `trl sft`: fine-tune a LLM on a text/instruction dataset
|
||||
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
|
||||
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
HuggingFaceM4/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.:
|
||||
|
||||
```bash
|
||||
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
|
||||
### Supported Arguments
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `examples/scripts/sft.py` script.
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
|
||||
First, follow the basic instructions above and run `trl dpo --output_dir <output_dir> <*args>`. Make sure to process your DPO dataset in the TRL format as follows:
|
||||
|
||||
1- Make sure to pre-tokenize the dataset using chat templates:
|
||||
|
||||
```bash
|
||||
python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset
|
||||
```
|
||||
|
||||
You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template
|
||||
|
||||
2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
```
|
||||
|
||||
Once your dataset being pushed, run the dpo CLI as follows:
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `examples/scripts/dpo.py` script.
|
||||
|
||||
## Chat interface
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
Note that the chat interface relies on the chat template of the tokenizer to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
|
||||
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONIG_FILE` where you can also specify the default generation parameters.
|
@ -2,9 +2,24 @@
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
|
||||
|
||||
|
||||
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
## How DPO works
|
||||
|
||||
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
|
||||
|
||||
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
||||
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo).
|
||||
|
||||
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://arxiv.org/pdf/2305.18290.pdf)):
|
||||
|
||||
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://arxiv.org/pdf/2305.18290.pdf).
|
||||
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
|
||||
@ -63,7 +78,7 @@ The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
|
||||
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
```py
|
||||
dpo_trainer = DPOTrainer(
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
@ -86,11 +101,11 @@ Given the preference data, we can fit a binary classifier according to the Bradl
|
||||
|
||||
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
|
||||
|
||||
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.
|
||||
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only).
|
||||
|
||||
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
|
||||
|
||||
The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of preferences. Thus the dataset are not necessarily preferences but rather desirable vs undesirable completions. For paired preference data as required by the `DPOTrainer`, use the `loss_type="kto_pair"` argument to the trainer to utilize this loss, while for the more general case of desired and undesirable data, use the as of yet unimplemented `KTOTrainer`.
|
||||
The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.
|
||||
|
||||
## Logging
|
||||
|
||||
@ -101,6 +116,120 @@ While training and evaluating we record the following reward metrics:
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
## Accelerate DPO fine-tuning using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
|----------|-----------------|-----------|------|------------------------|-----------------|----------------|
|
||||
| A100 40G | Zephyr 7b | Ultra Chat| 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat| 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import TrainingArguments
|
||||
from trl import DPOTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name = "unsloth/zephyr-sft",
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
# Do model patching and add fast LoRA weights
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0, # Dropout = 0 is currently optimized
|
||||
bias = "none", # Bias = "none" is currently optimized
|
||||
use_gradient_checkpointing = True,
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(output_dir="./output")
|
||||
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref=None,
|
||||
args=training_args,
|
||||
beta=0.1,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
dpo_trainer.train()
|
||||
```
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Reference model considerations with PEFT
|
||||
|
||||
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
|
||||
|
||||
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
|
||||
2. Merge the adapter into the base model, create another adapter on top, then leave the `model_ref` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
|
||||
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
|
||||
|
||||
### Downsides to merging QLoRA before DPO (approach 2)
|
||||
|
||||
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
|
||||
|
||||
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
|
||||
|
||||
### Using option 3 - load the adapter twice
|
||||
|
||||
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.
|
||||
|
||||
For example:
|
||||
```python
|
||||
# Load the base model.
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/mixtral-8x7b-v0.1",
|
||||
load_in_4bit=True,
|
||||
quantization_config=bnb_config,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
# Load the adapter.
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
"/path/to/peft",
|
||||
is_trainable=True,
|
||||
adapter_name="train",
|
||||
)
|
||||
# Load the adapter a second time, with a different name, which will be our reference model.
|
||||
model.load_adapter("/path/to/peft", adapter_name="reference")
|
||||
|
||||
# Initialize the trainer, without a ref_model param.
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
...
|
||||
model_adapter_name="train",
|
||||
ref_adapter_name="reference",
|
||||
)
|
||||
```
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
[[autodoc]] DPOTrainer
|
||||
|
@ -30,7 +30,6 @@ If you generate text by purely sampling from the model distribution things work
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached, thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
|
93
docs/source/kto_trainer.mdx
Normal file
93
docs/source/kto_trainer.mdx
Normal file
@ -0,0 +1,93 @@
|
||||
# KTO Trainer
|
||||
|
||||
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
|
||||
For a full example have a look at [`examples/scripts/kto.py`].
|
||||
|
||||
Depending on how good your base model is, you may or may not need to do SFT before KTO.
|
||||
This is different from standard RLHF and DPO, which always require SFT.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
|
||||
|
||||
- `prompt`
|
||||
- `completion`
|
||||
- `label`
|
||||
|
||||
for example:
|
||||
|
||||
```
|
||||
kto_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
|
||||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `KTOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
```py
|
||||
training_args = KTOConfig(
|
||||
beta=0.1,
|
||||
desirable_weight=1.0,
|
||||
undesirable_weight=1.0,
|
||||
)
|
||||
|
||||
kto_trainer = KTOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
kto_trainer.train()
|
||||
```
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
@ -71,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an
|
||||
|
||||
```bash
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
|
||||
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
|
||||
```
|
||||
|
||||
## Using `trl` + `peft` and Data Parallelism
|
||||
@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
|
||||
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
|
||||
|
||||
```bash
|
||||
python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2
|
||||
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
|
||||
```
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Multi Adapter RL (MARL) - a single base model for everything
|
||||
|
||||
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue.
|
||||
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
|
||||
|
||||
## Requirements
|
||||
|
||||
@ -48,7 +48,7 @@ trainer = PPOTrainer(
|
||||
|
||||
...
|
||||
```
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`.
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
|
||||
|
||||
```python
|
||||
rewards = trainer.model.compute_reward_score(**inputs)
|
||||
@ -58,8 +58,8 @@ rewards = trainer.model.compute_reward_score(**inputs)
|
||||
|
||||
### Control on the adapter name
|
||||
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
|
||||
```python
|
||||
adapter_name_policy_1 = "policy_1"
|
||||
@ -97,4 +97,4 @@ trainer = PPOTrainer(
|
||||
...
|
||||
)
|
||||
...
|
||||
```
|
||||
```
|
||||
|
@ -4,6 +4,21 @@ TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training la
|
||||
|
||||
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
|
||||
|
||||
## How PPO works
|
||||
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
|
||||
@ -90,7 +105,7 @@ from trl import PPOTrainer
|
||||
ppo_trainer = PPOTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
train_dataset=train_dataset,
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
@ -115,22 +130,22 @@ We can then loop over all examples in the dataset and generate a response for ea
|
||||
|
||||
```py
|
||||
from tqdm import tqdm
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
#### Get response from SFTModel
|
||||
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
#### Compute reward score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_model(texts)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
|
||||
#### Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
|
||||
for batch in tqdm(ppo_trainer.dataloader):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
#### Get response from SFTModel
|
||||
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
#### Compute reward score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_model(texts)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
|
||||
#### Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
#### Save model
|
||||
ppo_trainer.save_model("my_ppo_model")
|
||||
@ -148,4 +163,4 @@ While training and evaluating we log the following metrics:
|
||||
|
||||
[[autodoc]] PPOTrainer
|
||||
|
||||
[[autodoc]] PPOConfig
|
||||
[[autodoc]] PPOConfig
|
||||
|
@ -30,7 +30,7 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {"batch_size": 1}
|
||||
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
|
||||
|
||||
|
@ -25,7 +25,7 @@ accelerate launch examples/scripts/ppo.py # launches training
|
||||
# 3. get help text and documentation
|
||||
python examples/scripts/ppo.py --help
|
||||
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
|
||||
python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16
|
||||
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16
|
||||
```
|
||||
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
@ -42,7 +42,7 @@ Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce loc
|
||||
|
||||
```bash
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -61,7 +61,7 @@ python benchmark/benchmark.py \
|
||||
|
||||
```bash
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -79,7 +79,7 @@ python benchmark/benchmark.py \
|
||||
|
||||
```bash
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2 --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -89,7 +89,7 @@ python benchmark/benchmark.py \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -99,7 +99,7 @@ python benchmark/benchmark.py \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
@ -116,7 +116,7 @@ python benchmark/benchmark.py \
|
||||
|
||||
```
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_peft --use_peft --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Supervised Fine-tuning Trainer
|
||||
# Supervised Fine-tuning Trainer
|
||||
|
||||
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
|
||||
|
||||
@ -6,7 +6,7 @@ Check out a complete flexible example at [`examples/scripts/sft.py`](https://git
|
||||
|
||||
## Quickstart
|
||||
|
||||
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
|
||||
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
|
||||
The following code-snippet takes care of all the data pre-processing and training for you:
|
||||
|
||||
```python
|
||||
@ -50,7 +50,7 @@ The above snippets will use the default training arguments from the [`transforme
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Train on completions only
|
||||
### Train on completions only
|
||||
|
||||
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
|
||||
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
|
||||
@ -82,7 +82,7 @@ trainer = SFTTrainer(
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
|
||||
@ -108,15 +108,15 @@ trainer = SFTTrainer(
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
|
||||
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
|
||||
|
||||
#### Using token_ids directly for `response_template`
|
||||
|
||||
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
@ -134,14 +134,14 @@ print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007),
|
||||
```
|
||||
|
||||
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
|
||||
|
||||
|
||||
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
|
||||
- `response_template` (without context): `[835, 4007, 22137, 29901]`
|
||||
|
||||
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
|
||||
|
||||
```
|
||||
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
|
||||
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
|
||||
```
|
||||
|
||||
|
||||
@ -154,9 +154,75 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe
|
||||
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Add Special Tokens for Chat Format
|
||||
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
|
||||
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Resizes the model’s embedding layer to accommodate the new tokens.
|
||||
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
# Set up the chat format with default 'chatml' format
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
```
|
||||
|
||||
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
|
||||
|
||||
### Dataset format support
|
||||
|
||||
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
|
||||
* conversational format
|
||||
```json
|
||||
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
|
||||
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
|
||||
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
|
||||
```
|
||||
* instruction format
|
||||
```json
|
||||
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
|
||||
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
|
||||
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
|
||||
```
|
||||
|
||||
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
|
||||
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
...
|
||||
|
||||
# load jsonl dataset
|
||||
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
|
||||
# load dataset from the HuggingFace Hub
|
||||
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
|
||||
|
||||
...
|
||||
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
packing=True,
|
||||
)
|
||||
```
|
||||
|
||||
If the dataset is not in one those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
|
||||
|
||||
|
||||
### Format your input prompts
|
||||
|
||||
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
|
||||
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
|
||||
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
|
||||
```bash
|
||||
Below is an instruction ...
|
||||
@ -185,7 +251,7 @@ trainer = SFTTrainer(
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
|
||||
### Packing dataset ([`ConstantLengthDataset`])
|
||||
|
||||
@ -204,7 +270,8 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTTrainer` init method.
|
||||
|
||||
#### Customize your prompts using packed dataset
|
||||
|
||||
@ -228,7 +295,7 @@ You can also customize the [`ConstantLengthDataset`] much more by directly passi
|
||||
|
||||
### Control over the pretrained model
|
||||
|
||||
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
|
||||
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
|
||||
@ -248,7 +315,7 @@ trainer = SFTTrainer(
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
Note that all keyword arguments of `from_pretrained()` are supported.
|
||||
Note that all keyword arguments of `from_pretrained()` are supported.
|
||||
|
||||
### Training adapters
|
||||
|
||||
@ -281,7 +348,7 @@ trainer.train()
|
||||
|
||||
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
|
||||
### Training adapters with base 8 bit models
|
||||
### Training adapters with base 8 bit models
|
||||
|
||||
For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
|
||||
|
||||
@ -314,7 +381,7 @@ trainer.train()
|
||||
|
||||
## Using Flash Attention and Flash Attention 2
|
||||
|
||||
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
|
||||
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
|
||||
First, to make sure you have all the latest features from transformers, install transformers from source
|
||||
|
||||
```bash
|
||||
@ -346,11 +413,11 @@ Note that you cannot train your model using Flash Attention 1 on an arbitrary da
|
||||
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
|
||||
|
||||
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|
||||
|----------------|-------------------|-------------|------------|------------------------|
|
||||
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| | facebook/opt-350m | 2048 | 8 | **OOM** |
|
||||
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
|
||||
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
|
||||
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| | facebook/opt-350m | 2048 | 8 | **OOM** |
|
||||
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
|
||||
|
||||
### Using Flash Attention-2
|
||||
|
||||
@ -360,13 +427,13 @@ To use Flash Attention 2, first install the latest `flash-attn` package:
|
||||
pip install -U flash-attn
|
||||
```
|
||||
|
||||
And add `use_flash_attention_2=True` when calling `from_pretrained`:
|
||||
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
```
|
||||
|
||||
@ -375,6 +442,45 @@ After loading your model, you can either train it as it is, or attach adapters a
|
||||
|
||||
In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
|
||||
|
||||
|
||||
### Using model creation utility
|
||||
|
||||
We included a utility function to create your model.
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
```python
|
||||
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
model_config = ModelConfig(
|
||||
model_name_or_path="facebook/opt-350m"
|
||||
attn_implementation=None, # or "flash_attention_2"
|
||||
)
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
attn_implementation=model_config.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
||||
trainer = SFTTrainer(
|
||||
...,
|
||||
model=model_config.model_name_or_path,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
### Enhance model's performances using NEFTune
|
||||
|
||||
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
|
||||
@ -413,44 +519,48 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
|
||||
|
||||
### Accelerate fine-tuning 2x using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
|
||||
|
||||
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
|-----------------|-----------|-----|-------------------------|-----------------|----------------|
|
||||
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
|
||||
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
|
||||
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
|
||||
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import SFTTrainer
|
||||
from unsloth import FastLlamaModel, FastMistralModel
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
|
||||
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
|
||||
# Load Llama model
|
||||
model, tokenizer = FastLlamaModel.from_pretrained(
|
||||
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name = "unsloth/mistral-7b",
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
# Do model patching and add fast LoRA weights
|
||||
model = FastLlamaModel.get_peft_model(
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0, # Currently only supports dropout = 0
|
||||
bias = "none", # Currently only supports bias = "none"
|
||||
lora_dropout = 0, # Dropout = 0 is currently optimized
|
||||
bias = "none", # Bias = "none" is currently optimized
|
||||
use_gradient_checkpointing = True,
|
||||
random_state = 3407,
|
||||
max_seq_length = max_seq_length,
|
||||
)
|
||||
|
||||
args = TrainingArguments(output_dir="./output")
|
||||
args = TrainingArguments(output_dir = "./output")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
@ -459,7 +569,6 @@ trainer = SFTTrainer(
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = max_seq_length,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
@ -471,9 +580,27 @@ Pay attention to the following best practices when training a model with that tr
|
||||
|
||||
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
|
||||
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
|
||||
- Ensure that the model is placed on the correct device:
|
||||
```python
|
||||
from accelerate import PartialState
|
||||
device_string = PartialState().process_index
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
...
|
||||
device_map={'':device_string}
|
||||
)
|
||||
```
|
||||
|
||||
## GPTQ Conversion
|
||||
|
||||
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.
|
||||
|
||||
## SFTTrainer
|
||||
|
||||
[[autodoc]] SFTTrainer
|
||||
|
20
example_config.yaml
Normal file
20
example_config.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
# This is an example configuration file of TRL CLI, you can use it for
|
||||
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
|
||||
# The YAML file supports environment variables by adding an `env` field
|
||||
# as below
|
||||
|
||||
# env:
|
||||
# CUDA_VISIBLE_DEVICES: 0
|
||||
|
||||
model_name_or_path:
|
||||
HuggingFaceM4/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
16
examples/accelerate_configs/single_gpu.yaml
Normal file
16
examples/accelerate_configs/single_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: "NO"
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
122
examples/datasets/anthropic_hh.py
Normal file
122
examples/datasets/anthropic_hh.py
Normal file
@ -0,0 +1,122 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/anthropic_hh.py --debug --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(default="hh-rlhf-trl-style", metadata={"help": "The Hugging Face repository ID"})
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
|
||||
|
||||
# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
|
||||
def extract_dialogue(input_text):
|
||||
# Split the input by lines and initialize variables
|
||||
lines = input_text.strip().split("\n\n")
|
||||
dialogue_list = []
|
||||
|
||||
# Iterate through each line and extract the dialogue
|
||||
for line in lines:
|
||||
# Check if the line starts with "Human" or "Assistant" and split accordingly
|
||||
if line.startswith("Human:"):
|
||||
role = "user"
|
||||
content = line.replace("Human: ", "").strip()
|
||||
elif line.startswith("Assistant:"):
|
||||
role = "assistant"
|
||||
content = line.replace("Assistant: ", "").strip()
|
||||
else:
|
||||
# If the line doesn't start with "Human" or "Assistant", it's part of the previous message's content
|
||||
# Append it to the last message's content
|
||||
dialogue_list[-1]["content"] += "\n\n" + line.strip()
|
||||
continue
|
||||
|
||||
# Append the extracted dialogue piece to the list
|
||||
dialogue_list.append({"role": role, "content": content})
|
||||
|
||||
return dialogue_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
ds = load_dataset("Anthropic/hh-rlhf")
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = extract_dialogue(row["chosen"])
|
||||
row["rejected"] = extract_dialogue(row["rejected"])
|
||||
row["prompt"] = row["chosen"][0]["content"]
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's Anthropic HH Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
sft_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
113
examples/datasets/tldr_preference.py
Normal file
113
examples/datasets/tldr_preference.py
Normal file
@ -0,0 +1,113 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/tldr_preference.py --debug --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/tldr_preference.py --push_to_hub --hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(
|
||||
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
|
||||
)
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
|
||||
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
|
||||
if not args.debug:
|
||||
ds["validation_cnndm"] = ds["validation"].filter(lambda x: x["batch"] in cnndm_batches)
|
||||
ds["validation"] = ds["validation"].filter(lambda x: x["batch"] not in cnndm_batches)
|
||||
|
||||
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
|
||||
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
|
||||
|
||||
def process(row):
|
||||
format_str = cnndm_format_str if row["batch"] in cnndm_batches else tldr_format_str
|
||||
row["prompt"] = format_str.format(**row["info"])
|
||||
choice = row["choice"]
|
||||
chosen = row["summaries"][choice]["text"]
|
||||
rejected = row["summaries"][1 - choice]["text"]
|
||||
row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": chosen}]
|
||||
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
for key in ds: # reorder columns
|
||||
ds[key] = ds[key].select_columns(
|
||||
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
|
||||
)
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's TL;DR Preference Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
sft_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
42
examples/datasets/tokenize_ds.py
Normal file
42
examples/datasets/tokenize_ds.py
Normal file
@ -0,0 +1,42 @@
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
python -i examples/datasets/tokenize_ds.py --debug --model HuggingFaceH4/zephyr-7b-beta
|
||||
python -i examples/datasets/tokenize_ds.py --debug --model gpt2
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
dataset: str = field(default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The dataset to load"})
|
||||
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
ds = load_dataset(args.dataset)
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
print(ds["train"][0]["chosen"])
|
@ -12,7 +12,7 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {"batch_size": 1}
|
||||
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
|
||||
|
||||
@ -29,7 +29,7 @@ generation_kwargs = {
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"max_new_tokens": 20,
|
||||
}
|
||||
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
|
||||
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
|
||||
response_txt = tokenizer.decode(response_tensor[0])
|
||||
|
||||
# 5. define a reward for response
|
||||
|
@ -1,7 +1,7 @@
|
||||
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
|
||||
There were three main steps to the training process:
|
||||
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --no_gradient_checkpointing --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
|
||||
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
|
||||
|
@ -15,6 +15,7 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
@ -89,11 +90,14 @@ class ScriptArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
# Load the human stack-exchange-paired dataset for tuning the reward model.
|
||||
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
|
||||
if script_args.train_subset > 0:
|
||||
@ -129,7 +133,10 @@ training_args = TrainingArguments(
|
||||
logging_steps=10,
|
||||
optim=script_args.optim,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
|
||||
# Load the value-head model and tokenizer.
|
||||
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -32,7 +31,7 @@ tqdm.pandas()
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with PPO
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
@ -67,6 +66,7 @@ class ScriptArguments:
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
@ -163,7 +163,7 @@ dataset = build_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
@ -181,7 +181,7 @@ lora_config = LoraConfig(
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=script_args.load_in_8bit,
|
||||
device_map={"": current_device},
|
||||
peft_config=lora_config,
|
||||
)
|
||||
@ -216,11 +216,13 @@ sentiment_pipe = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=reward_model_name,
|
||||
device_map={"": current_device},
|
||||
model_kwargs={"load_in_8bit": True},
|
||||
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
|
||||
tokenizer=tokenizer,
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
|
@ -38,9 +38,9 @@ def get_args():
|
||||
parser.add_argument("--weight_decay", type=float, default=0.05)
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--no_fp16", action="store_false")
|
||||
parser.add_argument("--fp16", action="store_true", default=False)
|
||||
parser.add_argument("--bf16", action="store_true", default=False)
|
||||
parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False)
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num_workers", type=int, default=None)
|
||||
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
||||
@ -159,8 +159,8 @@ def run_training(args, train_data, val_data):
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
warmup_steps=args.num_warmup_steps,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=not args.no_gradient_checkpointing,
|
||||
fp16=not args.no_fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
fp16=args.fp16,
|
||||
bf16=args.bf16,
|
||||
weight_decay=args.weight_decay,
|
||||
run_name="llama-7b-finetuned",
|
||||
|
@ -4,9 +4,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
@ -41,6 +42,10 @@ class ScriptArguments:
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
@ -54,6 +59,10 @@ class ScriptArguments:
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
|
||||
@ -73,12 +82,15 @@ class ScriptArguments:
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
sanity_check: bool = False,
|
||||
cache_dir: str = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
@ -123,12 +135,21 @@ if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
|
||||
# 1. load a pretrained model
|
||||
torch_dtype = torch.float
|
||||
if script_args.model_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif script_args.model_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
torch_dtype=torch_dtype,
|
||||
load_in_4bit=script_args.load_in_4bit,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
@ -138,12 +159,6 @@ if __name__ == "__main__":
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
@ -181,6 +196,8 @@ if __name__ == "__main__":
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
@ -203,7 +220,7 @@ if __name__ == "__main__":
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
|
@ -8,7 +8,14 @@ from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
@ -27,6 +34,7 @@ class ScriptArguments:
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
|
||||
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
|
||||
# LoraConfig
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
@ -53,6 +61,8 @@ if training_args.group_by_length and script_args.packing:
|
||||
if training_args.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing not supported")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
@ -91,7 +101,7 @@ def prepare_sample_text(example):
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args):
|
||||
def create_datasets(tokenizer, args, seed=None):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
@ -104,9 +114,9 @@ def create_datasets(tokenizer, args):
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=None)
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
@ -133,11 +143,13 @@ def create_datasets(tokenizer, args):
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
bnb_config = None
|
||||
if script_args.use_bnb:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
@ -153,7 +165,7 @@ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_c
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
||||
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=base_model,
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -107,7 +106,7 @@ text_env = TextEnvironment(
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for step in range(100):
|
||||
for _step in range(100):
|
||||
tasks, answers = generate_data(ppo_config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -61,9 +60,9 @@ def exact_match_reward(responses, answers=None):
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs((predicted_number - float(answer))) < 0.1:
|
||||
if np.abs(predicted_number - float(answer)) < 0.1:
|
||||
reward += 1.0
|
||||
except: # noqa
|
||||
except Exception:
|
||||
pass
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -114,7 +113,7 @@ dataset = dataset.shuffle(local_seed)
|
||||
|
||||
def data_generator():
|
||||
for i in range(len(dataset)):
|
||||
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
|
||||
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
|
||||
|
||||
|
||||
gen = data_generator()
|
||||
@ -123,7 +122,7 @@ gen = iter(gen)
|
||||
|
||||
def generate_data(n):
|
||||
tasks, answers = [], []
|
||||
for i in range(n):
|
||||
for _i in range(n):
|
||||
q, a = next(gen)
|
||||
tasks.append(q)
|
||||
answers.append(a)
|
||||
@ -143,10 +142,14 @@ def exact_match_reward(responses, answers=None):
|
||||
return rewards
|
||||
|
||||
|
||||
def tool_fn(x):
|
||||
# limit the amount of tokens
|
||||
return tool(x).split("\n")[1][:600]
|
||||
|
||||
|
||||
# text env
|
||||
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
# limit the amount if tokens
|
||||
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa
|
||||
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -184,8 +187,6 @@ for i in range(args.iterations):
|
||||
"answer": [", ".join(item) for item in answers],
|
||||
}
|
||||
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
|
||||
ppo_trainer.log_stats(
|
||||
train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"]
|
||||
)
|
||||
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
|
||||
if i % 100 == 0:
|
||||
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -59,7 +58,7 @@ tqdm.pandas()
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with PPO
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
@ -146,7 +145,7 @@ dataset = build_dataset(config, input_min_text_length=min_input_length, input_ma
|
||||
|
||||
|
||||
def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
@ -218,7 +217,7 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
response_tensors.append(response.squeeze()[-gen_len:])
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
# Compute sentiment score # noqa
|
||||
# Compute sentiment score
|
||||
texts = batch["response"]
|
||||
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
|
||||
ppo_trainer.accelerator.device
|
||||
|
338
examples/scripts/chat.py
Normal file
338
examples/scripts/chat.py
Normal file
@ -0,0 +1,338 @@
|
||||
# flake8: noqa
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from trl.commands.cli_utils import init_zero_verbose
|
||||
|
||||
init_zero_verbose()
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import pwd
|
||||
import re
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose
|
||||
from trl.trainer.utils import get_kbit_device_map, get_quantization_config
|
||||
|
||||
|
||||
HELP_STRING = """\
|
||||
|
||||
**TRL CHAT INTERFACE**
|
||||
|
||||
The chat interface is a simple tool to try out a chat model.
|
||||
|
||||
Besides talking to the model there are several commands:
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
"""
|
||||
|
||||
SUPPORTED_GENERATION_KWARGS = [
|
||||
"max_new_tokens",
|
||||
"do_sample",
|
||||
"num_beams",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
]
|
||||
|
||||
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name=None, user_name=None):
|
||||
self._console = Console()
|
||||
if model_name is None:
|
||||
self.model_name = "assistant"
|
||||
else:
|
||||
self.model_name = model_name
|
||||
if user_name is None:
|
||||
self.user_name = "user"
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output from a role."""
|
||||
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
# Read lines from the stream
|
||||
for i, outputs in enumerate(output_stream):
|
||||
if not outputs or i == 0:
|
||||
continue
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
def input(self):
|
||||
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
|
||||
self._console.print()
|
||||
return input
|
||||
|
||||
def clear(self):
|
||||
self._console.clear()
|
||||
|
||||
def print_user_message(self, text):
|
||||
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_green(self, text):
|
||||
self._console.print(f"[bold green]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_red(self, text):
|
||||
self._console.print(f"[bold red]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_help(self):
|
||||
self._console.print(Markdown(HELP_STRING))
|
||||
self._console.print()
|
||||
|
||||
|
||||
def get_username():
|
||||
return pwd.getpwuid(os.getuid())[0]
|
||||
|
||||
|
||||
def create_default_filename(model_name):
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{model_name}/chat_{time_str}.json"
|
||||
|
||||
|
||||
def save_chat(chat, args, filename):
|
||||
output_dict = {}
|
||||
output_dict["settings"] = vars(args)
|
||||
output_dict["chat_history"] = chat
|
||||
|
||||
folder = args.save_folder
|
||||
|
||||
if filename is None:
|
||||
filename = create_default_filename(args.model_name_or_path)
|
||||
filename = os.path.join(folder, filename)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_dict, f, indent=4)
|
||||
return os.path.abspath(filename)
|
||||
|
||||
|
||||
def clear_chat_history(system_prompt):
|
||||
if system_prompt is None:
|
||||
chat = []
|
||||
else:
|
||||
chat = [{"role": "system", "content": system_prompt}]
|
||||
return chat
|
||||
|
||||
|
||||
def parse_settings(user_input, current_args, interface):
|
||||
settings = user_input[4:].strip().split(";")
|
||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||
settings = dict(settings)
|
||||
error = False
|
||||
|
||||
for name in settings:
|
||||
if hasattr(current_args, name):
|
||||
try:
|
||||
if isinstance(getattr(current_args, name), bool):
|
||||
if settings[name] == "True":
|
||||
settings[name] = True
|
||||
elif settings[name] == "False":
|
||||
settings[name] = False
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||
except ValueError:
|
||||
interface.print_red(
|
||||
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
|
||||
)
|
||||
else:
|
||||
interface.print_red(f"There is no '{name}' setting.")
|
||||
|
||||
if error:
|
||||
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
|
||||
return current_args, False
|
||||
else:
|
||||
for name in settings:
|
||||
setattr(current_args, name, settings[name])
|
||||
interface.print_green(f"Set {name} to {settings[name]}.")
|
||||
|
||||
time.sleep(1.5) # so the user has time to read the changes
|
||||
return current_args, True
|
||||
|
||||
|
||||
def load_model_and_tokenizer(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = get_quantization_config(args)
|
||||
model_kwargs = dict(
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
attn_implementation=args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def chat_cli():
|
||||
parser = TrlParser(ChatArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
if args.config == "default":
|
||||
args.config = os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml")
|
||||
if args.config.lower() == "none":
|
||||
args.config = None
|
||||
args = parser.update_dataclasses_with_config([args])[0]
|
||||
if args.examples is None:
|
||||
args.examples = {}
|
||||
|
||||
current_args = copy.deepcopy(args)
|
||||
|
||||
if args.user is None:
|
||||
user = get_username()
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
while True:
|
||||
try:
|
||||
user_input = interface.input()
|
||||
|
||||
if user_input == "clear":
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input == "help":
|
||||
interface.print_help()
|
||||
continue
|
||||
|
||||
if user_input == "exit":
|
||||
break
|
||||
|
||||
if user_input == "reset":
|
||||
interface.clear()
|
||||
current_args = copy.deepcopy(args)
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
continue
|
||||
|
||||
if user_input.startswith("save") and len(user_input.split()) < 2:
|
||||
split_input = user_input.split()
|
||||
|
||||
if len(split_input) == 2:
|
||||
filename = split_input[1]
|
||||
else:
|
||||
filename = None
|
||||
filename = save_chat(chat, current_args, filename)
|
||||
interface.print_green(f"Chat saved in {filename}!")
|
||||
continue
|
||||
|
||||
if re.match(SETTING_RE, user_input):
|
||||
current_args, success = parse_settings(user_input, current_args, interface)
|
||||
if success:
|
||||
chat = []
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input.startswith("example") and len(user_input.split()) == 2:
|
||||
example_name = user_input.split()[1]
|
||||
if example_name in current_args.examples:
|
||||
interface.clear()
|
||||
chat = []
|
||||
interface.print_user_message(current_args.examples[example_name]["text"])
|
||||
user_input = current_args.examples[example_name]["text"]
|
||||
else:
|
||||
interface.print_red(
|
||||
f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}."
|
||||
)
|
||||
continue
|
||||
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
generation_kwargs = dict(
|
||||
inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
),
|
||||
streamer=generation_streamer,
|
||||
max_new_tokens=current_args.max_new_tokens,
|
||||
do_sample=current_args.do_sample,
|
||||
num_beams=current_args.num_beams,
|
||||
temperature=current_args.temperature,
|
||||
top_k=current_args.top_k,
|
||||
top_p=current_args.top_p,
|
||||
repetition_penalty=current_args.repetition_penalty,
|
||||
)
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
model_output = interface.stream_output(generation_streamer)
|
||||
thread.join()
|
||||
chat.append({"role": "assistant", "content": model_output})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat_cli()
|
13
examples/scripts/config/default_chat_config.yaml
Normal file
13
examples/scripts/config/default_chat_config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
examples:
|
||||
llama:
|
||||
text: There is a Llama in my lawn, how can I get rid of it?
|
||||
code:
|
||||
text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].
|
||||
helicopter:
|
||||
text: How many helicopters can a human eat in one sitting?
|
||||
numbers:
|
||||
text: Count to 10 but skip every number ending with an 'e'
|
||||
birds:
|
||||
text: Why aren't birds real?
|
||||
socks:
|
||||
text: Why is it important to eat socks after meditating?
|
@ -11,18 +11,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
python examples/scripts/ddpo.py \
|
||||
--num_epochs=200 \
|
||||
--train_gradient_accumulation_steps=1 \
|
||||
--sample_num_steps=50 \
|
||||
--sample_batch_size=6 \
|
||||
--train_batch_size=3 \
|
||||
--sample_num_batches_per_epoch=4 \
|
||||
--per_prompt_stat_tracking=True \
|
||||
--per_prompt_stat_tracking_buffer_size=32 \
|
||||
--tracker_project_name="stable_diffusion_training" \
|
||||
--log_with="wandb"
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tyro
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser
|
||||
|
||||
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
@ -30,38 +40,22 @@ from trl.import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
hf_user_access_token: str
|
||||
pretrained_model: str = "runwayml/stable-diffusion-v1-5"
|
||||
"""the pretrained model to use"""
|
||||
pretrained_revision: str = "main"
|
||||
"""the pretrained model revision to use"""
|
||||
hf_hub_model_id: str = "ddpo-finetuned-stable-diffusion"
|
||||
"""HuggingFace repo to save model weights to"""
|
||||
hf_hub_aesthetic_model_id: str = "trl-lib/ddpo-aesthetic-predictor"
|
||||
"""HuggingFace model ID for aesthetic scorer model weights"""
|
||||
hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth"
|
||||
"""HuggingFace model filename for aesthetic scorer model weights"""
|
||||
|
||||
ddpo_config: DDPOConfig = field(
|
||||
default_factory=lambda: DDPOConfig(
|
||||
num_epochs=200,
|
||||
train_gradient_accumulation_steps=1,
|
||||
sample_num_steps=50,
|
||||
sample_batch_size=6,
|
||||
train_batch_size=3,
|
||||
sample_num_batches_per_epoch=4,
|
||||
per_prompt_stat_tracking=True,
|
||||
per_prompt_stat_tracking_buffer_size=32,
|
||||
tracker_project_name="stable_diffusion_training",
|
||||
log_with="wandb",
|
||||
project_kwargs={
|
||||
"logging_dir": "./logs",
|
||||
"automatic_checkpoint_naming": True,
|
||||
"total_limit": 5,
|
||||
"project_dir": "./save",
|
||||
},
|
||||
)
|
||||
pretrained_model: str = field(
|
||||
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
|
||||
)
|
||||
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
|
||||
hf_hub_model_id: str = field(
|
||||
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
|
||||
)
|
||||
hf_hub_aesthetic_model_id: str = field(
|
||||
default="trl-lib/ddpo-aesthetic-predictor",
|
||||
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
|
||||
)
|
||||
hf_hub_aesthetic_model_filename: str = field(
|
||||
default="aesthetic-model.pth",
|
||||
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
|
||||
)
|
||||
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@ -99,7 +93,7 @@ class AestheticScorer(torch.nn.Module):
|
||||
cached_path = hf_hub_download(model_id, model_filename)
|
||||
except EntryNotFoundError:
|
||||
cached_path = os.path.join(model_id, model_filename)
|
||||
state_dict = torch.load(cached_path)
|
||||
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
|
||||
self.mlp.load_state_dict(state_dict)
|
||||
self.dtype = dtype
|
||||
self.eval()
|
||||
@ -181,7 +175,7 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
|
||||
for i, image in enumerate(images):
|
||||
prompt = prompts[i]
|
||||
reward = rewards[i].item()
|
||||
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
|
||||
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()
|
||||
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
@ -190,14 +184,21 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = tyro.cli(ScriptArguments)
|
||||
parser = HfArgumentParser((ScriptArguments, DDPOConfig))
|
||||
args, ddpo_config = parser.parse_args_into_dataclasses()
|
||||
ddpo_config.project_kwargs = {
|
||||
"logging_dir": "./logs",
|
||||
"automatic_checkpoint_naming": True,
|
||||
"total_limit": 5,
|
||||
"project_dir": "./save",
|
||||
}
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline(
|
||||
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=True
|
||||
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
|
||||
)
|
||||
|
||||
trainer = DDPOTrainer(
|
||||
args.ddpo_config,
|
||||
ddpo_config,
|
||||
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
|
||||
prompt_fn,
|
||||
pipeline,
|
||||
@ -206,4 +207,4 @@ if __name__ == "__main__":
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token)
|
||||
trainer.push_to_hub(args.hf_hub_model_id)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# coding=utf-8
|
||||
# flake8: noqa
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -12,187 +12,155 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
# regular:
|
||||
python examples/scripts/dpo.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 1e-3 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="dpo_anthropic_hh" \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns
|
||||
|
||||
# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source
|
||||
# TODO: bump transformers version in requirements at next release.
|
||||
# peft:
|
||||
python examples/scripts/dpo.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 1e-3 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="dpo_anthropic_hh" \
|
||||
--optim rmsprop \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns \
|
||||
--use_peft \
|
||||
--lora_r=16 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
# 0. imports
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
|
||||
|
||||
from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser
|
||||
|
||||
if TRL_USE_RICH:
|
||||
init_zero_verbose()
|
||||
FORMAT = "%(message)s"
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
|
||||
from trl import DPOTrainer
|
||||
from trl import (
|
||||
DPOTrainer,
|
||||
ModelConfig,
|
||||
RichProgressCallback,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"})
|
||||
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
|
||||
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
|
||||
)
|
||||
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
# lora parameters
|
||||
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
|
||||
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
|
||||
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
|
||||
# instrumentation
|
||||
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
|
||||
report_to: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
|
||||
)
|
||||
gradient_checkpointing_kwargs: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def extract_anthropic_prompt(prompt_and_response):
|
||||
"""Extract the anthropic prompt from a prompt and response pair."""
|
||||
search_term = "\n\nAssistant:"
|
||||
search_term_idx = prompt_and_response.rfind(search_term)
|
||||
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
|
||||
return prompt_and_response[: search_term_idx + len(search_term)]
|
||||
|
||||
|
||||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
|
||||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': List[str],
|
||||
'chosen': List[str],
|
||||
'rejected': List[str],
|
||||
}
|
||||
|
||||
Prompts should be structured as follows:
|
||||
\n\nHuman: <prompt>\n\nAssistant:
|
||||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
|
||||
"""
|
||||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
|
||||
if sanity_check:
|
||||
dataset = dataset.select(range(min(len(dataset), 1000)))
|
||||
|
||||
def split_prompt_and_responses(sample) -> Dict[str, str]:
|
||||
prompt = extract_anthropic_prompt(sample["chosen"])
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": sample["chosen"][len(prompt) :],
|
||||
"rejected": sample["rejected"][len(prompt) :],
|
||||
}
|
||||
|
||||
return dataset.map(split_prompt_and_responses)
|
||||
if TRL_USE_RICH:
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_and_config()
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
|
||||
# Force use our print callback
|
||||
if TRL_USE_RICH:
|
||||
training_args.disable_tqdm = True
|
||||
console = Console()
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
attn_implementation=model_config.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
||||
peft_config = get_peft_config(model_config)
|
||||
if peft_config is None:
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
||||
else:
|
||||
model_ref = None
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Anthropic Helpful-Harmless dataset
|
||||
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
learning_rate=script_args.learning_rate,
|
||||
evaluation_strategy="steps",
|
||||
logging_first_step=True,
|
||||
logging_steps=10, # match results in blog post
|
||||
eval_steps=500,
|
||||
output_dir="./test",
|
||||
optim="rmsprop",
|
||||
warmup_steps=150,
|
||||
report_to=script_args.report_to,
|
||||
bf16=True,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
# TODO: uncomment that on the next transformers release
|
||||
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
|
||||
################
|
||||
# Optional rich context managers
|
||||
###############
|
||||
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
|
||||
save_context = (
|
||||
nullcontext()
|
||||
if not TRL_USE_RICH
|
||||
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
|
||||
)
|
||||
|
||||
if script_args.use_peft:
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.peft_lora_r,
|
||||
lora_alpha=script_args.peft_lora_alpha,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
train_dataset = load_dataset(args.dataset_name, split="train")
|
||||
eval_dataset = load_dataset(args.dataset_name, split="test")
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
with init_context:
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
beta=args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
max_target_length=args.max_target_length,
|
||||
max_prompt_length=args.max_prompt_length,
|
||||
generate_during_eval=args.generate_during_eval,
|
||||
peft_config=get_peft_config(model_config),
|
||||
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
|
||||
)
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_length=script_args.max_length,
|
||||
max_target_length=script_args.max_target_length,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
generate_during_eval=True,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
||||
with save_context:
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
152
examples/scripts/kto.py
Normal file
152
examples/scripts/kto.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Run the KTO training script with the following command with some example arguments.
|
||||
In general, the optimal configuration for KTO will be similar to that of DPO:
|
||||
|
||||
# regular:
|
||||
python examples/scripts/kto.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 1e-3 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="kto_anthropic_hh" \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns
|
||||
|
||||
# peft:
|
||||
python examples/scripts/kto.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 1e-3 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="kto_anthropic_hh" \
|
||||
--optim rmsprop \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns \
|
||||
--use_peft \
|
||||
--lora_r=16 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the KTO training script.
|
||||
"""
|
||||
|
||||
# debugging
|
||||
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
|
||||
|
||||
|
||||
def extract_anthropic_prompt(prompt_and_response):
|
||||
"""Extract the anthropic prompt from a prompt and response pair."""
|
||||
search_term = "\n\nAssistant:"
|
||||
search_term_idx = prompt_and_response.rfind(search_term)
|
||||
|
||||
if search_term_idx == -1:
|
||||
raise ValueError(f"Prompt and response does not contain '{search_term}'")
|
||||
|
||||
return prompt_and_response[: search_term_idx + len(search_term)]
|
||||
|
||||
|
||||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
|
||||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': List[str],
|
||||
'completion': List[str],
|
||||
'label': List[bool],
|
||||
}
|
||||
|
||||
Prompts should be structured as follows:
|
||||
\n\nHuman: <prompt>\n\nAssistant:
|
||||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
|
||||
"""
|
||||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
|
||||
if sanity_check:
|
||||
dataset = dataset.select(range(min(len(dataset), 1000)))
|
||||
|
||||
flat_data = {
|
||||
"prompt": [],
|
||||
"completion": [],
|
||||
"label": [],
|
||||
}
|
||||
for sample in dataset:
|
||||
prompt = extract_anthropic_prompt(sample["chosen"])
|
||||
flat_data["prompt"].append(prompt)
|
||||
flat_data["completion"].append(sample["chosen"][len(prompt) :])
|
||||
flat_data["label"].append(True)
|
||||
flat_data["prompt"].append(prompt)
|
||||
flat_data["completion"].append(sample["rejected"][len(prompt) :])
|
||||
flat_data["label"].append(False)
|
||||
|
||||
return dataset.from_dict(flat_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
|
||||
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Anthropic Helpful-Harmless dataset
|
||||
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
|
||||
|
||||
# 4. initialize the KTO trainer
|
||||
kto_trainer = KTOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=kto_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# 5. train
|
||||
kto_trainer.train()
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -12,16 +11,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
python examples/scripts/ppo.py \
|
||||
--log_with=wandb
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
from transformers import AutoTokenizer, HfArgumentParser, pipeline
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
||||
from trl.core import LengthSampler
|
||||
@ -33,42 +35,17 @@ tqdm.pandas()
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
ppo_config: PPOConfig = field(
|
||||
default_factory=lambda: PPOConfig(
|
||||
model_name="lvwerra/gpt2-imdb",
|
||||
query_dataset="imdb",
|
||||
reward_model="sentiment-analysis:lvwerra/distilbert-imdb",
|
||||
learning_rate=1.41e-5,
|
||||
log_with=None,
|
||||
mini_batch_size=128,
|
||||
batch_size=128,
|
||||
gradient_accumulation_steps=1,
|
||||
early_stopping=False,
|
||||
target_kl=6.0,
|
||||
kl_penalty="kl",
|
||||
seed=0,
|
||||
use_score_scaling=False,
|
||||
use_score_norm=False,
|
||||
score_clip=None,
|
||||
)
|
||||
)
|
||||
use_seq2seq: bool = False
|
||||
"""whether to use seq2seq models"""
|
||||
use_peft: bool = False
|
||||
"""whether to use peft"""
|
||||
peft_config: Optional[LoraConfig] = field(
|
||||
default_factory=lambda: LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=16,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
),
|
||||
)
|
||||
use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"})
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
||||
|
||||
# LoraConfig
|
||||
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})
|
||||
|
||||
args = tyro.cli(ScriptArguments)
|
||||
|
||||
parser = HfArgumentParser((ScriptArguments, PPOConfig))
|
||||
args, ppo_config = parser.parse_args_into_dataclasses()
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
@ -113,42 +90,47 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
dataset = build_dataset(args.ppo_config, args.ppo_config.query_dataset)
|
||||
dataset = build_dataset(ppo_config, ppo_config.query_dataset)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(args.ppo_config.seed)
|
||||
set_seed(ppo_config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer.
|
||||
if not args.use_peft:
|
||||
ref_model = trl_model_class.from_pretrained(args.ppo_config.model_name, trust_remote_code=args.trust_remote_code)
|
||||
ref_model = trl_model_class.from_pretrained(ppo_config.model_name, trust_remote_code=args.trust_remote_code)
|
||||
device_map = None
|
||||
peft_config = None
|
||||
else:
|
||||
peft_config = args.peft_config
|
||||
peft_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
ref_model = None
|
||||
# Copy the model to each device
|
||||
device_map = {"": Accelerator().local_process_index}
|
||||
|
||||
model = trl_model_class.from_pretrained(
|
||||
args.ppo_config.model_name,
|
||||
ppo_config.model_name,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
device_map=device_map,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.ppo_config.model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
|
||||
|
||||
# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(args.ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
|
||||
|
||||
# We then build the sentiment analysis pipeline, passing the model name and the
|
||||
# sentiment analysis pipeline arguments. Let's also make sure to set the device
|
||||
@ -162,7 +144,7 @@ if ppo_trainer.accelerator.num_processes == 1:
|
||||
else:
|
||||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
task, model_name = args.ppo_config.reward_model.split(":")
|
||||
task, model_name = ppo_config.reward_model.split(":")
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline(task, model=model_name, device=device)
|
||||
@ -188,7 +170,7 @@ generation_kwargs = {
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# Get response from gpt2
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,8 +20,9 @@ from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
|
||||
input_min_text_length = 6
|
||||
@ -82,7 +82,7 @@ nf4_config = BitsAndBytesConfig(
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
device_map={"": "xpu:0"} if is_xpu_available() else {"": 0},
|
||||
device_map={"": "xpu:0"} if is_xpu_available() else {"": "npu:0"} if is_npu_available else {"": 0},
|
||||
peft_config=lora_config,
|
||||
quantization_config=nf4_config,
|
||||
reward_adapter=script_args.rm_adapter,
|
||||
@ -96,7 +96,7 @@ dataset = create_and_prepare_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
config = PPOConfig(
|
||||
@ -130,7 +130,7 @@ generation_kwargs = {
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
response_tensors = ppo_trainer.generate(
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -12,162 +11,114 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
"""
|
||||
python examples/scripts/reward_modeling.py \
|
||||
--model_name_or_path=facebook/opt-350m \
|
||||
--output_dir="reward_modeling_anthropic_hh" \
|
||||
--per_device_train_batch_size=64 \
|
||||
--num_train_epochs=1 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--gradient_checkpointing=True \
|
||||
--learning_rate=1.41e-5 \
|
||||
--report_to="wandb" \
|
||||
--remove_unused_columns=False \
|
||||
--optim="adamw_torch" \
|
||||
--logging_steps=10 \
|
||||
--evaluation_strategy="steps" \
|
||||
--max_length=512 \
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import tyro
|
||||
from accelerate import Accelerator
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl import RewardConfig, RewardTrainer, is_xpu_available
|
||||
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: str = "facebook/opt-350m"
|
||||
"""the model name"""
|
||||
dataset_name: str = "Anthropic/hh-rlhf"
|
||||
"""the dataset name"""
|
||||
dataset_text_field: str = "text"
|
||||
"""the text field of the dataset"""
|
||||
eval_split: str = "none"
|
||||
"""the dataset split to evaluate on; default to 'none' (no evaluation)"""
|
||||
load_in_8bit: bool = False
|
||||
"""load the model in 8 bits precision"""
|
||||
load_in_4bit: bool = False
|
||||
"""load the model in 4 bits precision"""
|
||||
trust_remote_code: bool = True
|
||||
"""Enable `trust_remote_code`"""
|
||||
reward_config: RewardConfig = field(
|
||||
default_factory=lambda: RewardConfig(
|
||||
output_dir="output",
|
||||
per_device_train_batch_size=64,
|
||||
num_train_epochs=1,
|
||||
gradient_accumulation_steps=16,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
learning_rate=1.41e-5,
|
||||
report_to="tensorboard",
|
||||
remove_unused_columns=False,
|
||||
optim="adamw_torch",
|
||||
logging_steps=500,
|
||||
evaluation_strategy="no",
|
||||
max_length=512,
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((RewardConfig, ModelConfig))
|
||||
reward_config, model_config = parser.parse_args_into_dataclasses()
|
||||
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_config.model_name_or_path, num_labels=1, **model_kwargs
|
||||
)
|
||||
|
||||
if model_config.lora_task_type != "SEQ_CLS":
|
||||
warnings.warn(
|
||||
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
|
||||
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
|
||||
)
|
||||
)
|
||||
use_peft: bool = False
|
||||
"""whether to use peft"""
|
||||
peft_config: Optional[LoraConfig] = field(
|
||||
default_factory=lambda: LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=16,
|
||||
bias="none",
|
||||
task_type="SEQ_CLS",
|
||||
modules_to_save=["scores"],
|
||||
),
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
raw_datasets = load_dataset("Anthropic/hh-rlhf")
|
||||
# Tokenize chosen/rejected pairs of inputs
|
||||
# Adapt this section to your needs for custom datasets
|
||||
|
||||
args = tyro.cli(ScriptArguments)
|
||||
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"input_ids_chosen": [],
|
||||
"attention_mask_chosen": [],
|
||||
"input_ids_rejected": [],
|
||||
"attention_mask_rejected": [],
|
||||
}
|
||||
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
|
||||
tokenized_chosen = tokenizer(chosen)
|
||||
tokenized_rejected = tokenizer(rejected)
|
||||
|
||||
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
|
||||
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
|
||||
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
|
||||
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
|
||||
|
||||
# Step 1: Load the model
|
||||
if args.load_in_8bit and args.load_in_4bit:
|
||||
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
|
||||
elif args.load_in_8bit or args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
|
||||
# Copy the model to each device
|
||||
device_map = (
|
||||
{"": f"xpu:{Accelerator().local_process_index}"}
|
||||
if is_xpu_available()
|
||||
else {"": Accelerator().local_process_index}
|
||||
)
|
||||
else:
|
||||
device_map = None
|
||||
quantization_config = None
|
||||
return new_examples
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name,
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
num_labels=1,
|
||||
)
|
||||
|
||||
# Step 2: Load the dataset and pre-process it
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
train_dataset = load_dataset(args.dataset_name, split="train")
|
||||
|
||||
|
||||
# Tokenize chosen/rejected pairs of inputs
|
||||
# Adapt this section to your needs for custom datasets
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"input_ids_chosen": [],
|
||||
"attention_mask_chosen": [],
|
||||
"input_ids_rejected": [],
|
||||
"attention_mask_rejected": [],
|
||||
}
|
||||
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
|
||||
tokenized_chosen = tokenizer(chosen)
|
||||
tokenized_rejected = tokenizer(rejected)
|
||||
|
||||
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
|
||||
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
|
||||
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
|
||||
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
|
||||
|
||||
return new_examples
|
||||
|
||||
|
||||
# Preprocess the dataset and filter out examples that are longer than args.max_length
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=4,
|
||||
)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
|
||||
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
|
||||
)
|
||||
|
||||
if args.eval_split == "none":
|
||||
eval_dataset = None
|
||||
else:
|
||||
eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
# Preprocess the dataset and filter out examples that are longer than args.max_length
|
||||
raw_datasets = raw_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=4,
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
|
||||
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
|
||||
raw_datasets = raw_datasets.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length
|
||||
and len(x["input_ids_rejected"]) <= reward_config.max_length
|
||||
)
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
|
||||
# Step 4: Define the LoraConfig
|
||||
if args.use_peft:
|
||||
peft_config = args.peft_config
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
# Step 5: Define the Trainer
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=args.reward_config,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=reward_config,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(reward_config.output_dir)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# coding=utf-8
|
||||
# flake8: noqa
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -12,147 +12,140 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
"""
|
||||
# regular:
|
||||
python examples/scripts/sft.py \
|
||||
--model_name_or_path="facebook/opt-350m" \
|
||||
--report_to="wandb" \
|
||||
--learning_rate=1.41e-5 \
|
||||
--per_device_train_batch_size=64 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--output_dir="sft_openassistant-guanaco" \
|
||||
--logging_steps=1 \
|
||||
--num_train_epochs=3 \
|
||||
--max_steps=-1 \
|
||||
--push_to_hub \
|
||||
--gradient_checkpointing \
|
||||
|
||||
# peft:
|
||||
python examples/scripts/sft.py \
|
||||
--model_name_or_path="facebook/opt-350m" \
|
||||
--report_to="wandb" \
|
||||
--learning_rate=1.41e-5 \
|
||||
--per_device_train_batch_size=64 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--output_dir="sft_openassistant-guanaco" \
|
||||
--logging_steps=1 \
|
||||
--num_train_epochs=3 \
|
||||
--max_steps=-1 \
|
||||
--push_to_hub \
|
||||
--gradient_checkpointing \
|
||||
--use_peft \
|
||||
--lora_r=64 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
|
||||
|
||||
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
|
||||
|
||||
if TRL_USE_RICH:
|
||||
init_zero_verbose()
|
||||
FORMAT = "%(message)s"
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
|
||||
|
||||
from trl import SFTTrainer, is_xpu_available
|
||||
from tqdm.rich import tqdm
|
||||
from transformers import AutoTokenizer, TrainingArguments
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
RichProgressCallback,
|
||||
SFTTrainer,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_kbit_device_map,
|
||||
)
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
if TRL_USE_RICH:
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with SFTTrainer
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(
|
||||
default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"}
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_and_config()
|
||||
|
||||
# Force use our print callback
|
||||
if TRL_USE_RICH:
|
||||
training_args.disable_tqdm = True
|
||||
console = Console()
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
|
||||
report_to: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
||||
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
|
||||
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
attn_implementation=model_config.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
|
||||
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
|
||||
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
|
||||
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
||||
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
|
||||
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
|
||||
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
|
||||
logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"})
|
||||
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
|
||||
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
|
||||
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
|
||||
save_steps: Optional[int] = field(
|
||||
default=100, metadata={"help": "Number of updates steps before two checkpoint saves"}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
raw_datasets = load_dataset(args.dataset_name)
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
################
|
||||
# Optional rich context managers
|
||||
###############
|
||||
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
|
||||
save_context = (
|
||||
nullcontext()
|
||||
if not TRL_USE_RICH
|
||||
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
|
||||
)
|
||||
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
|
||||
)
|
||||
gradient_checkpointing_kwargs: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
|
||||
},
|
||||
)
|
||||
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
|
||||
mixed_precision: Optional[str] = field(default="bf16", metadata={"help": "Mixed precision training"})
|
||||
target_modules: Optional[List[str]] = field(default=None, metadata={"help": "Target modules for LoRA adapters"})
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
with init_context:
|
||||
trainer = SFTTrainer(
|
||||
model=model_config.model_name_or_path,
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
dataset_text_field=args.dataset_text_field,
|
||||
max_seq_length=args.max_seq_length,
|
||||
tokenizer=tokenizer,
|
||||
packing=args.packing,
|
||||
peft_config=get_peft_config(model_config),
|
||||
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
|
||||
)
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
trainer.train()
|
||||
|
||||
# Step 1: Load the model
|
||||
if script_args.load_in_8bit and script_args.load_in_4bit:
|
||||
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
|
||||
elif script_args.load_in_8bit or script_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
|
||||
)
|
||||
# Copy the model to each device
|
||||
device_map = (
|
||||
{"": f"xpu:{Accelerator().local_process_index}"}
|
||||
if is_xpu_available()
|
||||
else {"": Accelerator().local_process_index}
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
device_map = None
|
||||
quantization_config = None
|
||||
torch_dtype = None
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
trust_remote_code=script_args.trust_remote_code,
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=script_args.use_auth_token,
|
||||
)
|
||||
|
||||
# Step 2: Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, split="train")
|
||||
|
||||
# Step 3: Define the training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=script_args.output_dir,
|
||||
per_device_train_batch_size=script_args.batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
learning_rate=script_args.learning_rate,
|
||||
logging_steps=script_args.logging_steps,
|
||||
num_train_epochs=script_args.num_train_epochs,
|
||||
max_steps=script_args.max_steps,
|
||||
report_to=script_args.report_to,
|
||||
save_steps=script_args.save_steps,
|
||||
save_total_limit=script_args.save_total_limit,
|
||||
push_to_hub=script_args.push_to_hub,
|
||||
hub_model_id=script_args.hub_model_id,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
# TODO: uncomment that on the next release
|
||||
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
# Step 4: Define the LoraConfig
|
||||
if script_args.use_peft:
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.peft_lora_r,
|
||||
lora_alpha=script_args.peft_lora_alpha,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=script_args.target_modules,
|
||||
)
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
# Step 5: Define the Trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
max_seq_length=script_args.seq_length,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field=script_args.dataset_text_field,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Step 6: Save the model
|
||||
trainer.save_model(script_args.output_dir)
|
||||
with save_context:
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
@ -1,16 +1,22 @@
|
||||
[tool.black]
|
||||
line-length = 119
|
||||
target-version = ['py38']
|
||||
|
||||
[tool.ruff]
|
||||
ignore = ["E501", "E741", "W605"]
|
||||
select = ["E", "F", "I", "W"]
|
||||
target-version = "py37"
|
||||
line-length = 119
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||
[tool.ruff.lint]
|
||||
ignore = [
|
||||
"B028", # warning without explicit stacklevel
|
||||
"C408", # dict() calls (stylistic)
|
||||
"C901", # function complexity
|
||||
"E501",
|
||||
]
|
||||
extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"]
|
||||
|
||||
[tool.ruff.isort]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Allow prints in auxiliary scripts
|
||||
"benchmark/**.py" = ["T201"]
|
||||
"examples/**.py" = ["T201"]
|
||||
"scripts/**.py" = ["T201"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
known-first-party = ["trl"]
|
||||
|
140
scripts/log_example_reports.py
Normal file
140
scripts/log_example_reports.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from datetime import date
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--slack_channel_name", default="trl-push-examples-ci")
|
||||
parser.add_argument("--text_file_name", required=True)
|
||||
|
||||
|
||||
def main(text_file_name, slack_channel_name=None):
|
||||
message = ""
|
||||
|
||||
if os.path.isfile(text_file_name):
|
||||
final_results = {}
|
||||
|
||||
file = open(text_file_name)
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
result, config_name = line.split(",")
|
||||
config_name = config_name.split("/")[-1].split(".yaml")[0]
|
||||
final_results[config_name] = int(result)
|
||||
|
||||
no_error_payload = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🌞 There were no failures on the example tests!"
|
||||
if not len(final_results) == 0
|
||||
else "Something went wrong there is at least one empty file - please check GH action results.",
|
||||
"emoji": True,
|
||||
},
|
||||
}
|
||||
|
||||
total_num_failed = sum(final_results.values())
|
||||
else:
|
||||
no_error_payload = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🔴 Something is wrong with the workflow please check ASAP!"
|
||||
"Something went wrong there is no text file being produced. Please check ASAP.",
|
||||
"emoji": True,
|
||||
},
|
||||
}
|
||||
|
||||
total_num_failed = 0
|
||||
|
||||
test_type_name = text_file_name.replace(".txt", "").replace("temp_results_", "").replace("_", " ").title()
|
||||
|
||||
payload = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🤗 Results of the {} TRL {} example tests.".format(
|
||||
os.environ.get("TEST_TYPE", ""), test_type_name
|
||||
),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if total_num_failed > 0:
|
||||
message += f"{total_num_failed} failed tests for example tests!"
|
||||
|
||||
for test_name, failed in final_results.items():
|
||||
failed_table = tabulate(
|
||||
[[test_name, "🟢" if not failed else "🔴"]],
|
||||
headers=["Test Name", "Status"],
|
||||
showindex="always",
|
||||
tablefmt="grid",
|
||||
maxcolwidths=[12],
|
||||
)
|
||||
message += "\n```\n" + failed_table + "\n```"
|
||||
|
||||
print(f"### {message}")
|
||||
else:
|
||||
payload.append(no_error_payload)
|
||||
|
||||
if os.environ.get("TEST_TYPE", "") != "":
|
||||
from slack_sdk import WebClient
|
||||
|
||||
if len(message) > MAX_LEN_MESSAGE:
|
||||
print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}")
|
||||
message = message[:MAX_LEN_MESSAGE] + "..."
|
||||
|
||||
if len(message) != 0:
|
||||
md_report = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": message},
|
||||
}
|
||||
payload.append(md_report)
|
||||
action_button = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "*For more details:*"},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
payload.append(action_button)
|
||||
|
||||
date_report = {
|
||||
"type": "context",
|
||||
"elements": [
|
||||
{
|
||||
"type": "plain_text",
|
||||
"text": f"On Push - main {os.environ.get('TEST_TYPE')} test results for {date.today()}",
|
||||
},
|
||||
],
|
||||
}
|
||||
payload.append(date_report)
|
||||
|
||||
print(payload)
|
||||
|
||||
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
|
||||
client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args.text_file_name, args.slack_channel_name)
|
153
scripts/log_reports.py
Normal file
153
scripts/log_reports.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--slack_channel_name", default="trl-push-ci")
|
||||
|
||||
|
||||
def main(slack_channel_name=None):
|
||||
failed = []
|
||||
passed = []
|
||||
|
||||
group_info = []
|
||||
|
||||
total_num_failed = 0
|
||||
empty_file = False or len(list(Path().glob("*.log"))) == 0
|
||||
|
||||
total_empty_files = []
|
||||
|
||||
for log in Path().glob("*.log"):
|
||||
section_num_failed = 0
|
||||
i = 0
|
||||
with open(log) as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
i += 1
|
||||
if line.get("nodeid", "") != "":
|
||||
test = line["nodeid"]
|
||||
if line.get("duration", None) is not None:
|
||||
duration = f'{line["duration"]:.4f}'
|
||||
if line.get("outcome", "") == "failed":
|
||||
section_num_failed += 1
|
||||
failed.append([test, duration, log.name.split("_")[0]])
|
||||
total_num_failed += 1
|
||||
else:
|
||||
passed.append([test, duration, log.name.split("_")[0]])
|
||||
empty_file = i == 0
|
||||
group_info.append([str(log), section_num_failed, failed])
|
||||
total_empty_files.append(empty_file)
|
||||
os.remove(log)
|
||||
failed = []
|
||||
no_error_payload = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🌞 There were no failures!"
|
||||
if not any(total_empty_files)
|
||||
else "Something went wrong there is at least one empty file - please check GH action results.",
|
||||
"emoji": True,
|
||||
},
|
||||
}
|
||||
|
||||
message = ""
|
||||
payload = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🤗 Results of the {} TRL tests.".format(os.environ.get("TEST_TYPE", "")),
|
||||
},
|
||||
},
|
||||
]
|
||||
if total_num_failed > 0:
|
||||
for i, (name, num_failed, failed_tests) in enumerate(group_info):
|
||||
if num_failed > 0:
|
||||
if num_failed == 1:
|
||||
message += f"*{name}: {num_failed} failed test*\n"
|
||||
else:
|
||||
message += f"*{name}: {num_failed} failed tests*\n"
|
||||
failed_table = []
|
||||
for test in failed_tests:
|
||||
failed_report = test[0].split("::")
|
||||
# Truncate the last string as some test names might be long
|
||||
failed_report[-1] = failed_report[-1][:30] + ".."
|
||||
failed_table.append(failed_report)
|
||||
failed_table = tabulate(
|
||||
failed_table,
|
||||
headers=["Test Location", "Test Case", "Test Name"],
|
||||
showindex="always",
|
||||
tablefmt="grid",
|
||||
maxcolwidths=[12, 12, 12],
|
||||
)
|
||||
message += "\n```\n" + failed_table + "\n```"
|
||||
|
||||
if total_empty_files[i]:
|
||||
message += f"\n*{name}: Warning! Empty file - please check the GitHub action job *\n"
|
||||
print(f"### {message}")
|
||||
else:
|
||||
payload.append(no_error_payload)
|
||||
|
||||
if os.environ.get("TEST_TYPE", "") != "":
|
||||
from slack_sdk import WebClient
|
||||
|
||||
if len(message) > MAX_LEN_MESSAGE:
|
||||
message = f"There are {total_num_failed} failed tests in total ! Cannot display the entire summary - please check the action results directly"
|
||||
|
||||
if len(message) != 0:
|
||||
md_report = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": message},
|
||||
}
|
||||
payload.append(md_report)
|
||||
action_button = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "*For more details:*"},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
payload.append(action_button)
|
||||
|
||||
date_report = {
|
||||
"type": "context",
|
||||
"elements": [
|
||||
{
|
||||
"type": "plain_text",
|
||||
"text": f"On Push main {os.environ.get('TEST_TYPE')} test results for {date.today()}",
|
||||
},
|
||||
],
|
||||
}
|
||||
payload.append(date_report)
|
||||
|
||||
print(payload)
|
||||
|
||||
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
|
||||
client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args.slack_channel_name)
|
@ -35,7 +35,7 @@ def main():
|
||||
open_issues = repo.get_issues(state="open")
|
||||
|
||||
for issue in open_issues:
|
||||
comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
|
||||
comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True)
|
||||
last_comment = comments[0] if len(comments) > 0 else None
|
||||
if (
|
||||
last_comment is not None
|
||||
|
11
setup.cfg
11
setup.cfg
@ -1,11 +1,2 @@
|
||||
[metadata]
|
||||
license_file = LICENSE
|
||||
|
||||
[isort]
|
||||
ensure_newline_before_comments = True
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = True
|
||||
line_length = 119
|
||||
lines_after_imports = 2
|
||||
multi_line_output = 3
|
||||
use_parentheses = True
|
||||
license_file = LICENSE
|
77
setup.py
77
setup.py
@ -53,11 +53,12 @@ To create the package for pypi.
|
||||
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
|
||||
Then push the change with a message 'set dev version'
|
||||
"""
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.7.5.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.8.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
@ -68,7 +69,7 @@ REQUIRED_PKGS = [
|
||||
"tyro>=0.5.11",
|
||||
]
|
||||
EXTRAS = {
|
||||
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate"],
|
||||
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist"],
|
||||
"peft": ["peft>=0.4.0"],
|
||||
"diffusers": ["diffusers>=0.18.0"],
|
||||
"deepspeed": ["deepspeed>=0.9.5"],
|
||||
@ -79,34 +80,44 @@ EXTRAS["dev"] = []
|
||||
for reqs in EXTRAS.values():
|
||||
EXTRAS["dev"].extend(reqs)
|
||||
|
||||
setup(
|
||||
name="trl",
|
||||
license="Apache 2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
url="https://github.com/huggingface/trl",
|
||||
packages=find_packages(),
|
||||
include_package_data=True,
|
||||
install_requires=REQUIRED_PKGS,
|
||||
extras_require=EXTRAS,
|
||||
python_requires=">=3.7",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
zip_safe=False,
|
||||
version=__version__,
|
||||
description="Train transformer language models with reinforcement learning.",
|
||||
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
|
||||
author="Leandro von Werra",
|
||||
author_email="leandro.vonwerra@gmail.com",
|
||||
)
|
||||
try:
|
||||
file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts"))
|
||||
|
||||
setup(
|
||||
name="trl",
|
||||
license="Apache 2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
url="https://github.com/huggingface/trl",
|
||||
entry_points={
|
||||
"console_scripts": ["trl=trl.commands.cli:main"],
|
||||
},
|
||||
include_package_data=True,
|
||||
package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*"]},
|
||||
packages=find_packages(),
|
||||
install_requires=REQUIRED_PKGS,
|
||||
extras_require=EXTRAS,
|
||||
python_requires=">=3.7",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
zip_safe=False,
|
||||
version=__version__,
|
||||
description="Train transformer language models with reinforcement learning.",
|
||||
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
|
||||
author="Leandro von Werra",
|
||||
author_email="leandro.vonwerra@gmail.com",
|
||||
)
|
||||
finally:
|
||||
os.unlink(os.path.join(file_path, "trl/commands/scripts"))
|
||||
|
0
tests/slow/__init__.py
Normal file
0
tests/slow/__init__.py
Normal file
221
tests/slow/test_dpo_slow.py
Normal file
221
tests/slow/test_dpo_slow.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import itertools
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from accelerate.utils.memory import release_memory
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
|
||||
|
||||
from trl import DPOTrainer, is_peft_available
|
||||
|
||||
from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu
|
||||
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class DPOTrainerSlowTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]")
|
||||
cls.peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=8,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
cls.max_length = 128
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
|
||||
def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
|
||||
"""
|
||||
A test that tests the simple usage of `DPOTrainer` using a bare model in full precision.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=2,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
fp16=True,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
# dpo train lora model
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
loss_type=loss_type,
|
||||
precompute_ref_log_probs=pre_compute_logits,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save trained model or adapter
|
||||
trainer.save_model()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(
|
||||
list(
|
||||
itertools.product(
|
||||
MODELS_TO_TEST,
|
||||
DPO_LOSS_TYPES,
|
||||
DPO_PRECOMPUTE_LOGITS,
|
||||
GRADIENT_CHECKPOINTING_KWARGS,
|
||||
)
|
||||
)
|
||||
)
|
||||
@require_peft
|
||||
def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_checkpointing_kwargs):
|
||||
"""
|
||||
A test that tests the simple usage of `DPOTrainer` using a peft model in full precision + different scenarios of gradient checkpointing.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=2,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
fp16=True,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
# dpo train lora model
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
generate_during_eval=False,
|
||||
loss_type=loss_type,
|
||||
precompute_ref_log_probs=pre_compute_logits,
|
||||
peft_config=self.peft_config,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
assert trainer.ref_model is None
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save trained model or adapter
|
||||
trainer.save_model()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(
|
||||
list(
|
||||
itertools.product(
|
||||
MODELS_TO_TEST,
|
||||
DPO_LOSS_TYPES,
|
||||
DPO_PRECOMPUTE_LOGITS,
|
||||
GRADIENT_CHECKPOINTING_KWARGS,
|
||||
)
|
||||
)
|
||||
)
|
||||
@require_bitsandbytes
|
||||
@require_peft
|
||||
def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gradient_checkpointing_kwargs):
|
||||
"""
|
||||
A test that tests the simple usage of `DPOTrainer` using QLoRA + different scenarios of gradient checkpointing.
|
||||
"""
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=2,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
fp16=True,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
# dpo train lora model
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
generate_during_eval=False,
|
||||
loss_type=loss_type,
|
||||
precompute_ref_log_probs=pre_compute_logits,
|
||||
peft_config=self.peft_config,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
assert trainer.ref_model is None
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save trained model or adapter
|
||||
trainer.save_model()
|
||||
|
||||
release_memory(model, trainer)
|
390
tests/slow/test_sft_slow.py
Normal file
390
tests/slow/test_sft_slow.py
Normal file
@ -0,0 +1,390 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import itertools
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from accelerate.utils.memory import release_memory
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
|
||||
|
||||
from trl import SFTTrainer, is_peft_available
|
||||
from trl.models.utils import setup_chat_format
|
||||
|
||||
from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
|
||||
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class SFTTrainerSlowTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.train_dataset = load_dataset("imdb", split="train[:10%]")
|
||||
cls.eval_dataset = load_dataset("imdb", split="test[:10%]")
|
||||
cls.dataset_text_field = "text"
|
||||
cls.max_seq_length = 128
|
||||
cls.peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=8,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
def test_sft_trainer_str(self, model_name, packing):
|
||||
"""
|
||||
Simply tests if passing a simple str to `SFTTrainer` loads and runs the trainer
|
||||
as expected.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model_name,
|
||||
args=args,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
def test_sft_trainer_transformers(self, model_name, packing):
|
||||
"""
|
||||
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
|
||||
as expected.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
@require_peft
|
||||
def test_sft_trainer_peft(self, model_name, packing):
|
||||
"""
|
||||
Simply tests if passing a transformers model + peft config to `SFTTrainer` loads and runs the trainer
|
||||
as expected.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
peft_config=self.peft_config,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
def test_sft_trainer_transformers_mp(self, model_name, packing):
|
||||
"""
|
||||
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
|
||||
as expected in mixed precision.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS)))
|
||||
def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_checkpointing_kwargs):
|
||||
"""
|
||||
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
|
||||
as expected in mixed precision + different scenarios of gradient_checkpointing.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS)))
|
||||
@require_peft
|
||||
def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient_checkpointing_kwargs):
|
||||
"""
|
||||
Simply tests if passing a transformers model + PEFT to `SFTTrainer` loads and runs the trainer
|
||||
as expected in mixed precision + different scenarios of gradient_checkpointing.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
peft_config=self.peft_config,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(
|
||||
list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, DEVICE_MAP_OPTIONS))
|
||||
)
|
||||
@require_torch_multi_gpu
|
||||
def test_sft_trainer_transformers_mp_gc_device_map(
|
||||
self, model_name, packing, gradient_checkpointing_kwargs, device_map
|
||||
):
|
||||
"""
|
||||
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
|
||||
as expected in mixed precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc).
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS)))
|
||||
@require_peft
|
||||
@require_bitsandbytes
|
||||
def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gradient_checkpointing_kwargs):
|
||||
"""
|
||||
Simply tests if passing a transformers model + PEFT + bnb to `SFTTrainer` loads and runs the trainer
|
||||
as expected in mixed precision + different scenarios of gradient_checkpointing.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
)
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=packing,
|
||||
dataset_text_field=self.dataset_text_field,
|
||||
max_seq_length=self.max_seq_length,
|
||||
peft_config=self.peft_config,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
@require_peft
|
||||
@require_bitsandbytes
|
||||
def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
|
||||
"""
|
||||
Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer
|
||||
as expected.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train")
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
peft_config=self.peft_config,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(model, trainer)
|
27
tests/slow/testing_constants.py
Normal file
27
tests/slow/testing_constants.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: push them under trl-org
|
||||
MODELS_TO_TEST = [
|
||||
"HuggingFaceM4/tiny-random-LlamaForCausalLM",
|
||||
"HuggingFaceM4/tiny-random-MistralForCausalLM",
|
||||
]
|
||||
|
||||
# We could have also not declared these variables but let's be verbose
|
||||
PACKING_OPTIONS = [True, False]
|
||||
GRADIENT_CHECKPOINTING_KWARGS = [None, {"use_reentrant": False}, {"use_reentrant": True}]
|
||||
DEVICE_MAP_OPTIONS = [{"": 0}, "auto"]
|
||||
|
||||
DPO_LOSS_TYPES = ["sigmoid", "ipo", "kto_pair"]
|
||||
DPO_PRECOMPUTE_LOGITS = [True, False]
|
@ -59,7 +59,7 @@ class BestOfNSamplerTester(unittest.TestCase):
|
||||
|
||||
for q, expected_length in various_queries_formats:
|
||||
results = best_of_n.generate(q)
|
||||
self.assertIsInstance(results, list)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == expected_length
|
||||
|
||||
def test_different_sample_sizes_and_n_candidates_values(self):
|
||||
|
40
tests/test_cli.py
Normal file
40
tests/test_cli.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
|
||||
def test_sft_cli():
|
||||
try:
|
||||
subprocess.run(
|
||||
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise AssertionError("An error occured while running the CLI, please double check") from exc
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
|
||||
def test_dpo_cli():
|
||||
try:
|
||||
subprocess.run(
|
||||
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise AssertionError("An error occured while running the CLI, please double check") from exc
|
@ -30,13 +30,13 @@ class CoreTester(unittest.TestCase):
|
||||
cls.test_input_unmasked = cls.test_input[1:3]
|
||||
|
||||
def test_masked_mean(self):
|
||||
self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask))
|
||||
assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask)
|
||||
|
||||
def test_masked_var(self):
|
||||
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))
|
||||
assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask)
|
||||
|
||||
def test_masked_whiten(self):
|
||||
whiten_unmasked = whiten(self.test_input_unmasked)
|
||||
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
|
||||
diffs = (whiten_unmasked - whiten_masked).sum()
|
||||
self.assertAlmostEqual(diffs, 0)
|
||||
assert abs(diffs.item()) < 0.00001
|
||||
|
@ -31,18 +31,21 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
|
||||
self.instruction_template = "\n### User:"
|
||||
self.response_template = "\n### Assistant:"
|
||||
|
||||
# GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25]
|
||||
# GPT2Tokenizer: [198, 21017, 11787, 25] -> [21017, 11787, 25]
|
||||
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901]
|
||||
# Note: If this test is ever switched to Llama2Tokenizer, this should be double checked,
|
||||
# and possibly switched back to [2:] instead of [1:].
|
||||
# With GPT2Tokenizer, [1:] is correct - we want the 21017 token included, which is ###.
|
||||
self.tokenized_instruction_w_context = self.tokenizer.encode(
|
||||
self.instruction_template, add_special_tokens=False
|
||||
)[2:]
|
||||
)[1:]
|
||||
|
||||
# GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25]
|
||||
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901]
|
||||
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]
|
||||
|
||||
# Plain check on string
|
||||
self.assertIn(self.response_template, self.instruction)
|
||||
assert self.response_template in self.instruction
|
||||
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
|
||||
|
||||
# Test the fix for #598
|
||||
@ -57,6 +60,28 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
|
||||
)
|
||||
self.collator.torch_call([self.tokenized_instruction])
|
||||
|
||||
# Test for PR #1185
|
||||
# We pass in a string where the first user template is different than the rest.
|
||||
# Usually this would happen due to context-sensitive tokenization, but here we
|
||||
# explicitly change the template to test the fix.
|
||||
self.instruction = """## User: First instruction
|
||||
|
||||
### Assistant: First response
|
||||
|
||||
### User: Second instruction
|
||||
|
||||
### Assistant: Second response"""
|
||||
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
|
||||
self.collator = DataCollatorForCompletionOnlyLM(
|
||||
self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer
|
||||
)
|
||||
collator_output = self.collator.torch_call([self.tokenized_instruction])
|
||||
collator_text = self.tokenizer.decode(
|
||||
collator_output["labels"][torch.where(collator_output["labels"] != -100)]
|
||||
)
|
||||
expected_text = " First response\n\n Second response" ""
|
||||
assert collator_text == expected_text
|
||||
|
||||
def test_data_collator_handling_of_long_sequences(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
|
||||
self.instruction = """### System: You are a helpful assistant.
|
||||
@ -69,7 +94,7 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
|
||||
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer)
|
||||
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
|
||||
result = torch.all(encoded_instance["labels"] == -100)
|
||||
self.assertTrue(result, "Not all values in the tensor are -100.")
|
||||
assert result, "Not all values in the tensor are -100."
|
||||
|
||||
# check DataCollatorForCompletionOnlyLM using response template and instruction template
|
||||
self.instruction_template = "\n### User:"
|
||||
@ -78,4 +103,4 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
|
||||
)
|
||||
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
|
||||
result = torch.all(encoded_instance["labels"] == -100)
|
||||
self.assertTrue(result, "Not all values in the tensor are -100.")
|
||||
assert result, "Not all values in the tensor are -100."
|
||||
|
142
tests/test_dataset_formatting.py
Normal file
142
tests/test_dataset_formatting.py
Normal file
@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
from typing import Callable
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
|
||||
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format
|
||||
|
||||
|
||||
class DatasetFormattingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.llama_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
self.chatml_tokenizer = AutoTokenizer.from_pretrained("philschmid/gpt2-chatml-tokenizer")
|
||||
|
||||
def test_get_formatting_func_from_dataset_with_chatml_messages(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
[
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
]
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Llama tokenizer
|
||||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
|
||||
assert isinstance(formatting_func, Callable)
|
||||
formatted_text = formatting_func(dataset[0])
|
||||
expected = "<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"
|
||||
assert formatted_text == expected
|
||||
formatted_text = formatting_func(dataset[0:1])
|
||||
assert formatted_text == [expected]
|
||||
|
||||
# ChatML tokenizer
|
||||
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
|
||||
formatted_text = formatting_func(dataset[0])
|
||||
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
|
||||
assert formatted_text == expected
|
||||
formatted_text = formatting_func(dataset[0:1])
|
||||
assert formatted_text == [expected]
|
||||
|
||||
def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"conversations": [
|
||||
[
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
]
|
||||
]
|
||||
}
|
||||
)
|
||||
# Llama tokenizer
|
||||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
|
||||
assert isinstance(formatting_func, Callable)
|
||||
formatted_text = formatting_func(dataset[0])
|
||||
expected = "<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"
|
||||
assert formatted_text == expected
|
||||
formatted_text = formatting_func(dataset[0:1])
|
||||
assert formatted_text == [expected]
|
||||
|
||||
# ChatML tokenizer
|
||||
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
|
||||
formatted_text = formatting_func(dataset[0])
|
||||
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
|
||||
assert formatted_text == expected
|
||||
formatted_text = formatting_func(dataset[0:1])
|
||||
assert formatted_text == [expected]
|
||||
|
||||
def test_get_formatting_func_from_dataset_with_instruction(self):
|
||||
dataset = Dataset.from_list(
|
||||
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}]
|
||||
)
|
||||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
|
||||
assert formatting_func is not None
|
||||
assert isinstance(formatting_func, Callable)
|
||||
formatted_text = formatting_func(dataset[0])
|
||||
assert formatted_text == "<s>[INST] What is 2+2? [/INST] 4 </s>"
|
||||
formatted_text = formatting_func(dataset[0:1])
|
||||
assert formatted_text == ["<s>[INST] What is 2+2? [/INST] 4 </s>"]
|
||||
|
||||
def test_get_formatting_func_from_dataset_from_hub(self):
|
||||
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train")
|
||||
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train")
|
||||
for ds in [ds_1, ds_2]:
|
||||
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer)
|
||||
assert formatting_func is not None
|
||||
assert isinstance(formatting_func, Callable)
|
||||
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train")
|
||||
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer)
|
||||
assert formatting_func is None
|
||||
|
||||
def test_get_formatting_func_from_dataset_with_unknown_format(self):
|
||||
dataset = Dataset.from_dict({"text": "test"})
|
||||
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
|
||||
assert formatting_func is None
|
||||
|
||||
|
||||
class SetupChatFormatTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
|
||||
def test_setup_chat_format(self):
|
||||
original_tokenizer_len = len(self.tokenizer)
|
||||
modified_model, modified_tokenizer = setup_chat_format(
|
||||
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
|
||||
)
|
||||
|
||||
_chatml = ChatMlSpecialTokens()
|
||||
# Check if special tokens are correctly set
|
||||
assert modified_tokenizer.eos_token == "<|im_end|>"
|
||||
assert modified_tokenizer.pad_token == "<|im_end|>"
|
||||
assert modified_tokenizer.bos_token == "<|im_start|>"
|
||||
assert modified_tokenizer.eos_token == _chatml.eos_token
|
||||
assert modified_tokenizer.pad_token == _chatml.pad_token
|
||||
assert modified_tokenizer.bos_token == _chatml.bos_token
|
||||
assert len(modified_tokenizer) == (original_tokenizer_len + 2)
|
||||
assert (self.model.get_input_embeddings().weight.shape[0] % 64) == 0
|
||||
assert self.model.get_input_embeddings().weight.shape[0] == (original_tokenizer_len + 64)
|
||||
|
||||
def test_example_with_setup_model(self):
|
||||
modified_model, modified_tokenizer = setup_chat_format(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
]
|
||||
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
assert (
|
||||
prompt
|
||||
== "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
|
||||
)
|
@ -16,12 +16,12 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from trl import is_diffusers_available
|
||||
from trl import is_diffusers_available, is_peft_available
|
||||
|
||||
from .testing_utils import require_diffusers
|
||||
|
||||
|
||||
if is_diffusers_available():
|
||||
if is_diffusers_available() and is_peft_available():
|
||||
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
|
||||
@ -68,13 +68,13 @@ class DDPOTrainerTester(unittest.TestCase):
|
||||
clip_range = 0.0001
|
||||
ratio = torch.tensor([1.0])
|
||||
loss = self.trainer.loss(advantage, clip_range, ratio)
|
||||
self.assertEqual(loss.item(), 1.0)
|
||||
assert loss.item() == 1.0
|
||||
|
||||
def test_generate_samples(self):
|
||||
samples, output_pairs = self.trainer._generate_samples(1, 2)
|
||||
self.assertEqual(len(samples), 1)
|
||||
self.assertEqual(len(output_pairs), 1)
|
||||
self.assertEqual(len(output_pairs[0][0]), 2)
|
||||
assert len(samples) == 1
|
||||
assert len(output_pairs) == 1
|
||||
assert len(output_pairs[0][0]) == 2
|
||||
|
||||
def test_calculate_loss(self):
|
||||
samples, _ = self.trainer._generate_samples(1, 2)
|
||||
@ -87,13 +87,41 @@ class DDPOTrainerTester(unittest.TestCase):
|
||||
prompt_embeds = sample["prompt_embeds"]
|
||||
advantage = torch.tensor([1.0], device=prompt_embeds.device)
|
||||
|
||||
self.assertEqual(latents.shape, (1, 4, 64, 64))
|
||||
self.assertEqual(next_latents.shape, (1, 4, 64, 64))
|
||||
self.assertEqual(log_probs.shape, (1,))
|
||||
self.assertEqual(timesteps.shape, (1,))
|
||||
self.assertEqual(prompt_embeds.shape, (2, 77, 32))
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
assert next_latents.shape == (1, 4, 64, 64)
|
||||
assert log_probs.shape == (1,)
|
||||
assert timesteps.shape == (1,)
|
||||
assert prompt_embeds.shape == (2, 77, 32)
|
||||
loss, approx_kl, clipfrac = self.trainer.calculate_loss(
|
||||
latents, timesteps, next_latents, log_probs, advantage, prompt_embeds
|
||||
)
|
||||
|
||||
self.assertTrue(torch.isfinite(loss.cpu()))
|
||||
assert torch.isfinite(loss.cpu())
|
||||
|
||||
|
||||
@require_diffusers
|
||||
class DDPOTrainerWithLoRATester(DDPOTrainerTester):
|
||||
"""
|
||||
Test the DDPOTrainer class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.ddpo_config = DDPOConfig(
|
||||
num_epochs=2,
|
||||
train_gradient_accumulation_steps=1,
|
||||
per_prompt_stat_tracking_buffer_size=32,
|
||||
sample_num_batches_per_epoch=2,
|
||||
sample_batch_size=2,
|
||||
mixed_precision=None,
|
||||
save_freq=1000000,
|
||||
)
|
||||
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
|
||||
pretrained_revision = "main"
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline(
|
||||
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True
|
||||
)
|
||||
|
||||
self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline)
|
||||
|
||||
return super().setUp()
|
||||
|
@ -22,7 +22,7 @@ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokeni
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
from .testing_utils import require_no_wandb, require_peft
|
||||
from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft
|
||||
|
||||
|
||||
class DPOTrainerTester(unittest.TestCase):
|
||||
@ -129,14 +129,14 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
assert not torch.equal(param, new_param)
|
||||
|
||||
def test_dpo_trainer_without_providing_ref_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -167,14 +167,14 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
assert not torch.equal(param, new_param)
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
@ -218,7 +218,7 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
@ -226,7 +226,78 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
assert not torch.equal(param, new_param)
|
||||
|
||||
def test_dpo_trainer_padding_token_is_none(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
tokenizer.pad_token = None
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
|
||||
r" before calling the trainer.",
|
||||
):
|
||||
trainer = DPOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
def test_dpo_trainer_w_dataset_num_proc(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
tokenizer.pad_token = None
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
|
||||
r" before calling the trainer.",
|
||||
):
|
||||
trainer = DPOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
dataset_num_proc=5,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
@require_no_wandb
|
||||
def test_dpo_trainer_generate_during_eval_no_wandb(self):
|
||||
@ -313,3 +384,267 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
except OSError:
|
||||
self.fail("Loading the saved peft adapter failed")
|
||||
|
||||
@require_peft
|
||||
@require_bitsandbytes
|
||||
@mark.peft_test
|
||||
def test_dpo_lora_bf16_autocast_llama(self):
|
||||
# Note this test only works on compute capability > 7 GPU devices
|
||||
from peft import LoraConfig
|
||||
|
||||
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
bf16=True,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
# dpo train lora model with a lora config
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
generate_during_eval=True,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save peft adapter
|
||||
trainer.save_model()
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["gpt2", "sigmoid", False, False],
|
||||
["gpt2", "sigmoid", False, True],
|
||||
["gpt2", "sigmoid", True, False],
|
||||
["gpt2", "sigmoid", True, True],
|
||||
["gpt2", "ipo", False, False],
|
||||
["gpt2", "ipo", False, True],
|
||||
["gpt2", "ipo", True, False],
|
||||
["gpt2", "ipo", True, True],
|
||||
["gpt2", "kto_pair", False, False],
|
||||
["gpt2", "kto_pair", False, True],
|
||||
["gpt2", "kto_pair", True, False],
|
||||
["gpt2", "kto_pair", True, True],
|
||||
]
|
||||
)
|
||||
@require_bitsandbytes
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
@unittest.skip("You need a GPU with bf16 support in order to run these tests")
|
||||
def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_eval):
|
||||
# Note this test only works on compute capability > 7 GPU devices
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
bf16=True,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
# dpo train lora model with a lora config
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
generate_during_eval=gen_during_eval,
|
||||
loss_type=loss_type,
|
||||
precompute_ref_log_probs=pre_compute,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save peft adapter
|
||||
trainer.save_model()
|
||||
|
||||
@require_peft
|
||||
def test_dpo_lora_tags(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
# dpo train lora model with a lora config
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
||||
@require_peft
|
||||
def test_dpo_tags(self):
|
||||
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
# dpo train lora model with a lora config
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_dpo_lora_force_use_ref(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
model_peft = get_peft_model(model, lora_config)
|
||||
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# passing a peft_model as model and ref_model should error out,
|
||||
# unless you pass `force_use_ref_model`
|
||||
trainer = DPOTrainer(
|
||||
model=model_peft,
|
||||
ref_model=ref_model,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model_peft,
|
||||
ref_model=ref_model,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
force_use_ref_model=True,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
@ -38,12 +38,12 @@ class TextHistoryTest(unittest.TestCase):
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
|
||||
history = TextHistory(text, tokens)
|
||||
self.assertEqual(history.text, text)
|
||||
self.assertTrue(torch.equal(history.tokens, tokens))
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens)))
|
||||
assert history.text == text
|
||||
assert torch.equal(history.tokens, tokens)
|
||||
assert torch.equal(history.token_masks, torch.zeros_like(tokens))
|
||||
|
||||
history = TextHistory(text, tokens, system=False)
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens)))
|
||||
assert torch.equal(history.token_masks, torch.ones_like(tokens))
|
||||
|
||||
def test_text_history_append_segment(self):
|
||||
text = "Hello there!"
|
||||
@ -51,26 +51,26 @@ class TextHistoryTest(unittest.TestCase):
|
||||
|
||||
history = TextHistory(text, tokens)
|
||||
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
|
||||
self.assertEqual(history.text, text + "General Kenobi!")
|
||||
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6])))
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1])))
|
||||
assert history.text == (text + "General Kenobi!")
|
||||
assert torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))
|
||||
assert torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))
|
||||
|
||||
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
|
||||
self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!")
|
||||
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])))
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0])))
|
||||
assert history.text == ((text + "General Kenobi!") + "You are a bold one!")
|
||||
assert torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))
|
||||
assert torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))
|
||||
|
||||
def test_text_history_complete(self):
|
||||
text = "Hello there!"
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
history = TextHistory(text, tokens)
|
||||
history.complete()
|
||||
self.assertTrue(history.completed)
|
||||
self.assertFalse(history.truncated)
|
||||
assert history.completed
|
||||
assert not history.truncated
|
||||
|
||||
history.complete(truncated=True)
|
||||
self.assertTrue(history.completed)
|
||||
self.assertTrue(history.truncated)
|
||||
assert history.completed
|
||||
assert history.truncated
|
||||
|
||||
def test_text_history_last_segment(self):
|
||||
text = "Hello there!"
|
||||
@ -78,7 +78,7 @@ class TextHistoryTest(unittest.TestCase):
|
||||
history = TextHistory(text, tokens)
|
||||
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]))
|
||||
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
|
||||
self.assertEqual(history.last_text_segment, "You are a bold one!")
|
||||
assert history.last_text_segment == "You are a bold one!"
|
||||
|
||||
def test_text_history_split_query_response(self):
|
||||
text = "Hello there!"
|
||||
@ -88,9 +88,9 @@ class TextHistoryTest(unittest.TestCase):
|
||||
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True)
|
||||
query, response, mask = history.split_query_response_tokens()
|
||||
|
||||
self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3])))
|
||||
self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9])))
|
||||
self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0])))
|
||||
assert torch.equal(query, torch.tensor([1, 2, 3]))
|
||||
assert torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))
|
||||
assert torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))
|
||||
|
||||
|
||||
class TextEnvironmentTester(unittest.TestCase):
|
||||
@ -112,10 +112,10 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
reward_fn=lambda x: torch.tensor(1),
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
self.assertEqual(env.prompt, "I am a prompt!\n")
|
||||
self.assertEqual(list(env.tools.keys()), ["DummyTool"])
|
||||
self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool))
|
||||
self.assertEqual(env.reward_fn("Hello there!"), 1)
|
||||
assert env.prompt == "I am a prompt!\n"
|
||||
assert list(env.tools.keys()) == ["DummyTool"]
|
||||
assert isinstance(env.tools["DummyTool"], DummyTool)
|
||||
assert env.reward_fn("Hello there!") == 1
|
||||
|
||||
def test_text_environment_generate(self):
|
||||
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
|
||||
@ -138,7 +138,7 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
|
||||
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)
|
||||
|
||||
self.assertEqual(generations_single, generations_batched)
|
||||
assert generations_single == generations_batched
|
||||
|
||||
def test_text_environment_tool_call_parsing(self):
|
||||
string_valid = "Something something <request><Tool1>Hello there!<call>"
|
||||
@ -155,24 +155,24 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
tool, response = env.parse_tool_call(string_valid)
|
||||
self.assertEqual(tool, "Tool1")
|
||||
self.assertEqual(response, "Hello there!")
|
||||
assert tool == "Tool1"
|
||||
assert response == "Hello there!"
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_request)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
assert tool is None
|
||||
assert response is None
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_call)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
assert tool is None
|
||||
assert response is None
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_tool)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
assert tool is None
|
||||
assert response is None
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_random)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
assert tool is None
|
||||
assert response is None
|
||||
|
||||
def test_text_environment_tool_truncation(self):
|
||||
env = TextEnvironment(
|
||||
@ -185,19 +185,19 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
|
||||
env.max_tool_response = 100
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100)
|
||||
assert (len(history.last_text_segment) - len(env.response_token)) == 100
|
||||
|
||||
env.max_tool_response = 500
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500)
|
||||
assert (len(history.last_text_segment) - len(env.response_token)) == 500
|
||||
|
||||
env.max_tool_response = 1001
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
|
||||
assert (len(history.last_text_segment) - len(env.response_token)) == 1000
|
||||
|
||||
env.max_tool_response = 2000
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
|
||||
assert (len(history.last_text_segment) - len(env.response_token)) == 1000
|
||||
|
||||
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
|
||||
def test_text_environment_max_calls(self, mock_generate):
|
||||
@ -211,20 +211,20 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
|
||||
env.max_turns = 1
|
||||
_, _, _, _, histories = env.run(["test"])
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "test" + 1 * "<request><DummyTool>test<call>test<response>"
|
||||
assert histories[0].text == (
|
||||
("I am a prompt!\n" + "test") + (1 * "<request><DummyTool>test<call>test<response>")
|
||||
)
|
||||
|
||||
env.max_turns = 2
|
||||
_, _, _, _, histories = env.run(["test"])
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "test" + 2 * "<request><DummyTool>test<call>test<response>"
|
||||
assert histories[0].text == (
|
||||
("I am a prompt!\n" + "test") + (2 * "<request><DummyTool>test<call>test<response>")
|
||||
)
|
||||
|
||||
env.max_turns = 4
|
||||
_, _, _, _, histories = env.run(["test"])
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "test" + 4 * "<request><DummyTool>test<call>test<response>"
|
||||
assert histories[0].text == (
|
||||
("I am a prompt!\n" + "test") + (4 * "<request><DummyTool>test<call>test<response>")
|
||||
)
|
||||
|
||||
def test_text_environment_compute_rewards(self):
|
||||
@ -240,7 +240,7 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
histories = env.compute_reward(histories)
|
||||
|
||||
for i in range(8):
|
||||
self.assertEqual(histories[i].reward, i)
|
||||
assert histories[i].reward == i
|
||||
|
||||
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
|
||||
def test_text_environment_run(self, mock_generate):
|
||||
@ -256,18 +256,20 @@ class TextEnvironmentTester(unittest.TestCase):
|
||||
task_2 = "Hello there! General Kenobi!"
|
||||
|
||||
query, response, response_mask, reward, histories = env.run([task_1, task_2])
|
||||
self.assertEqual(len(query[0]), 9)
|
||||
self.assertEqual(len(query[1]), 12)
|
||||
self.assertEqual(len(response[0]), 14)
|
||||
self.assertEqual(len(response[1]), 14)
|
||||
self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes
|
||||
self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes
|
||||
self.assertEqual(reward[0], 0)
|
||||
self.assertEqual(reward[1], 1)
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "<request><DummyTool>test<call>test<response>"
|
||||
assert len(query[0]) == 9
|
||||
assert len(query[1]) == 12
|
||||
assert len(response[0]) == 14
|
||||
assert len(response[1]) == 14
|
||||
assert response_mask[0].sum() == (2 * 3)
|
||||
# mocked generate always adds 3 toknes
|
||||
assert response_mask[1].sum() == (2 * 3)
|
||||
# mocked generate always adds 3 toknes
|
||||
assert reward[0] == 0
|
||||
assert reward[1] == 1
|
||||
assert histories[0].text == (
|
||||
("I am a prompt!\n" + "Hello there!") + (2 * "<request><DummyTool>test<call>test<response>")
|
||||
)
|
||||
self.assertEqual(
|
||||
histories[1].text,
|
||||
"I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "<request><DummyTool>test<call>test<response>",
|
||||
assert histories[1].text == (
|
||||
("I am a prompt!\n" + "Hello there! General Kenobi!")
|
||||
+ (2 * "<request><DummyTool>test<call>test<response>")
|
||||
)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
@ -31,15 +32,27 @@ class IterativeTrainerTester(unittest.TestCase):
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab-calibrated"
|
||||
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def _init_tensor_dummy_dataset(self):
|
||||
dummy_dataset_dict = {
|
||||
"input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
|
||||
"attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])],
|
||||
"labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
|
||||
"input_ids": [
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
],
|
||||
"attention_mask": [
|
||||
torch.tensor([1, 1, 1, 1, 1]),
|
||||
torch.tensor([1, 1, 1, 1, 1, 1]),
|
||||
torch.tensor([1, 1, 1, 1, 1]),
|
||||
],
|
||||
"labels": [
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
],
|
||||
}
|
||||
|
||||
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
|
||||
@ -94,11 +107,10 @@ class IterativeTrainerTester(unittest.TestCase):
|
||||
tokenizer = self.t5_tokenizer
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, learning_rate=1e-3
|
||||
)
|
||||
iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer)
|
||||
iterative_trainer.optimizer.zero_grad = partial(iterative_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
|
||||
iterative_trainer.step(**inputs)
|
||||
|
||||
|
340
tests/test_kto_trainer.py
Normal file
340
tests/test_kto_trainer.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from trl import KTOConfig, KTOTrainer
|
||||
|
||||
from .testing_utils import require_no_wandb, require_peft
|
||||
|
||||
|
||||
class KTOTrainerTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
|
||||
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
|
||||
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def _init_dummy_dataset(self):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
return Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["gpt2", True, True],
|
||||
["gpt2", True, False],
|
||||
# ["t5", True],
|
||||
["gpt2", False, True],
|
||||
["gpt2", False, False],
|
||||
# ["t5", False],
|
||||
]
|
||||
)
|
||||
def test_kto_trainer(self, name, pre_compute, eval_dataset):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
precompute_ref_log_probs=pre_compute,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
if name == "gpt2":
|
||||
model = self.model
|
||||
ref_model = self.ref_model
|
||||
tokenizer = self.tokenizer
|
||||
elif name == "t5":
|
||||
model = self.t5_model
|
||||
ref_model = self.t5_ref_model
|
||||
tokenizer = self.t5_tokenizer
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset if eval_dataset else None,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
def test_kto_trainer_tokenize_row(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=self.ref_model,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
row = dummy_dataset[0]
|
||||
|
||||
# test that the row can be tokenized
|
||||
tokenized_row = trainer.tokenize_row(row)
|
||||
|
||||
# Assert bos_token_id and eos_token_id (latter only for completion)
|
||||
assert tokenized_row["prompt_input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert tokenized_row["completion_input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert tokenized_row["prompt_input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
assert tokenized_row["completion_input_ids"][-1] == self.tokenizer.eos_token_id
|
||||
|
||||
def test_kto_trainer_without_providing_ref_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_kto_trainer_without_providing_ref_model_with_lora(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
if "lora" in n:
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
@require_no_wandb
|
||||
def test_kto_trainer_generate_during_eval_no_wandb(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
generate_during_eval=True,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
|
||||
" Please install with `pip install wandb` to resolve.",
|
||||
):
|
||||
KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_kto_lora_save(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
model_peft = get_peft_model(model, lora_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
# kto train lora model with a lora config
|
||||
trainer = KTOTrainer(
|
||||
model=model_peft,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save peft adapter
|
||||
trainer.save_model()
|
||||
|
||||
# assert that the model is loaded without giving OSError
|
||||
try:
|
||||
AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
except OSError:
|
||||
self.fail("Loading the saved peft adapter failed")
|
@ -15,6 +15,7 @@ import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
|
||||
@ -31,7 +32,7 @@ ALL_CAUSAL_LM_MODELS = [
|
||||
"trl-internal-testing/tiny-random-GPT2LMHeadModel",
|
||||
"trl-internal-testing/tiny-random-CodeGenForCausalLM-sharded",
|
||||
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors-sharded",
|
||||
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors"
|
||||
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors",
|
||||
# "trl-internal-testing/tiny-random-LlamaForCausalLM", uncomment on the next transformers release
|
||||
]
|
||||
|
||||
@ -68,7 +69,7 @@ class VHeadModelTester:
|
||||
"""
|
||||
for model_name in self.all_model_names:
|
||||
model = self.trl_model_class.from_pretrained(model_name)
|
||||
self.assertTrue(hasattr(model, "v_head"))
|
||||
assert hasattr(model, "v_head")
|
||||
|
||||
def test_value_head_shape(self):
|
||||
r"""
|
||||
@ -76,7 +77,7 @@ class VHeadModelTester:
|
||||
"""
|
||||
for model_name in self.all_model_names:
|
||||
model = self.trl_model_class.from_pretrained(model_name)
|
||||
self.assertTrue(model.v_head.summary.weight.shape[0] == 1)
|
||||
assert model.v_head.summary.weight.shape[0] == 1
|
||||
|
||||
def test_value_head_init_random(self):
|
||||
r"""
|
||||
@ -86,7 +87,7 @@ class VHeadModelTester:
|
||||
"""
|
||||
for model_name in self.all_model_names:
|
||||
model = self.trl_model_class.from_pretrained(model_name)
|
||||
self.assertFalse(torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)))
|
||||
assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
|
||||
|
||||
def test_value_head_not_str(self):
|
||||
r"""
|
||||
@ -96,7 +97,7 @@ class VHeadModelTester:
|
||||
for model_name in self.all_model_names:
|
||||
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
|
||||
model = self.trl_model_class.from_pretrained(pretrained_model)
|
||||
self.assertTrue(hasattr(model, "v_head"))
|
||||
assert hasattr(model, "v_head")
|
||||
|
||||
def test_from_save_trl(self):
|
||||
"""
|
||||
@ -113,7 +114,7 @@ class VHeadModelTester:
|
||||
|
||||
# Check if the weights are the same
|
||||
for key in model_from_save.state_dict():
|
||||
self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
|
||||
assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
|
||||
|
||||
def test_from_save_trl_sharded(self):
|
||||
"""
|
||||
@ -129,7 +130,7 @@ class VHeadModelTester:
|
||||
|
||||
# Check if the weights are the same
|
||||
for key in model_from_save.state_dict():
|
||||
self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
|
||||
assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
|
||||
|
||||
def test_from_save_transformers_sharded(self):
|
||||
"""
|
||||
@ -146,10 +147,8 @@ class VHeadModelTester:
|
||||
|
||||
# Check if the weights are the same
|
||||
for key in transformers_model.state_dict():
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
|
||||
)
|
||||
|
||||
def test_from_save_transformers(self):
|
||||
@ -168,24 +167,20 @@ class VHeadModelTester:
|
||||
|
||||
# Check if the weights are the same
|
||||
for key in transformers_model.state_dict():
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
|
||||
)
|
||||
|
||||
# Check if the trl model has the same keys as the transformers model
|
||||
# except the v_head
|
||||
for key in trl_model.state_dict():
|
||||
if "v_head" not in key:
|
||||
self.assertTrue(key in transformers_model.state_dict())
|
||||
assert key in transformers_model.state_dict()
|
||||
# check if the weights are the same
|
||||
self.assertTrue(torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]))
|
||||
assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
|
||||
|
||||
# check if they have the same modules
|
||||
self.assertTrue(
|
||||
set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys())
|
||||
)
|
||||
assert set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys())
|
||||
|
||||
|
||||
class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
@ -215,7 +210,7 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
|
||||
# Check if the outputs are of the right size - here
|
||||
# we always output 3 values - logits, loss, and value states
|
||||
self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
|
||||
assert len(outputs) == EXPECTED_OUTPUT_SIZE
|
||||
|
||||
def test_dropout_config(self):
|
||||
r"""
|
||||
@ -228,7 +223,7 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
model = self.trl_model_class.from_pretrained(pretrained_model)
|
||||
|
||||
# Check if v head of the model has the same dropout as the config
|
||||
self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
|
||||
assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
|
||||
|
||||
def test_dropout_kwargs(self):
|
||||
r"""
|
||||
@ -241,12 +236,12 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
|
||||
|
||||
# Check if v head of the model has the same dropout as the config
|
||||
self.assertEqual(model.v_head.dropout.p, 0.5)
|
||||
assert model.v_head.dropout.p == 0.5
|
||||
|
||||
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
|
||||
|
||||
# Check if v head of the model has the same dropout as the config
|
||||
self.assertEqual(model.v_head.dropout.p, 0.5)
|
||||
assert model.v_head.dropout.p == 0.5
|
||||
|
||||
def test_generate(self):
|
||||
r"""
|
||||
@ -263,7 +258,7 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
# Test with a model without a LM head
|
||||
model_id = "trl-internal-testing/tiny-random-GPT2Model"
|
||||
# This should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)
|
||||
|
||||
@ -279,13 +274,11 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
|
||||
lm_head_namings = self.trl_model_class.lm_head_namings
|
||||
|
||||
self.assertTrue(
|
||||
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
|
||||
)
|
||||
assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
|
||||
|
||||
for lm_head_naming in lm_head_namings:
|
||||
if hasattr(trl_model.pretrained_model, lm_head_naming):
|
||||
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
|
||||
assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
|
||||
|
||||
@ -303,13 +296,12 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
|
||||
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo")
|
||||
# check all keys
|
||||
self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
|
||||
assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
|
||||
|
||||
for name, param in model.state_dict().items():
|
||||
self.assertTrue(
|
||||
torch.allclose(param, model_from_pretrained.state_dict()[name]),
|
||||
f"Parameter {name} is not the same after push_to_hub and from_pretrained",
|
||||
)
|
||||
assert torch.allclose(
|
||||
param, model_from_pretrained.state_dict()[name]
|
||||
), f"Parameter {name} is not the same after push_to_hub and from_pretrained"
|
||||
|
||||
|
||||
class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
@ -340,7 +332,7 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
|
||||
# Check if the outputs are of the right size - here
|
||||
# we always output 3 values - logits, loss, and value states
|
||||
self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
|
||||
assert len(outputs) == EXPECTED_OUTPUT_SIZE
|
||||
|
||||
def test_dropout_config(self):
|
||||
r"""
|
||||
@ -353,7 +345,7 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
model = self.trl_model_class.from_pretrained(pretrained_model)
|
||||
|
||||
# Check if v head of the model has the same dropout as the config
|
||||
self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
|
||||
assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
|
||||
|
||||
def test_dropout_kwargs(self):
|
||||
r"""
|
||||
@ -366,12 +358,12 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
|
||||
|
||||
# Check if v head of the model has the same dropout as the config
|
||||
self.assertEqual(model.v_head.dropout.p, 0.5)
|
||||
assert model.v_head.dropout.p == 0.5
|
||||
|
||||
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
|
||||
|
||||
# Check if v head of the model has the same dropout as the config
|
||||
self.assertEqual(model.v_head.dropout.p, 0.5)
|
||||
assert model.v_head.dropout.p == 0.5
|
||||
|
||||
def test_generate(self):
|
||||
r"""
|
||||
@ -389,7 +381,7 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
# Test with a model without a LM head
|
||||
model_id = "trl-internal-testing/tiny-random-T5Model"
|
||||
# This should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pretrained_model = AutoModel.from_pretrained(model_id)
|
||||
_ = self.trl_model_class.from_pretrained(pretrained_model)
|
||||
|
||||
@ -404,13 +396,12 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
|
||||
model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo")
|
||||
# check all keys
|
||||
self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
|
||||
assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
|
||||
|
||||
for name, param in model.state_dict().items():
|
||||
self.assertTrue(
|
||||
torch.allclose(param, model_from_pretrained.state_dict()[name]),
|
||||
f"Parameter {name} is not the same after push_to_hub and from_pretrained",
|
||||
)
|
||||
assert torch.allclose(
|
||||
param, model_from_pretrained.state_dict()[name]
|
||||
), f"Parameter {name} is not the same after push_to_hub and from_pretrained"
|
||||
|
||||
def test_transformers_bf16_kwargs(self):
|
||||
r"""
|
||||
@ -428,13 +419,11 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
|
||||
# skip the test for FSMT as it does not support mixed-prec
|
||||
continue
|
||||
|
||||
self.assertTrue(
|
||||
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
|
||||
)
|
||||
assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
|
||||
|
||||
for lm_head_naming in lm_head_namings:
|
||||
if hasattr(trl_model.pretrained_model, lm_head_naming):
|
||||
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
|
||||
assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
|
||||
|
||||
@ -474,14 +463,14 @@ class ReferenceModelTest(unittest.TestCase):
|
||||
last_ref_layer_after = ref_model.get_parameter(layer_5).data.clone()
|
||||
|
||||
# before optimization ref and model are identical
|
||||
self.assertTrue((first_layer_before == first_ref_layer_before).all())
|
||||
self.assertTrue((last_layer_before == last_ref_layer_before).all())
|
||||
assert (first_layer_before == first_ref_layer_before).all()
|
||||
assert (last_layer_before == last_ref_layer_before).all()
|
||||
# ref model stays identical after optimization
|
||||
self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
|
||||
self.assertTrue((last_ref_layer_before == last_ref_layer_after).all())
|
||||
assert (first_ref_layer_before == first_ref_layer_after).all()
|
||||
assert (last_ref_layer_before == last_ref_layer_after).all()
|
||||
# optimized model changes
|
||||
self.assertTrue(not (first_layer_before == first_layer_after).all())
|
||||
self.assertTrue(not (last_layer_before == last_layer_after).all())
|
||||
assert not (first_layer_before == first_layer_after).all()
|
||||
assert not (last_layer_before == last_layer_after).all()
|
||||
|
||||
def test_shared_layers(self):
|
||||
layer_0 = self.layer_format.format(layer=0)
|
||||
@ -506,12 +495,12 @@ class ReferenceModelTest(unittest.TestCase):
|
||||
second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone()
|
||||
|
||||
# before optimization ref and model are identical
|
||||
self.assertTrue((first_layer_before == first_ref_layer_before).all())
|
||||
self.assertTrue((second_layer_before == second_ref_layer_before).all())
|
||||
assert (first_layer_before == first_ref_layer_before).all()
|
||||
assert (second_layer_before == second_ref_layer_before).all()
|
||||
# ref model stays identical after optimization
|
||||
self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
|
||||
self.assertTrue((second_ref_layer_before == second_ref_layer_after).all())
|
||||
assert (first_ref_layer_before == first_ref_layer_after).all()
|
||||
assert (second_ref_layer_before == second_ref_layer_after).all()
|
||||
# first layer of optimized model stays the same
|
||||
self.assertTrue((first_layer_before == first_layer_after).all())
|
||||
assert (first_layer_before == first_layer_after).all()
|
||||
# other layers in optimized model change
|
||||
self.assertTrue(not (second_layer_before == second_layer_after).all())
|
||||
assert not (second_layer_before == second_layer_after).all()
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
from functools import partial
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
@ -93,15 +95,15 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
# Check that loading a model with `peft` will raise an error
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
import peft # noqa
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
import peft # noqa: F401
|
||||
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) # noqa
|
||||
trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) # noqa
|
||||
_trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
|
||||
_trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id)
|
||||
|
||||
def test_imports_no_peft(self):
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import ( # noqa
|
||||
from trl import ( # noqa: F401
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PPOConfig,
|
||||
@ -133,6 +135,7 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
tokenizer=tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -140,14 +143,14 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
# check gradients are not None
|
||||
for _, param in trl_model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertIsNotNone(param.grad)
|
||||
assert param.grad is not None
|
||||
|
||||
# check expected stats
|
||||
for stat in EXPECTED_STATS:
|
||||
self.assertIn(stat, train_stats)
|
||||
assert stat in train_stats
|
||||
|
@ -23,7 +23,7 @@ from trl import AutoModelForCausalLMWithValueHead, is_peft_available
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import get_peft_model, LoraConfig
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
from .testing_utils import require_bitsandbytes, require_peft
|
||||
|
||||
@ -60,7 +60,7 @@ class PeftModelTester(unittest.TestCase):
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)
|
||||
|
||||
# Check that the value head has requires_grad=True
|
||||
self.assertTrue(model.v_head.summary.weight.requires_grad)
|
||||
assert model.v_head.summary.weight.requires_grad
|
||||
|
||||
def test_check_peft_model_nb_trainable_params(self):
|
||||
r"""
|
||||
@ -73,12 +73,12 @@ class PeftModelTester(unittest.TestCase):
|
||||
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 10273)
|
||||
assert nb_trainable_params == 10273
|
||||
|
||||
# Check that the number of trainable param for the non-peft model is correct
|
||||
non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
|
||||
nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 99578)
|
||||
assert nb_trainable_params == 99578
|
||||
|
||||
def test_create_peft_model_from_config(self):
|
||||
r"""
|
||||
@ -89,13 +89,13 @@ class PeftModelTester(unittest.TestCase):
|
||||
)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 10273)
|
||||
assert nb_trainable_params == 10273
|
||||
|
||||
causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 10273)
|
||||
assert nb_trainable_params == 10273
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_create_bnb_peft_model_from_config(self):
|
||||
@ -109,10 +109,8 @@ class PeftModelTester(unittest.TestCase):
|
||||
)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 10273)
|
||||
self.assertTrue(
|
||||
trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
|
||||
)
|
||||
assert nb_trainable_params == 10273
|
||||
assert trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
|
||||
|
||||
causal_lm_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
|
||||
@ -120,10 +118,8 @@ class PeftModelTester(unittest.TestCase):
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 10273)
|
||||
self.assertTrue(
|
||||
trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
|
||||
)
|
||||
assert nb_trainable_params == 10273
|
||||
assert trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
|
||||
|
||||
def test_save_pretrained_peft(self):
|
||||
r"""
|
||||
@ -138,31 +134,23 @@ class PeftModelTester(unittest.TestCase):
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
|
||||
self.assertTrue(
|
||||
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
|
||||
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
|
||||
)
|
||||
self.assertTrue(
|
||||
os.path.exists(f"{tmp_dir}/adapter_config.json"),
|
||||
msg=f"{tmp_dir}/adapter_config.json does not exist",
|
||||
)
|
||||
assert os.path.isfile(
|
||||
f"{tmp_dir}/adapter_model.safetensors"
|
||||
), f"{tmp_dir}/adapter_model.safetensors does not exist"
|
||||
assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
|
||||
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
|
||||
self.assertTrue(
|
||||
os.path.exists(f"{tmp_dir}/pytorch_model.bin"),
|
||||
msg=f"{tmp_dir}/pytorch_model.bin does not exist",
|
||||
)
|
||||
assert os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist"
|
||||
maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin")
|
||||
# check that only keys that starts with `v_head` are in the dict
|
||||
self.assertTrue(
|
||||
all(k.startswith("v_head") for k in maybe_v_head.keys()),
|
||||
msg=f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`",
|
||||
)
|
||||
assert all(
|
||||
k.startswith("v_head") for k in maybe_v_head.keys()
|
||||
), f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`"
|
||||
|
||||
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)
|
||||
|
||||
# check all the weights are the same
|
||||
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
|
||||
self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}")
|
||||
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
|
||||
|
||||
def test_load_pretrained_peft(self):
|
||||
r"""
|
||||
@ -178,19 +166,15 @@ class PeftModelTester(unittest.TestCase):
|
||||
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)
|
||||
|
||||
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
|
||||
self.assertTrue(
|
||||
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
|
||||
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
|
||||
)
|
||||
self.assertTrue(
|
||||
os.path.exists(f"{tmp_dir}/adapter_config.json"),
|
||||
msg=f"{tmp_dir}/adapter_config.json does not exist",
|
||||
)
|
||||
assert os.path.isfile(
|
||||
f"{tmp_dir}/adapter_model.safetensors"
|
||||
), f"{tmp_dir}/adapter_model.safetensors does not exist"
|
||||
assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
|
||||
|
||||
# check all the weights are the same
|
||||
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
|
||||
if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]:
|
||||
self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}")
|
||||
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
|
||||
|
||||
def test_continue_training_peft_model(self):
|
||||
r"""
|
||||
@ -205,4 +189,4 @@ class PeftModelTester(unittest.TestCase):
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir, is_trainable=True)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
self.assertEqual(nb_trainable_params, 10273)
|
||||
assert nb_trainable_params == 10273
|
||||
|
@ -17,6 +17,7 @@ import gc
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -180,7 +181,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
self.assertEqual(len(dummy_dataloader), 0)
|
||||
assert len(dummy_dataloader) == 0
|
||||
|
||||
def test_ppo_step(self):
|
||||
# initialize dataset
|
||||
@ -193,6 +194,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -200,7 +202,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
for param in ppo_trainer.model.parameters():
|
||||
@ -220,6 +222,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -230,9 +233,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
response_mask = [torch.ones_like(r) for r in response_tensor]
|
||||
|
||||
# train model
|
||||
train_stats = ppo_trainer.step(
|
||||
[q for q in query_tensor], [r for r in response_tensor], reward, response_mask
|
||||
)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward, response_mask)
|
||||
break
|
||||
|
||||
for param in ppo_trainer.model.parameters():
|
||||
@ -254,9 +255,10 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD))
|
||||
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
|
||||
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -264,15 +266,15 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
for name, param in ppo_trainer.model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
|
||||
# ref model should not be trained
|
||||
for name, param in ppo_trainer.ref_model.named_parameters():
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
# Finally check stats
|
||||
for stat in EXPECTED_STATS:
|
||||
@ -293,10 +295,11 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD))
|
||||
self.assertTrue(isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR))
|
||||
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
|
||||
assert isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR)
|
||||
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -304,23 +307,23 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
for name, param in ppo_trainer.model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
|
||||
# ref model should not be trained
|
||||
for name, param in ppo_trainer.ref_model.named_parameters():
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
# Finally check stats
|
||||
for stat in EXPECTED_STATS:
|
||||
assert stat in train_stats.keys()
|
||||
|
||||
# assert that the LR has increased for exponential decay
|
||||
self.assertTrue(train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate)
|
||||
assert train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate
|
||||
|
||||
def test_ppo_step_with_no_ref(self):
|
||||
# initialize dataset
|
||||
@ -334,6 +337,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -341,15 +345,15 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
for name, param in ppo_trainer.model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
|
||||
# ref model should not be trained
|
||||
for name, param in ppo_trainer.ref_model.named_parameters():
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
# initialize a new gpt2 model:
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id)
|
||||
@ -357,10 +361,9 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
if "v_head" not in name:
|
||||
name = name.replace("pretrained_model.", "")
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(param.cpu(), model.state_dict()[name].cpu()),
|
||||
f"Parameter {name} has changed from the original model",
|
||||
)
|
||||
assert torch.allclose(
|
||||
param.cpu(), model.state_dict()[name].cpu()
|
||||
), f"Parameter {name} has changed from the original model"
|
||||
|
||||
# Finally check stats
|
||||
for stat in EXPECTED_STATS:
|
||||
@ -385,6 +388,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
num_shared_layers=num_shared_layers,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -392,7 +396,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
pattern = r".*transformer\.h\.(\d+)\..*"
|
||||
@ -402,15 +406,15 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
if re.match(pattern, name):
|
||||
layer_number = int(re.match(pattern, name).groups(0)[0])
|
||||
if layer_number < num_shared_layers:
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
else:
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
elif any([layer in name for layer in final_layers]):
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
elif any(layer in name for layer in final_layers):
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
|
||||
# ref model should not be trained
|
||||
for name, param in ppo_trainer.ref_model.named_parameters():
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
for stat in EXPECTED_STATS:
|
||||
assert stat in train_stats.keys()
|
||||
@ -452,6 +456,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -459,21 +464,21 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])]
|
||||
# train model - this should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
with pytest.raises(ValueError):
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
|
||||
reward = [torch.tensor([1.0]), torch.tensor([0.0])]
|
||||
# train model - this should work
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
# check if the gradients are computed for the model
|
||||
for name, param in ppo_trainer.model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
|
||||
# ref model should not be trained
|
||||
for name, param in ppo_trainer.ref_model.named_parameters():
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
def test_ppo_step_input_shape(self):
|
||||
"""
|
||||
@ -489,6 +494,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -499,16 +505,16 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
bs = ppo_trainer.config.batch_size
|
||||
|
||||
queries, responses, _, _ = ppo_trainer._step_safety_checker(
|
||||
bs, [q for q in query_tensor], [r for r in response_tensor], reward
|
||||
bs, list(query_tensor), list(response_tensor), reward
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(queries, list), f"queries should be a list, got {type(queries)}")
|
||||
self.assertTrue(isinstance(responses, list), f"responses should be a list, got {type(responses)}")
|
||||
assert isinstance(queries, list), f"queries should be a list, got {type(queries)}"
|
||||
assert isinstance(responses, list), f"responses should be a list, got {type(responses)}"
|
||||
|
||||
# check the shapes
|
||||
for i in range(bs):
|
||||
self.assertEqual(queries[i].shape, torch.Size([7]))
|
||||
self.assertEqual(responses[i].size(), torch.Size([7]))
|
||||
assert queries[i].shape == torch.Size([7])
|
||||
assert responses[i].size() == torch.Size([7])
|
||||
break
|
||||
|
||||
def test_ppo_step_no_dataset(self):
|
||||
@ -529,6 +535,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
ref_model=self.gpt2_model_ref,
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
# train model with ppo
|
||||
reward = [torch.tensor([1.0])]
|
||||
# train model - this should work fine
|
||||
@ -536,15 +543,15 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
|
||||
# check gradients
|
||||
for name, param in ppo_trainer.model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
|
||||
# ref model should not be trained
|
||||
for name, param in ppo_trainer.ref_model.named_parameters():
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
# check train stats
|
||||
for stat in EXPECTED_STATS:
|
||||
self.assertTrue(stat in train_stats, f"Train stats should contain {stat}")
|
||||
assert stat in train_stats, f"Train stats should contain {stat}"
|
||||
|
||||
def test_loss_trainer(self):
|
||||
"""
|
||||
@ -579,7 +586,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
logits = torch.exp(all_logprobs)
|
||||
vpreds = values + 0.1
|
||||
|
||||
score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
|
||||
score, non_score, kls = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
|
||||
values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask)
|
||||
|
||||
# just make sure a dummy loss is computed
|
||||
@ -595,8 +602,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
returns[idx].unsqueeze(0),
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(pg_loss.item(), 2.0494, 4)
|
||||
self.assertAlmostEqual(v_loss.item(), 0.07110, 4)
|
||||
assert abs(pg_loss.item() - 2.0494) < 0.0001
|
||||
assert abs(v_loss.item() - 0.0711) < 0.0001
|
||||
|
||||
# check if we get same results with masked parts removed
|
||||
pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss(
|
||||
@ -609,8 +616,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
apply_mask(advantages[idx], mask[idx]).unsqueeze(0),
|
||||
apply_mask(returns[idx], mask[idx]).unsqueeze(0),
|
||||
)
|
||||
self.assertAlmostEqual(pg_loss_unmasked.item(), 2.0494, 4)
|
||||
self.assertAlmostEqual(v_loss_unmasked.item(), 0.07110, 4)
|
||||
assert abs(pg_loss_unmasked.item() - 2.0494) < 0.0001
|
||||
assert abs(v_loss_unmasked.item() - 0.0711) < 0.0001
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@ -674,11 +681,11 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
model, dummy_queries, dummy_responses, model_inputs
|
||||
)
|
||||
|
||||
self.assertLessEqual(abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2), 1e-4)
|
||||
self.assertLessEqual(abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2), 1e-4)
|
||||
assert abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2) <= 0.0001
|
||||
assert abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2) <= 0.0001
|
||||
|
||||
self.assertLessEqual(abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]), 1e-4)
|
||||
self.assertLessEqual(abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]), 1e-4)
|
||||
assert abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]) <= 0.0001
|
||||
assert abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]) <= 0.0001
|
||||
|
||||
def test_ppo_trainer_max_grad_norm(self):
|
||||
"""
|
||||
@ -695,7 +702,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
# train model with ppo
|
||||
@ -704,16 +711,15 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
# check gradients
|
||||
for name, param in ppo_trainer.model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
|
||||
self.assertTrue(
|
||||
torch.all(param.grad.abs() <= self.ppo_config.max_grad_norm),
|
||||
f"Parameter {name} has a gradient larger than max_grad_norm",
|
||||
)
|
||||
assert param.grad is not None, f"Parameter {name} has no gradient"
|
||||
assert torch.all(
|
||||
param.grad.abs() <= self.ppo_config.max_grad_norm
|
||||
), f"Parameter {name} has a gradient larger than max_grad_norm"
|
||||
|
||||
def test_ppo_trainer_kl_penalty(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
@ -730,7 +736,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
expected_output = torch.Tensor([[0.1000, -0.1000, 0.1000], [-0.1000, 0.1000, -0.2000]])
|
||||
self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output))
|
||||
assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)
|
||||
|
||||
self.ppo_config.kl_penalty = "abs"
|
||||
ppo_trainer = PPOTrainer(
|
||||
@ -742,7 +748,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
expected_output = torch.Tensor([[0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 0.2000]])
|
||||
self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output))
|
||||
assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)
|
||||
|
||||
self.ppo_config.kl_penalty = "mse"
|
||||
ppo_trainer = PPOTrainer(
|
||||
@ -754,7 +760,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
expected_output = torch.Tensor([[0.0050, 0.0050, 0.0050], [0.0050, 0.0050, 0.0200]])
|
||||
self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output))
|
||||
assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)
|
||||
|
||||
def test_ppo_trainer_full_kl_penalty(self):
|
||||
# a few more extensive tests for the full kl option as it is more involved
|
||||
@ -793,8 +799,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
[[0.0, 0.0]],
|
||||
)
|
||||
output = ppo_trainer._kl_penalty(log_probs, ref_log_probs)
|
||||
self.assertTrue(output.shape == (1, 2))
|
||||
self.assertTrue(torch.allclose(output, expected_output))
|
||||
assert output.shape == (1, 2)
|
||||
assert torch.allclose(output, expected_output)
|
||||
|
||||
# test for when the two dists are almost not overlapping
|
||||
log_probs = torch.Tensor(
|
||||
@ -819,8 +825,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
[[4.4474, 4.4474]],
|
||||
)
|
||||
output = ppo_trainer._kl_penalty(log_probs, ref_log_probs)
|
||||
self.assertTrue(output.shape == (1, 2))
|
||||
self.assertTrue(torch.allclose(output, expected_output))
|
||||
assert output.shape == (1, 2)
|
||||
assert torch.allclose(output, expected_output)
|
||||
|
||||
# test for when the two dists are almost not overlapping
|
||||
log_probs = torch.Tensor(
|
||||
@ -845,8 +851,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
[[3.7361, 0.0]],
|
||||
)
|
||||
output = ppo_trainer._kl_penalty(log_probs, ref_log_probs)
|
||||
self.assertTrue(output.shape == (1, 2))
|
||||
self.assertTrue(torch.allclose(output, expected_output, atol=1e-4))
|
||||
assert output.shape == (1, 2)
|
||||
assert torch.allclose(output, expected_output, atol=0.0001)
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
@ -883,8 +889,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
self.assertTrue(ppo_trainer.ref_model is None)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
assert ppo_trainer.ref_model is None
|
||||
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
@ -894,19 +900,19 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model by running a step twice
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
|
||||
ppo_trainer.model.train()
|
||||
ppo_trainer.model.gradient_checkpointing_enable()
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
# check gradients
|
||||
for name, param in model.named_parameters():
|
||||
if "lora" in name or "v_head" in name:
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has a no gradient"
|
||||
else:
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
@ -971,8 +977,8 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
self.assertTrue(ppo_trainer.ref_model is None)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
assert ppo_trainer.ref_model is None
|
||||
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
@ -982,23 +988,23 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model by running a step twice
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
|
||||
ppo_trainer.model.train()
|
||||
ppo_trainer.model.gradient_checkpointing_enable()
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
new_logits = ppo_trainer.model.compute_reward_score(dummy_inputs)
|
||||
self.assertTrue(not torch.allclose(previous_rm_logits, new_logits[:, -1, :]))
|
||||
self.assertTrue(torch.allclose(original_rm_logits, new_logits[:, -1, :]))
|
||||
assert not torch.allclose(previous_rm_logits, new_logits[:, -1, :])
|
||||
assert torch.allclose(original_rm_logits, new_logits[:, -1, :])
|
||||
|
||||
# check gradients
|
||||
for name, param in model.named_parameters():
|
||||
if ("lora" in name or "v_head" in name) and ("reward" not in name):
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has a no gradient"
|
||||
else:
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
@unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.")
|
||||
def test_push_to_hub(self):
|
||||
@ -1016,10 +1022,10 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
url = ppo_trainer.push_to_hub(repo_id=repo_id, token=self._token, api_endpoint=CI_HUB_ENDPOINT)
|
||||
# Extract repo_name from the url
|
||||
re_search = re.search(CI_HUB_ENDPOINT + r"/([^/]+/[^/]+)/", url)
|
||||
self.assertTrue(re_search is not None)
|
||||
assert re_search is not None
|
||||
hub_repo_id = re_search.groups()[0]
|
||||
# Check we created a Hub repo
|
||||
self.assertEqual(hub_repo_id, repo_id)
|
||||
assert hub_repo_id == repo_id
|
||||
# Ensure all files are present
|
||||
files = sorted(self._api.list_repo_files(hub_repo_id))
|
||||
assert all(
|
||||
@ -1057,7 +1063,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
"gpt2", device_map="balanced", max_memory={0: "500MB", 1: "500MB"}
|
||||
)
|
||||
|
||||
self.assertTrue(set(gpt2_model.hf_device_map.values()) == {0, 1})
|
||||
assert set(gpt2_model.hf_device_map.values()) == {0, 1}
|
||||
|
||||
# this line is very important
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
@ -1068,7 +1074,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
peft_model = get_peft_model(gpt2_model, lora_config)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model)
|
||||
|
||||
self.assertTrue(model.is_sequential_parallel)
|
||||
assert model.is_sequential_parallel
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
self.ppo_config.batch_size = 2
|
||||
@ -1082,7 +1088,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
self.assertTrue(ppo_trainer.ref_model is None)
|
||||
assert ppo_trainer.ref_model is None
|
||||
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
@ -1092,19 +1098,19 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model by running a step twice
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
|
||||
ppo_trainer.model.train()
|
||||
ppo_trainer.model.gradient_checkpointing_enable()
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
# check gradients
|
||||
for name, param in model.named_parameters():
|
||||
if "lora" in name or "v_head" in name:
|
||||
self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient")
|
||||
assert param.grad is not None, f"Parameter {name} has a no gradient"
|
||||
else:
|
||||
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
|
||||
assert param.grad is None, f"Parameter {name} has a gradient"
|
||||
|
||||
def test_generation(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
@ -1134,7 +1140,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
generations_single = [ppo_trainer.generate(inputs, **generation_kwargs).squeeze() for inputs in model_inputs]
|
||||
generations_single = tokenizer.batch_decode(generations_single)
|
||||
|
||||
self.assertEqual(generations_single, generations_batched)
|
||||
assert generations_single == generations_batched
|
||||
|
||||
def test_grad_accumulation(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
@ -1162,7 +1168,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(1.0)]
|
||||
# train model by running a step twice
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
model_grad = gpt2_model.v_head.summary.weight
|
||||
@ -1186,11 +1192,11 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(1.0)]
|
||||
# train model by running a step twice
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
model_grad_acc = gpt2_model_clone.v_head.summary.weight
|
||||
self.assertTrue(torch.allclose(model_grad_acc, model_grad, rtol=1e-3, atol=1e-3))
|
||||
assert torch.allclose(model_grad_acc, model_grad, rtol=0.001, atol=0.001)
|
||||
|
||||
@unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.")
|
||||
def test_push_to_hub_if_best_reward(self):
|
||||
@ -1217,6 +1223,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -1224,7 +1231,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
def test_batch_size_check(self):
|
||||
|
@ -14,6 +14,7 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
|
||||
@ -35,7 +36,7 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
def test_accuracy_metrics(self):
|
||||
dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0]))
|
||||
accuracy = compute_accuracy(dummy_eval_predictions)
|
||||
self.assertEqual(accuracy["accuracy"], 0.5)
|
||||
assert accuracy["accuracy"] == 0.5
|
||||
|
||||
def test_reward_trainer(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -52,9 +53,9 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"input_ids_chosen": [
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
],
|
||||
"attention_mask_chosen": [
|
||||
@ -64,9 +65,9 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
torch.LongTensor([1, 0]),
|
||||
],
|
||||
"input_ids_rejected": [
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
],
|
||||
"attention_mask_rejected": [
|
||||
@ -91,17 +92,17 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
assert not torch.equal(param, new_param)
|
||||
|
||||
preds = trainer.predict(dummy_dataset)
|
||||
self.assertEqual(preds.predictions.shape, (4, 2))
|
||||
assert preds.predictions.shape == (4, 2)
|
||||
|
||||
@require_peft
|
||||
def test_reward_trainer_peft(self):
|
||||
@ -132,9 +133,9 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"input_ids_chosen": [
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
],
|
||||
"attention_mask_chosen": [
|
||||
@ -144,9 +145,9 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
torch.LongTensor([1, 0]),
|
||||
],
|
||||
"input_ids_rejected": [
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
],
|
||||
"attention_mask_rejected": [
|
||||
@ -175,27 +176,27 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
|
||||
# check gradients are not None
|
||||
for n, param in trainer.model.named_parameters():
|
||||
if any([t in n for t in trainable_params_name]):
|
||||
if any(t in n for t in trainable_params_name):
|
||||
previous_trainable_params[n] = param.clone()
|
||||
else:
|
||||
previous_non_trainable_params[n] = param.clone()
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
||||
assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
|
||||
|
||||
# check the non trainable params have not changed
|
||||
for n, param in previous_non_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
||||
assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
|
||||
|
||||
preds = trainer.predict(dummy_dataset)
|
||||
self.assertEqual(preds.predictions.shape, (4, 2))
|
||||
assert preds.predictions.shape == (4, 2)
|
||||
|
||||
def test_reward_trainer_assert_value_error(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -206,12 +207,12 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
# fmt: off
|
||||
"input_ids_b": [
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
],
|
||||
"attention_mask_c": [
|
||||
@ -221,9 +222,9 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
torch.LongTensor([1, 0]),
|
||||
],
|
||||
"input_ids_f": [
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
],
|
||||
"attention_mask_g": [
|
||||
@ -232,8 +233,8 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
torch.LongTensor([1, 1]),
|
||||
torch.LongTensor([1, 1, 1]),
|
||||
],
|
||||
# fmt: on
|
||||
}
|
||||
# fmt: on
|
||||
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
trainer = RewardTrainer(
|
||||
@ -243,7 +244,7 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
train_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
trainer.train()
|
||||
|
||||
training_args = RewardConfig(
|
||||
@ -276,13 +277,13 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"input_ids_chosen": [
|
||||
torch.LongTensor([0, 1, 2,]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
],
|
||||
"attention_mask_chosen": [
|
||||
torch.LongTensor([1, 1, 1]),
|
||||
],
|
||||
"input_ids_rejected": [
|
||||
torch.LongTensor([0, 2,]),
|
||||
torch.LongTensor([0, 2]),
|
||||
],
|
||||
"attention_mask_rejected": [
|
||||
torch.LongTensor([1, 1]),
|
||||
@ -306,9 +307,60 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
batch = trainer.data_collator(batch)
|
||||
loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True)
|
||||
|
||||
self.assertAlmostEqual(
|
||||
loss,
|
||||
-torch.nn.functional.logsigmoid(
|
||||
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
|
||||
).mean(),
|
||||
l_val = -torch.nn.functional.logsigmoid(
|
||||
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
|
||||
).mean()
|
||||
|
||||
assert abs(loss - l_val) < 1e-6
|
||||
|
||||
def test_reward_trainer_tags(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = RewardConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"input_ids_chosen": [
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
torch.LongTensor([0, 1, 2]),
|
||||
torch.LongTensor([1, 2]),
|
||||
],
|
||||
"attention_mask_chosen": [
|
||||
torch.LongTensor([1, 1, 1]),
|
||||
torch.LongTensor([1, 0]),
|
||||
torch.LongTensor([1, 1, 1]),
|
||||
torch.LongTensor([1, 0]),
|
||||
],
|
||||
"input_ids_rejected": [
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
torch.LongTensor([0, 2]),
|
||||
torch.LongTensor([1, 2, 0]),
|
||||
],
|
||||
"attention_mask_rejected": [
|
||||
torch.LongTensor([1, 1]),
|
||||
torch.LongTensor([1, 1, 0]),
|
||||
torch.LongTensor([1, 1]),
|
||||
torch.LongTensor([1, 1, 1]),
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
@ -17,6 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
@ -85,6 +86,42 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
)
|
||||
cls.dummy_chatml_dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
[
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
{"role": "user", "content": "What is 3+3?"},
|
||||
{"role": "assistant", "content": "6"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
],
|
||||
]
|
||||
}
|
||||
)
|
||||
cls.dummy_instruction_dataset = Dataset.from_list(
|
||||
[
|
||||
{"prompt": "What is 2+2?", "completion": "4"},
|
||||
{"prompt": "What is 3+3?", "completion": "6"},
|
||||
{"prompt": "What is 4+4?", "completion": "8"},
|
||||
{"prompt": "What is 2+2?", "completion": "4"},
|
||||
{"prompt": "What is 3+3?", "completion": "6"},
|
||||
{"prompt": "What is 4+4?", "completion": "8"},
|
||||
{"prompt": "What is 2+2?", "completion": "4"},
|
||||
{"prompt": "What is 3+3?", "completion": "6"},
|
||||
{"prompt": "What is 4+4?", "completion": "8"},
|
||||
{"prompt": "What is 2+2?", "completion": "4"},
|
||||
{"prompt": "What is 3+3?", "completion": "6"},
|
||||
{"prompt": "What is 4+4?", "completion": "8"},
|
||||
]
|
||||
)
|
||||
|
||||
cls.train_dataset = ConstantLengthDataset(
|
||||
cls.tokenizer,
|
||||
@ -112,18 +149,18 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
formatting_func=formatting_prompts_func,
|
||||
)
|
||||
|
||||
self.assertTrue(len(formatted_dataset) == len(self.dummy_dataset))
|
||||
self.assertTrue(len(formatted_dataset) > 0)
|
||||
assert len(formatted_dataset) == len(self.dummy_dataset)
|
||||
assert len(formatted_dataset) > 0
|
||||
|
||||
for example in formatted_dataset:
|
||||
self.assertTrue("input_ids" in example)
|
||||
self.assertTrue("labels" in example)
|
||||
assert "input_ids" in example
|
||||
assert "labels" in example
|
||||
|
||||
self.assertTrue(len(example["input_ids"]) == formatted_dataset.seq_length)
|
||||
self.assertTrue(len(example["labels"]) == formatted_dataset.seq_length)
|
||||
assert len(example["input_ids"]) == formatted_dataset.seq_length
|
||||
assert len(example["labels"]) == formatted_dataset.seq_length
|
||||
|
||||
decoded_text = self.tokenizer.decode(example["input_ids"])
|
||||
self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text))
|
||||
assert ("Question" in decoded_text) and ("Answer" in decoded_text)
|
||||
|
||||
def test_sft_trainer(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -147,10 +184,10 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
def test_sft_trainer_uncorrect_data(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -164,14 +201,42 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_dataset,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
# this should work since the dummy chatml include the correct format
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
num_of_sequences=32,
|
||||
packing=True,
|
||||
)
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
packing=False,
|
||||
)
|
||||
# this should work since the dummy instruction dataset is the correct format
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_instruction_dataset,
|
||||
max_seq_length=16, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
)
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_instruction_dataset,
|
||||
packing=False,
|
||||
)
|
||||
# This should work
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
@ -182,7 +247,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
packing=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
# This should not work because not enough data for one sample
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
@ -194,7 +259,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
# This should not work as well
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
_ = SFTTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
@ -235,10 +300,10 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
@ -263,9 +328,9 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
@ -288,9 +353,9 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
|
||||
|
||||
def test_sft_trainer_with_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -314,10 +379,10 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
@ -341,9 +406,9 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
# with formatting_func + packed
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -368,9 +433,9 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
# with formatting_func + packed
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -393,9 +458,9 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
@ -417,9 +482,9 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
|
||||
|
||||
def test_sft_trainer_with_multiple_eval_datasets(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -446,11 +511,11 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_data1_loss"] is not None
|
||||
assert trainer.state.log_history[1]["eval_data2_loss"] is not None
|
||||
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
|
||||
|
||||
def test_data_collator_completion_lm(self):
|
||||
response_template = "### Response:\n"
|
||||
@ -465,7 +530,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
labels = batch["labels"]
|
||||
last_pad_idx = np.where(labels == -100)[1][-1]
|
||||
result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :])
|
||||
self.assertEqual(result_text, "I have not been masked correctly.")
|
||||
assert result_text == "I have not been masked correctly."
|
||||
|
||||
def test_data_collator_completion_lm_with_multiple_text(self):
|
||||
tokenizer = copy.deepcopy(self.tokenizer)
|
||||
@ -488,7 +553,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
labels = batch["labels"][i]
|
||||
last_pad_idx = np.where(labels == -100)[0][-1]
|
||||
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :])
|
||||
self.assertEqual(result_text, "I have not been masked correctly.")
|
||||
assert result_text == "I have not been masked correctly."
|
||||
|
||||
def test_data_collator_chat_completion_lm(self):
|
||||
instruction_template = "### Human:"
|
||||
@ -509,7 +574,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
labels = batch["labels"]
|
||||
non_masked_tokens = batch["input_ids"][labels != -100]
|
||||
result_text = self.tokenizer.decode(non_masked_tokens)
|
||||
self.assertEqual(result_text, " I should not be masked. I should not be masked too.")
|
||||
assert result_text == " I should not be masked. I should not be masked too."
|
||||
|
||||
def test_data_collator_chat_completion_lm_with_multiple_text(self):
|
||||
tokenizer = copy.deepcopy(self.tokenizer)
|
||||
@ -537,11 +602,11 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
non_masked_tokens1 = input_ids[0][labels[0] != -100]
|
||||
result_text1 = tokenizer.decode(non_masked_tokens1)
|
||||
self.assertEqual(result_text1, " I should not be masked.")
|
||||
assert result_text1 == " I should not be masked."
|
||||
|
||||
non_masked_tokens2 = input_ids[1][labels[1] != -100]
|
||||
result_text2 = tokenizer.decode(non_masked_tokens2)
|
||||
self.assertEqual(result_text2, " I should not be masked. I should not be masked too.")
|
||||
assert result_text2 == " I should not be masked. I should not be masked too."
|
||||
|
||||
def test_sft_trainer_infinite_with_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -564,15 +629,15 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
max_seq_length=500,
|
||||
)
|
||||
|
||||
self.assertTrue(trainer.train_dataset.infinite)
|
||||
assert trainer.train_dataset.infinite
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
# make sure the trainer did 5 steps
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5")
|
||||
|
||||
def test_sft_trainer_infinite_with_model_epochs(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -593,14 +658,14 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
max_seq_length=500,
|
||||
)
|
||||
|
||||
self.assertFalse(trainer.train_dataset.infinite)
|
||||
assert not trainer.train_dataset.infinite
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
|
||||
# make sure the trainer did 5 steps
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4"))
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4")
|
||||
|
||||
def test_sft_trainer_with_model_neftune(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -634,8 +699,8 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
torch.random.manual_seed(24)
|
||||
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
|
||||
|
||||
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
|
||||
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)
|
||||
assert not torch.allclose(embeds_neftune, embeds_neftune_2)
|
||||
assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0
|
||||
|
||||
trainer.neftune_hook_handle.remove()
|
||||
|
||||
@ -643,7 +708,26 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
# Make sure forward pass works fine
|
||||
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
|
||||
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
|
||||
assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0
|
||||
|
||||
@require_peft
|
||||
def test_peft_sft_trainer_str(self):
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
_ = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=None,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
peft_config=peft_config,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
@require_peft
|
||||
def test_peft_sft_trainer(self):
|
||||
@ -675,16 +759,16 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
packing=True,
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(trainer.model, PeftModel))
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
@require_peft
|
||||
def test_peft_sft_trainer_gc(self):
|
||||
@ -717,16 +801,16 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
packing=True,
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(trainer.model, PeftModel))
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
@require_peft
|
||||
def test_peft_sft_trainer_neftune(self):
|
||||
@ -761,7 +845,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.model = trainer._trl_activate_neftune(trainer.model)
|
||||
|
||||
self.assertTrue(isinstance(trainer.model, PeftModel))
|
||||
assert isinstance(trainer.model, PeftModel)
|
||||
|
||||
device = trainer.model.get_input_embeddings().weight.device
|
||||
trainer.model.train()
|
||||
@ -772,20 +856,127 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
torch.random.manual_seed(24)
|
||||
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
|
||||
|
||||
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
|
||||
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)
|
||||
assert not torch.allclose(embeds_neftune, embeds_neftune_2)
|
||||
assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0
|
||||
|
||||
trainer.neftune_hook_handle.remove()
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
||||
# Make sure forward pass works fine to check if embeddings forward is not broken.
|
||||
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
|
||||
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
|
||||
assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0
|
||||
|
||||
@require_peft
|
||||
def test_peft_sft_trainer_tag(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
max_steps=4,
|
||||
eval_steps=2,
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
peft_config=peft_config,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
||||
@require_peft
|
||||
def test_sft_trainer_tag(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
max_steps=4,
|
||||
eval_steps=2,
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
||||
def test_sft_trainer_eval_packing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
max_steps=4,
|
||||
eval_steps=2,
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
eval_dataset=self.dummy_chatml_dataset,
|
||||
packing=True,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
eval_packing=False,
|
||||
)
|
||||
|
||||
assert len(trainer.train_dataset["input_ids"]) == 1
|
||||
assert len(trainer.eval_dataset["input_ids"]) != 1
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
eval_dataset=self.dummy_chatml_dataset,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
)
|
||||
|
||||
assert len(trainer.train_dataset["input_ids"]) == 1
|
||||
assert len(trainer.eval_dataset["input_ids"]) == 1
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
eval_dataset=self.dummy_chatml_dataset,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
packing=False,
|
||||
)
|
||||
|
||||
assert len(trainer.train_dataset["input_ids"]) != 1
|
||||
assert len(trainer.eval_dataset["input_ids"]) != 1
|
||||
|
@ -15,7 +15,13 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
|
||||
from trl import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_peft_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
@ -27,6 +33,15 @@ def require_peft(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires bnb. Skips the test if bnb is not available.
|
||||
"""
|
||||
if not is_bitsandbytes_available():
|
||||
test_case = unittest.skip("test requires bnb")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_diffusers(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
|
||||
@ -55,17 +70,6 @@ def require_no_wandb(test_case):
|
||||
return require_wandb(test_case, required=False)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
|
||||
"""
|
||||
try:
|
||||
import bitsandbytes # noqa: F401
|
||||
except ImportError:
|
||||
test_case = unittest.skip("test requires bitsandbytes")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs.
|
||||
@ -75,6 +79,15 @@ def require_torch_multi_gpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires GPUs. Skips the test if there is no GPU.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
test_case = unittest.skip("test requires GPU")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_multi_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires multiple XPUs. Skips the test if there aren't enough XPUs.
|
||||
|
161
trl/__init__.py
161
trl/__init__.py
@ -1,40 +1,133 @@
|
||||
# flake8: noqa
|
||||
|
||||
__version__ = "0.7.5.dev0"
|
||||
__version__ = "0.8.0"
|
||||
|
||||
from .core import set_seed
|
||||
from .environment import TextEnvironment, TextHistory
|
||||
from .extras import BestOfNSampler
|
||||
from .import_utils import (
|
||||
is_diffusers_available,
|
||||
is_npu_available,
|
||||
is_peft_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from .models import (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PreTrainedModelWrapper,
|
||||
create_reference_model,
|
||||
)
|
||||
from .trainer import (
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
DPOTrainer,
|
||||
IterativeSFTTrainer,
|
||||
PPOConfig,
|
||||
PPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
SFTTrainer,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
|
||||
_import_structure = {
|
||||
"core": [
|
||||
"set_seed",
|
||||
],
|
||||
"environment": [
|
||||
"TextEnvironment",
|
||||
"TextHistory",
|
||||
],
|
||||
"extras": [
|
||||
"BestOfNSampler",
|
||||
],
|
||||
"import_utils": [
|
||||
"is_bitsandbytes_available",
|
||||
"is_diffusers_available",
|
||||
"is_npu_available",
|
||||
"is_peft_available",
|
||||
"is_wandb_available",
|
||||
"is_xpu_available",
|
||||
],
|
||||
"models": [
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
"AutoModelForSeq2SeqLMWithValueHead",
|
||||
"PreTrainedModelWrapper",
|
||||
"create_reference_model",
|
||||
"setup_chat_format",
|
||||
"SUPPORTED_ARCHITECTURES",
|
||||
],
|
||||
"trainer": [
|
||||
"DataCollatorForCompletionOnlyLM",
|
||||
"DPOTrainer",
|
||||
"IterativeSFTTrainer",
|
||||
"KTOConfig",
|
||||
"KTOTrainer",
|
||||
"ModelConfig",
|
||||
"PPOConfig",
|
||||
"PPOTrainer",
|
||||
"RewardConfig",
|
||||
"RewardTrainer",
|
||||
"SFTTrainer",
|
||||
],
|
||||
"commands": [],
|
||||
"commands.utils": ["SftArgumentParser", "init_zero_verbose", "TrlParser", "DpoArgumentParser"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
|
||||
"multitask_prompt_tuning": [
|
||||
"MultitaskPromptEmbedding",
|
||||
"MultitaskPromptTuningConfig",
|
||||
"MultitaskPromptTuningInit",
|
||||
],
|
||||
}
|
||||
|
||||
if is_diffusers_available():
|
||||
from .models import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"DDPOPipelineOutput",
|
||||
"DDPOSchedulerOutput",
|
||||
"DDPOStableDiffusionPipeline",
|
||||
"DefaultDDPOStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import set_seed
|
||||
from .environment import TextEnvironment, TextHistory
|
||||
from .extras import BestOfNSampler
|
||||
from .import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_npu_available,
|
||||
is_peft_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from .models import (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PreTrainedModelWrapper,
|
||||
create_reference_model,
|
||||
setup_chat_format,
|
||||
SUPPORTED_ARCHITECTURES,
|
||||
)
|
||||
from .trainer import (
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
DPOTrainer,
|
||||
IterativeSFTTrainer,
|
||||
KTOConfig,
|
||||
KTOTrainer,
|
||||
ModelConfig,
|
||||
PPOConfig,
|
||||
PPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
SFTTrainer,
|
||||
)
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
|
||||
from .commands.utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
|
||||
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .models import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
from .trainer import DDPOConfig, DDPOTrainer
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={"__version__": __version__},
|
||||
)
|
||||
from .trainer import DDPOConfig, DDPOTrainer
|
||||
|
34
trl/commands/__init__.py
Normal file
34
trl/commands/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
# flake8: noqa
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# flake8: noqa
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule, OptionalDependencyNotAvailable
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"cli_utils": ["SftArgumentParser", "init_zero_verbose", "DpoScriptArguments", "TrlParser"],
|
||||
"config_parser": ["YamlConfigParser"],
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cli_utils import SftScriptArguments, init_zero_verbose, DpoScriptArguments, TrlParser
|
||||
from .config_parser import YamlConfigParser
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
71
trl/commands/cli.py
Normal file
71
trl/commands/cli.py
Normal file
@ -0,0 +1,71 @@
|
||||
# This file is a copy of trl/examples/scripts/sft.py so that we could
|
||||
# use it together with rich and the TRL CLI in a more customizable manner.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
# Make sure to import things locally to avoid verbose from third party libs.
|
||||
with console.status("[bold purple]Welcome! Initializing the TRL CLI..."):
|
||||
from trl.commands.cli_utils import init_zero_verbose
|
||||
|
||||
init_zero_verbose()
|
||||
|
||||
command_name = sys.argv[1]
|
||||
|
||||
if command_name not in SUPPORTED_COMMANDS:
|
||||
raise ValueError(
|
||||
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}"
|
||||
)
|
||||
|
||||
trl_examples_dir = os.path.dirname(__file__)
|
||||
|
||||
# Force-use rich
|
||||
os.environ["TRL_USE_RICH"] = "1"
|
||||
|
||||
if command_name == "chat":
|
||||
command = f"""
|
||||
python {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
|
||||
"""
|
||||
else:
|
||||
command = f"""
|
||||
accelerate launch {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
|
||||
"""
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
command.split(),
|
||||
text=True,
|
||||
check=True,
|
||||
encoding="utf-8",
|
||||
cwd=os.getcwd(),
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
except (CalledProcessError, ChildProcessError) as exc:
|
||||
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
|
||||
raise ValueError("TRL CLI failed! Check the traceback above..") from exc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
288
trl/commands/cli_utils.py
Normal file
288
trl/commands/cli_utils.py
Normal file
@ -0,0 +1,288 @@
|
||||
# This file is a copy of trl/examples/scripts/sft.py so that we could
|
||||
# use it together with rich and the TRL CLI in a more customizable manner.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
class YamlConfigParser:
|
||||
def __init__(self, config_path: str = None, dataclasses: List[Any] = None):
|
||||
self.config = None
|
||||
|
||||
if config_path is not None:
|
||||
with open(config_path) as yaml_file:
|
||||
self.config = yaml.safe_load(yaml_file)
|
||||
else:
|
||||
self.config = {}
|
||||
|
||||
if dataclasses is None:
|
||||
dataclasses = []
|
||||
|
||||
# We create a dummy training args to compare the values before / after
|
||||
# __post_init__
|
||||
# Here we import `TrainingArguments` from the local level to not
|
||||
# break TRL lazy imports.
|
||||
from transformers import TrainingArguments
|
||||
|
||||
self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args")
|
||||
|
||||
self.parse_and_set_env()
|
||||
self.merge_dataclasses(dataclasses)
|
||||
|
||||
def parse_and_set_env(self):
|
||||
if "env" in self.config:
|
||||
env_vars = self.config["env"]
|
||||
if isinstance(env_vars, dict):
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = str(value)
|
||||
else:
|
||||
raise ValueError("`env` field should be a dict in the YAML file.")
|
||||
|
||||
def merge_dataclasses(self, dataclasses):
|
||||
from transformers import TrainingArguments
|
||||
|
||||
dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses]
|
||||
|
||||
if len(self.config) > 0:
|
||||
for i, dataclass in enumerate(dataclasses):
|
||||
is_hf_training_args = False
|
||||
|
||||
for data_class_field in fields(dataclass):
|
||||
# Get the field here
|
||||
field_name = data_class_field.name
|
||||
field_value = getattr(dataclass, field_name)
|
||||
|
||||
if not isinstance(dataclass, TrainingArguments):
|
||||
default_value = data_class_field.default
|
||||
else:
|
||||
default_value = (
|
||||
getattr(self._dummy_training_args, field_name)
|
||||
if field_name != "output_dir"
|
||||
else field_name
|
||||
)
|
||||
is_hf_training_args = True
|
||||
|
||||
default_value_changed = field_value != default_value
|
||||
|
||||
if field_value is not None or field_name in self.config:
|
||||
if field_name in self.config:
|
||||
# In case the field value is not different from default, overwrite it
|
||||
if not default_value_changed:
|
||||
value_to_replace = self.config[field_name]
|
||||
setattr(dataclasses_copy[i], field_name, value_to_replace)
|
||||
# Otherwise do nothing
|
||||
|
||||
# Re-init `TrainingArguments` to handle all post-processing correctly
|
||||
if is_hf_training_args:
|
||||
init_signature = list(inspect.signature(TrainingArguments.__init__).parameters)
|
||||
dict_dataclass = asdict(dataclasses_copy[i])
|
||||
new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature}
|
||||
dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass)
|
||||
|
||||
return dataclasses_copy
|
||||
|
||||
def to_string(self):
|
||||
final_string = """"""
|
||||
for key, value in self.config.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
if len(value) != 0:
|
||||
value = str(value)
|
||||
value = value.replace("'", '"')
|
||||
value = f"'{value}'"
|
||||
else:
|
||||
continue
|
||||
|
||||
final_string += f"--{key} {value} "
|
||||
return final_string
|
||||
|
||||
|
||||
def init_zero_verbose():
|
||||
"""
|
||||
Perform zero verbose init - use this method on top of the CLI modules to make
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from rich.logging import RichHandler
|
||||
|
||||
FORMAT = "%(message)s"
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR)
|
||||
|
||||
# Custom warning handler to redirect warnings to the logging system
|
||||
def warning_handler(message, category, filename, lineno, file=None, line=None):
|
||||
logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}")
|
||||
|
||||
# Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well
|
||||
warnings.showwarning = warning_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
class SftScriptArguments:
|
||||
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
|
||||
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
|
||||
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
|
||||
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
|
||||
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
|
||||
gradient_checkpointing_use_reentrant: bool = field(
|
||||
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DpoScriptArguments:
|
||||
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
|
||||
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
|
||||
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
|
||||
max_target_length: int = field(
|
||||
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
|
||||
)
|
||||
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
|
||||
ignore_bias_buffers: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "debug argument for distributed training;"
|
||||
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
|
||||
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
|
||||
gradient_checkpointing_use_reentrant: bool = field(
|
||||
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatArguments:
|
||||
# general settings
|
||||
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"})
|
||||
user: str = field(default=None, metadata={"help": "Username to display in chat interface"})
|
||||
system_prompt: str = field(default=None, metadata={"help": "System prompt"})
|
||||
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"})
|
||||
device: str = field(
|
||||
default="cpu",
|
||||
metadata={"help": "device to use for inference."},
|
||||
)
|
||||
config: str = field(
|
||||
default="default",
|
||||
metadata={
|
||||
"help": "Config file used for setting the configs. If `default` uses examples/scripts/config/default_chat_config.yaml"
|
||||
},
|
||||
)
|
||||
examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."})
|
||||
# generation settings
|
||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"})
|
||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"})
|
||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"})
|
||||
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"})
|
||||
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"})
|
||||
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"})
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"})
|
||||
# model loading
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
torch_dtype: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||
"dtype will be automatically derived from the model's weights."
|
||||
),
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
|
||||
attn_implementation: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`"
|
||||
)
|
||||
},
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
|
||||
)
|
||||
load_in_4bit: bool = field(
|
||||
default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}
|
||||
)
|
||||
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
|
||||
|
||||
|
||||
class TrlParser(HfArgumentParser):
|
||||
def __init__(self, parsers):
|
||||
"""
|
||||
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
|
||||
parsers for users that pass a valid `config` field and merge the values that are set in the config
|
||||
with the processed parsers.
|
||||
|
||||
Args:
|
||||
parsers (`List[argparse.ArgumentParser`]):
|
||||
List of parsers.
|
||||
"""
|
||||
super().__init__(parsers)
|
||||
|
||||
def post_process_dataclasses(self, dataclasses):
|
||||
# Apply additional post-processing in case some arguments needs a special
|
||||
# care
|
||||
training_args = trl_args = None
|
||||
training_args_index = None
|
||||
|
||||
for i, dataclass_obj in enumerate(dataclasses):
|
||||
if dataclass_obj.__class__.__name__ == "TrainingArguments":
|
||||
training_args = dataclass_obj
|
||||
training_args_index = i
|
||||
elif dataclass_obj.__class__.__name__ in ("SftScriptArguments", "DpoScriptArguments"):
|
||||
trl_args = dataclass_obj
|
||||
else:
|
||||
...
|
||||
|
||||
if trl_args is not None and training_args is not None:
|
||||
training_args.gradient_checkpointing_kwargs = dict(
|
||||
use_reentrant=trl_args.gradient_checkpointing_use_reentrant
|
||||
)
|
||||
dataclasses[training_args_index] = training_args
|
||||
|
||||
return dataclasses
|
||||
|
||||
def parse_args_and_config(self):
|
||||
dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
# Pop the last element which should be the remaining strings
|
||||
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])
|
||||
return dataclasses
|
||||
|
||||
def update_dataclasses_with_config(self, dataclasses):
|
||||
self.config_parser = None
|
||||
for parser_dataclass in dataclasses:
|
||||
if hasattr(parser_dataclass, "config"):
|
||||
if self.config_parser is not None:
|
||||
raise ValueError("You passed the `config` field twice! Make sure to pass `config` only once.")
|
||||
self.config_parser = YamlConfigParser(parser_dataclass.config)
|
||||
|
||||
if self.config_parser is not None:
|
||||
dataclasses = self.config_parser.merge_dataclasses(dataclasses)
|
||||
dataclasses = self.post_process_dataclasses(dataclasses)
|
||||
return dataclasses
|
135
trl/core.py
135
trl/core.py
@ -15,13 +15,14 @@ import gc
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
from .import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
@ -29,30 +30,66 @@ from .import_utils import is_npu_available, is_xpu_available
|
||||
try:
|
||||
from collections.abc import Mapping
|
||||
except ImportError:
|
||||
from collections import Mapping
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
WANDB_PADDING = -1
|
||||
|
||||
|
||||
def flatten_dict(nested, sep="/"):
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
top_k: int = 0,
|
||||
top_p: float = 1.0,
|
||||
filter_value: float = -float("Inf"),
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
|
||||
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
top_k (`int`, *optional*, defaults to 0):
|
||||
If > 0, only keep the top k tokens with highest probability (top-k filtering)
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
|
||||
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimumber of tokens we keep per batch example in the output.
|
||||
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
|
||||
if top_k > 0:
|
||||
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
if 0 <= top_p <= 1.0:
|
||||
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
|
||||
"""Flatten dictionary and concatenate nested keys with separator."""
|
||||
|
||||
def rec(nest, prefix, into):
|
||||
def recurse(nest: Dict, prefix: str, into: Dict) -> None:
|
||||
for k, v in nest.items():
|
||||
if sep in k:
|
||||
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
|
||||
if isinstance(v, Mapping):
|
||||
rec(v, prefix + k + sep, into)
|
||||
recurse(v, prefix + k + sep, into)
|
||||
else:
|
||||
into[prefix + k] = v
|
||||
|
||||
flat = {}
|
||||
rec(nested, "", flat)
|
||||
recurse(nested, "", flat)
|
||||
return flat
|
||||
|
||||
|
||||
def convert_to_scalar(stats):
|
||||
def convert_to_scalar(stats: Dict) -> Dict:
|
||||
"""
|
||||
Converts the stats from a flattened dict to single scalar dicts
|
||||
"""
|
||||
@ -68,7 +105,7 @@ def convert_to_scalar(stats):
|
||||
return tensorboard_stats
|
||||
|
||||
|
||||
def stack_dicts(stats_dicts):
|
||||
def stack_dicts(stats_dicts: List[Dict]) -> Dict:
|
||||
"""Stack the values of a dict."""
|
||||
results = dict()
|
||||
for k in stats_dicts[0]:
|
||||
@ -77,12 +114,12 @@ def stack_dicts(stats_dicts):
|
||||
return results
|
||||
|
||||
|
||||
def add_suffix(input_dict, suffix):
|
||||
def add_suffix(input_dict: Dict, suffix: str) -> Dict:
|
||||
"""Add suffix to dict keys."""
|
||||
return dict((k + suffix, v) for k, v in input_dict.items())
|
||||
return {k + suffix: v for k, v in input_dict.items()}
|
||||
|
||||
|
||||
def pad_to_size(tensor, size, dim=1, padding=50256):
|
||||
def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
|
||||
"""Pad tensor to size."""
|
||||
t_size = tensor.size()[dim]
|
||||
if t_size == size:
|
||||
@ -91,7 +128,7 @@ def pad_to_size(tensor, size, dim=1, padding=50256):
|
||||
return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)
|
||||
|
||||
|
||||
def logprobs_from_logits(logits, labels, gather=True):
|
||||
def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
|
||||
"""
|
||||
@ -103,7 +140,7 @@ def logprobs_from_logits(logits, labels, gather=True):
|
||||
return logpy
|
||||
|
||||
|
||||
def whiten(values, shift_mean=True):
|
||||
def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
|
||||
"""Whiten values."""
|
||||
mean, var = torch.mean(values), torch.var(values)
|
||||
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
|
||||
@ -112,7 +149,7 @@ def whiten(values, shift_mean=True):
|
||||
return whitened
|
||||
|
||||
|
||||
def masked_mean(values, mask, axis=None):
|
||||
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
|
||||
"""Compute mean of tensor with a masked values."""
|
||||
if axis is not None:
|
||||
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
|
||||
@ -120,7 +157,7 @@ def masked_mean(values, mask, axis=None):
|
||||
return (values * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
def masked_var(values, mask, unbiased=True):
|
||||
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
|
||||
"""Compute variance of tensor with masked values."""
|
||||
mean = masked_mean(values, mask)
|
||||
centered_values = values - mean
|
||||
@ -139,7 +176,7 @@ def masked_var(values, mask, unbiased=True):
|
||||
return variance
|
||||
|
||||
|
||||
def masked_whiten(values, mask, shift_mean=True):
|
||||
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
|
||||
"""Whiten values with masked values."""
|
||||
mean, var = masked_mean(values, mask), masked_var(values, mask)
|
||||
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
|
||||
@ -148,23 +185,23 @@ def masked_whiten(values, mask, shift_mean=True):
|
||||
return whitened
|
||||
|
||||
|
||||
def clip_by_value(x, tensor_min, tensor_max):
|
||||
def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
|
||||
"""
|
||||
Tensor extenstion to torch.clamp
|
||||
Tensor extension to torch.clamp
|
||||
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
|
||||
"""
|
||||
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
|
||||
return clipped
|
||||
|
||||
|
||||
def entropy_from_logits(logits):
|
||||
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate entropy from logits."""
|
||||
pd = torch.nn.functional.softmax(logits, dim=-1)
|
||||
entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
def average_torch_dicts(list_of_dicts):
|
||||
def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
|
||||
"""Average values of a list of dicts with torch tensors."""
|
||||
average_dict = dict()
|
||||
for key in list_of_dicts[0].keys():
|
||||
@ -172,7 +209,7 @@ def average_torch_dicts(list_of_dicts):
|
||||
return average_dict
|
||||
|
||||
|
||||
def stats_to_np(stats_dict):
|
||||
def stats_to_np(stats_dict: Dict) -> Dict:
|
||||
"""Cast all torch.tensors in dict to numpy arrays."""
|
||||
new_dict = dict()
|
||||
for k, v in stats_dict.items():
|
||||
@ -188,40 +225,12 @@ def stats_to_np(stats_dict):
|
||||
return new_dict
|
||||
|
||||
|
||||
def listify_batch(tensor):
|
||||
"""Turns the first dimension of a tensor into a list."""
|
||||
return [tensor[i] for i in range(tensor.shape[0])]
|
||||
|
||||
|
||||
def build_bert_batch_from_txt(text_list, tokenizer, device):
|
||||
"""Create token id and attention mask tensors from text list for BERT classification."""
|
||||
|
||||
# tokenize
|
||||
tensors = [tokenizer.encode(txt, return_tensors="pt").to(device) for txt in text_list]
|
||||
|
||||
# find max length to pad to
|
||||
max_len = max([t.size()[1] for t in tensors])
|
||||
|
||||
# get padded tensors and attention masks
|
||||
# (attention masks make bert ignore padding)
|
||||
padded_tensors = []
|
||||
attention_masks = []
|
||||
for tensor in tensors:
|
||||
attention_mask = torch.ones(tensor.size(), device=device)
|
||||
padded_tensors.append(pad_to_size(tensor, max_len, padding=0))
|
||||
attention_masks.append(pad_to_size(attention_mask, max_len, padding=0))
|
||||
|
||||
# stack all tensors
|
||||
padded_tensors = torch.cat(padded_tensors)
|
||||
attention_masks = torch.cat(attention_masks)
|
||||
|
||||
return padded_tensors, attention_masks
|
||||
|
||||
|
||||
def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0):
|
||||
def respond_to_batch(
|
||||
model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
|
||||
) -> torch.LongTensor:
|
||||
"""Sample text from language model."""
|
||||
input_ids = queries
|
||||
for i in range(txt_len):
|
||||
for _i in range(txt_len):
|
||||
# Get Logits
|
||||
outputs = model(input_ids)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
@ -233,7 +242,7 @@ def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0):
|
||||
return input_ids[:, -txt_len:]
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
def set_seed(seed: int) -> None:
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
|
||||
|
||||
@ -256,14 +265,14 @@ class LengthSampler:
|
||||
Samples a length
|
||||
"""
|
||||
|
||||
def __init__(self, min_value, max_value):
|
||||
def __init__(self, min_value: int, max_value: int):
|
||||
self.values = list(range(min_value, max_value))
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> int:
|
||||
return np.random.choice(self.values)
|
||||
|
||||
|
||||
class PPODecorators(object):
|
||||
class PPODecorators:
|
||||
optimize_device_cache = False
|
||||
|
||||
@classmethod
|
||||
@ -287,11 +296,11 @@ class PPODecorators(object):
|
||||
|
||||
def randn_tensor(
|
||||
shape: Union[Tuple, List],
|
||||
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
dtype: Optional["torch.dtype"] = None,
|
||||
layout: Optional["torch.layout"] = None,
|
||||
):
|
||||
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
) -> torch.Tensor:
|
||||
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
|
||||
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
|
||||
is always created on the CPU.
|
||||
|
@ -1,3 +1,14 @@
|
||||
# flake8: noqa
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule
|
||||
|
||||
from .base_environment import TextEnvironment, TextHistory
|
||||
_import_structure = {
|
||||
"base_environment": ["TextEnvironment", "TextHistory"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base_environment import TextEnvironment, TextHistory
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import extract_model_from_parallel
|
||||
@ -45,7 +46,7 @@ class StringStoppingCriteria(StoppingCriteria):
|
||||
done = []
|
||||
|
||||
for i, decoded_generation in enumerate(decoded_generations):
|
||||
sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
|
||||
sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings)
|
||||
done.append(sequence_complete)
|
||||
if not sequence_complete:
|
||||
self.generated_tokens[i] += 1
|
||||
@ -242,7 +243,7 @@ class TextEnvironment:
|
||||
if isinstance(tools, dict):
|
||||
self.tools = tools
|
||||
else:
|
||||
self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
|
||||
self.tools = {tool.__class__.__name__: tool for tool in tools}
|
||||
self.reward_fn = reward_fn
|
||||
self.max_length = max_length
|
||||
self.request_token = "<request>"
|
||||
@ -277,7 +278,7 @@ class TextEnvironment:
|
||||
|
||||
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
|
||||
|
||||
while any([not history.completed for history in histories]) and turns < self.max_turns:
|
||||
while any(not history.completed for history in histories) and turns < self.max_turns:
|
||||
histories = self.generate(histories)
|
||||
histories = self.tasks_end_check(histories)
|
||||
# TODO: make this parallel rather than for-loop
|
||||
@ -416,7 +417,7 @@ class TextEnvironment:
|
||||
self,
|
||||
query_tensors,
|
||||
batch_size: int = 16,
|
||||
pad_to_multiple_of: int = None,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Generate responses for a list of query tensors.
|
||||
|
@ -13,4 +13,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .best_of_n_sampler import BestOfNSampler
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"best_of_n_sampler": ["BestOfNSampler"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .best_of_n_sampler import BestOfNSampler
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -7,7 +7,7 @@ from ..core import set_seed
|
||||
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
|
||||
|
||||
|
||||
class BestOfNSampler(object):
|
||||
class BestOfNSampler:
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModelWrapper,
|
||||
|
88
trl/extras/dataset_formatting.py
Normal file
88
trl/extras/dataset_formatting.py
Normal file
@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
from datasets import Dataset, Value
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ..trainer.utils import ConstantLengthDataset
|
||||
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
|
||||
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
|
||||
}
|
||||
|
||||
|
||||
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
|
||||
r"""
|
||||
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
|
||||
apply chat template to the dataset
|
||||
"""
|
||||
|
||||
def format_dataset(examples):
|
||||
if isinstance(examples[messages_field][0], list):
|
||||
output_texts = []
|
||||
for i in range(len(examples[messages_field])):
|
||||
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
|
||||
return output_texts
|
||||
else:
|
||||
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
|
||||
|
||||
return format_dataset
|
||||
|
||||
|
||||
def instructions_formatting_function(tokenizer: AutoTokenizer):
|
||||
r"""
|
||||
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
|
||||
apply chat template to the dataset
|
||||
"""
|
||||
|
||||
def format_dataset(examples):
|
||||
if isinstance(examples["prompt"], list):
|
||||
output_texts = []
|
||||
for i in range(len(examples["prompt"])):
|
||||
converted_sample = [
|
||||
{"role": "user", "content": examples["prompt"][i]},
|
||||
{"role": "assistant", "content": examples["completion"][i]},
|
||||
]
|
||||
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
|
||||
return output_texts
|
||||
else:
|
||||
converted_sample = [
|
||||
{"role": "user", "content": examples["prompt"]},
|
||||
{"role": "assistant", "content": examples["completion"]},
|
||||
]
|
||||
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
|
||||
|
||||
return format_dataset
|
||||
|
||||
|
||||
def get_formatting_func_from_dataset(
|
||||
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
|
||||
) -> Optional[Callable]:
|
||||
r"""
|
||||
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
|
||||
- `ChatML` with [{"role": str, "content": str}]
|
||||
- `instruction` with [{"prompt": str, "completion": str}]
|
||||
|
||||
Args:
|
||||
dataset (Dataset): User dataset
|
||||
tokenizer (AutoTokenizer): Tokenizer used for formatting
|
||||
|
||||
Returns:
|
||||
Callable: Formatting function if the dataset format is supported else None
|
||||
"""
|
||||
if isinstance(dataset, Dataset):
|
||||
if "messages" in dataset.features:
|
||||
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
|
||||
logging.info("Formatting dataset with chatml format")
|
||||
return conversations_formatting_function(tokenizer, "messages")
|
||||
if "conversations" in dataset.features:
|
||||
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
|
||||
logging.info("Formatting dataset with chatml format")
|
||||
return conversations_formatting_function(tokenizer, "conversations")
|
||||
elif dataset.features == FORMAT_MAPPING["instruction"]:
|
||||
logging.info("Formatting dataset with instruction format")
|
||||
return instructions_formatting_function(tokenizer)
|
||||
|
||||
return None
|
@ -12,7 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
@ -22,7 +27,11 @@ else:
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return importlib.util.find_spec("peft") is not None
|
||||
return find_spec("peft") is not None
|
||||
|
||||
|
||||
def is_unsloth_available() -> bool:
|
||||
return find_spec("unsloth") is not None
|
||||
|
||||
|
||||
def is_accelerate_greater_20_0() -> bool:
|
||||
@ -37,9 +46,16 @@ def is_accelerate_greater_20_0() -> bool:
|
||||
return accelerate_version >= "0.20.0"
|
||||
|
||||
|
||||
def is_transformers_greater_than(version: str) -> bool:
|
||||
_transformers_version = importlib.metadata.version("transformers")
|
||||
return _transformers_version > version
|
||||
def is_transformers_greater_than(current_version: str) -> bool:
|
||||
if _is_python_greater_3_8:
|
||||
from importlib.metadata import version
|
||||
|
||||
_transformers_version = version("transformers")
|
||||
else:
|
||||
import pkg_resources
|
||||
|
||||
_transformers_version = pkg_resources.get_distribution("transformers").version
|
||||
return _transformers_version > current_version
|
||||
|
||||
|
||||
def is_torch_greater_2_0() -> bool:
|
||||
@ -55,23 +71,26 @@ def is_torch_greater_2_0() -> bool:
|
||||
|
||||
|
||||
def is_diffusers_available() -> bool:
|
||||
return importlib.util.find_spec("diffusers") is not None
|
||||
return find_spec("diffusers") is not None
|
||||
|
||||
|
||||
def is_bitsandbytes_available() -> bool:
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
import torch
|
||||
|
||||
# bnb can be imported without GPU but is not usable.
|
||||
return find_spec("bitsandbytes") is not None and torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_torchvision_available() -> bool:
|
||||
return importlib.util.find_spec("torchvision") is not None
|
||||
return find_spec("torchvision") is not None
|
||||
|
||||
|
||||
def is_rich_available() -> bool:
|
||||
return importlib.util.find_spec("rich") is not None
|
||||
return find_spec("rich") is not None
|
||||
|
||||
|
||||
def is_wandb_available() -> bool:
|
||||
return importlib.util.find_spec("wandb") is not None
|
||||
return find_spec("wandb") is not None
|
||||
|
||||
|
||||
def is_xpu_available() -> bool:
|
||||
@ -80,7 +99,7 @@ def is_xpu_available() -> bool:
|
||||
|
||||
return accelerate.utils.is_xpu_available()
|
||||
else:
|
||||
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
||||
if find_spec("intel_extension_for_pytorch") is None:
|
||||
return False
|
||||
try:
|
||||
import torch
|
||||
@ -92,10 +111,74 @@ def is_xpu_available() -> bool:
|
||||
|
||||
def is_npu_available() -> bool:
|
||||
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
|
||||
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
|
||||
if find_spec("torch") is None or find_spec("torch_npu") is None:
|
||||
return False
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule
|
||||
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module = {}
|
||||
for key, values in import_structure.items():
|
||||
for value in values:
|
||||
self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
self._objects = {} if extra_objects is None else extra_objects
|
||||
self._name = name
|
||||
self._import_structure = import_structure
|
||||
|
||||
# Needed for autocompletion in an IDE
|
||||
def __dir__(self):
|
||||
result = super().__dir__()
|
||||
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
||||
for attr in self.__all__:
|
||||
if attr not in result:
|
||||
result.append(attr)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._objects:
|
||||
return self._objects[name]
|
||||
if name in self._modules:
|
||||
value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
else:
|
||||
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
try:
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
|
||||
f" traceback):\n{e}"
|
||||
) from e
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
||||
|
||||
|
||||
class OptionalDependencyNotAvailable(BaseException):
|
||||
"""Internally used error class for signalling an optional dependency was not found."""
|
||||
|
@ -13,22 +13,52 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .modeling_base import PreTrainedModelWrapper, create_reference_model
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
# flake8: noqa
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
|
||||
|
||||
SUPPORTED_ARCHITECTURES = (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
)
|
||||
_import_structure = {
|
||||
"modeling_base": ["PreTrainedModelWrapper", "create_reference_model"],
|
||||
"modeling_value_head": [
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
"AutoModelForSeq2SeqLMWithValueHead",
|
||||
],
|
||||
"utils": ["setup_chat_format", "SUPPORTED_ARCHITECTURES"],
|
||||
}
|
||||
|
||||
from ..import_utils import is_diffusers_available
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_sd_base"] = [
|
||||
"DDPOPipelineOutput",
|
||||
"DDPOSchedulerOutput",
|
||||
"DDPOStableDiffusionPipeline",
|
||||
"DefaultDDPOStableDiffusionPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_base import PreTrainedModelWrapper, create_reference_model
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
from .utils import setup_chat_format, SUPPORTED_ARCHITECTURES
|
||||
|
||||
if is_diffusers_available():
|
||||
from .modeling_sd_base import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_sd_base import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -15,6 +15,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -70,6 +71,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
supported_args: (`list`)
|
||||
The list of arguments that are supported by the wrapper class.
|
||||
"""
|
||||
|
||||
transformers_parent_class = None
|
||||
supported_args = None
|
||||
supported_modules = ("v_head",)
|
||||
@ -377,12 +379,12 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
)
|
||||
# load json
|
||||
if is_resuming_training:
|
||||
with open(index_file_name, "r") as f:
|
||||
with open(index_file_name) as f:
|
||||
index = json.load(f)
|
||||
# check filename with `v_head` or any known extra module:
|
||||
files_to_download = set()
|
||||
for k, v in index["weight_map"].items():
|
||||
if any([module in k for module in cls.supported_modules]):
|
||||
if any(module in k for module in cls.supported_modules):
|
||||
files_to_download.add(v)
|
||||
is_sharded = True
|
||||
|
||||
@ -459,7 +461,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
"adapter_model.bin",
|
||||
token=token,
|
||||
)
|
||||
except: # noqa
|
||||
except Exception:
|
||||
filename = os.path.join(adapter_model_id, "adapter_model.safetensors")
|
||||
safe_loading = True
|
||||
if not os.path.exists(filename):
|
||||
@ -469,10 +471,11 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
"adapter_model.safetensors",
|
||||
token=token,
|
||||
)
|
||||
except: # noqa
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
|
||||
)
|
||||
"Could not find adapter model in the Hub, "
|
||||
"make sure you have the correct adapter model id."
|
||||
) from exc
|
||||
else:
|
||||
local_filename = filename
|
||||
else:
|
||||
@ -484,7 +487,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
adapter_state_dict = loading_func(local_filename, **load_kwargs)
|
||||
|
||||
for score_name_candidate in cls.supported_rm_modules:
|
||||
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
|
||||
if any(score_name_candidate in name for name in adapter_state_dict.keys()):
|
||||
score_name = score_name_candidate
|
||||
# we have found the correct head name and can break
|
||||
break
|
||||
@ -497,7 +500,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
score_dict[key_name] = param.to(cls._get_current_device())
|
||||
|
||||
num_labels, hidden_dim = score_dict["weight"].shape
|
||||
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])
|
||||
has_bias = any("bias" in name for name in adapter_state_dict.keys())
|
||||
|
||||
score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
|
||||
device=cls._get_current_device(),
|
||||
@ -600,7 +603,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
|
||||
|
||||
def create_reference_model(
|
||||
model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None
|
||||
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None
|
||||
) -> PreTrainedModelWrapper:
|
||||
"""
|
||||
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
|
||||
@ -635,7 +638,7 @@ def create_reference_model(
|
||||
else:
|
||||
for pattern_candidate in LAYER_PATTERNS:
|
||||
pattern_candidate = pattern_candidate.format(layer=num_shared_layers)
|
||||
if any([pattern_candidate in name for name in parameter_names]):
|
||||
if any(pattern_candidate in name for name in parameter_names):
|
||||
pattern = pattern_candidate
|
||||
break
|
||||
|
||||
@ -647,7 +650,7 @@ def create_reference_model(
|
||||
unshared_param_list = []
|
||||
|
||||
shared_parameter = True
|
||||
for name, param in model.named_parameters():
|
||||
for name, _param in model.named_parameters():
|
||||
if pattern in name:
|
||||
shared_parameter = False
|
||||
if shared_parameter:
|
||||
@ -660,8 +663,7 @@ def create_reference_model(
|
||||
param = model.get_parameter(param_name)
|
||||
param.requires_grad = False
|
||||
|
||||
ref_param = ref_model.get_parameter(param_name) # noqa
|
||||
ref_param = param # noqa
|
||||
_ref_param = ref_model.get_parameter(param_name)
|
||||
|
||||
# for all other parameters just make sure they don't use gradients
|
||||
for param_name in unshared_param_list:
|
||||
|
@ -21,15 +21,20 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
|
||||
from ..core import randn_tensor
|
||||
from ..import_utils import is_peft_available
|
||||
from .sd_utils import convert_state_dict_to_diffusers
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPOPipelineOutput(object):
|
||||
class DDPOPipelineOutput:
|
||||
"""
|
||||
Output class for the diffusers pipeline to be finetuned with the DDPO trainer
|
||||
|
||||
@ -49,7 +54,7 @@ class DDPOPipelineOutput(object):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPOSchedulerOutput(object):
|
||||
class DDPOSchedulerOutput:
|
||||
"""
|
||||
Output class for the diffusers scheduler to be finetuned with the DDPO trainer
|
||||
|
||||
@ -64,7 +69,7 @@ class DDPOSchedulerOutput(object):
|
||||
log_probs: torch.Tensor
|
||||
|
||||
|
||||
class DDPOStableDiffusionPipeline(object):
|
||||
class DDPOStableDiffusionPipeline:
|
||||
"""
|
||||
Main class for the diffusers pipeline to be finetuned with the DDPO trainer
|
||||
"""
|
||||
@ -534,7 +539,11 @@ class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
|
||||
self.pretrained_revision = pretrained_model_revision
|
||||
|
||||
try:
|
||||
self.sd_pipeline.unet.load_attn_procs(pretrained_model_name, revision=pretrained_model_revision)
|
||||
self.sd_pipeline.load_lora_weights(
|
||||
pretrained_model_name,
|
||||
weight_name="pytorch_lora_weights.safetensors",
|
||||
revision=pretrained_model_revision,
|
||||
)
|
||||
self.use_lora = True
|
||||
except OSError:
|
||||
if use_lora:
|
||||
@ -583,7 +592,8 @@ class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
|
||||
|
||||
def save_pretrained(self, output_dir):
|
||||
if self.use_lora:
|
||||
self.sd_pipeline.unet.save_attn_procs(output_dir)
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet))
|
||||
self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
|
||||
self.sd_pipeline.save_pretrained(output_dir)
|
||||
|
||||
def set_progress_bar_config(self, *args, **kwargs):
|
||||
@ -591,34 +601,29 @@ class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
|
||||
|
||||
def get_trainable_layers(self):
|
||||
if self.use_lora:
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
for name in self.sd_pipeline.unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id]
|
||||
lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
self.sd_pipeline.unet.add_adapter(lora_config)
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
self.sd_pipeline.unet.set_attn_processor(lora_attn_procs)
|
||||
return AttnProcsLayers(self.sd_pipeline.unet.attn_processors)
|
||||
# To avoid accelerate unscaling problems in FP16.
|
||||
for param in self.sd_pipeline.unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
return self.sd_pipeline.unet
|
||||
else:
|
||||
return self.sd_pipeline.unet
|
||||
|
||||
def save_checkpoint(self, models, weights, output_dir):
|
||||
if len(models) != 1:
|
||||
raise ValueError("Given how the trainable params were set, this should be of length 1")
|
||||
if self.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||
self.sd_pipeline.unet.save_attn_procs(output_dir)
|
||||
if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None:
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0]))
|
||||
self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
|
||||
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||
models[0].save_pretrained(os.path.join(output_dir, "unet"))
|
||||
else:
|
||||
@ -627,15 +632,12 @@ class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
|
||||
def load_checkpoint(self, models, input_dir):
|
||||
if len(models) != 1:
|
||||
raise ValueError("Given how the trainable params were set, this should be of length 1")
|
||||
if self.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||
tmp_unet = UNet2DConditionModel.from_pretrained(
|
||||
self.pretrained_model,
|
||||
revision=self.pretrained_revision,
|
||||
subfolder="unet",
|
||||
if self.use_lora:
|
||||
lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict(
|
||||
input_dir, weight_name="pytorch_lora_weights.safetensors"
|
||||
)
|
||||
tmp_unet.load_attn_procs(input_dir)
|
||||
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
|
||||
del tmp_unet
|
||||
self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0])
|
||||
|
||||
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||
models[0].register_to_config(**load_model.config)
|
||||
|
@ -85,6 +85,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.
|
||||
|
||||
"""
|
||||
|
||||
transformers_parent_class = AutoModelForCausalLM
|
||||
lm_head_namings = ["lm_head", "embed_out"]
|
||||
supported_args = (
|
||||
@ -218,7 +219,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
return pretrained_model_state_dict
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
setattr(self.pretrained_model, "v_head", self.v_head)
|
||||
self.pretrained_model.v_head = self.v_head
|
||||
|
||||
return self.pretrained_model.push_to_hub(*args, **kwargs)
|
||||
|
||||
@ -276,6 +277,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to the `ValueHead` class.
|
||||
"""
|
||||
|
||||
transformers_parent_class = AutoModelForSeq2SeqLM
|
||||
lm_head_namings = ["lm_head", "embed_out", "output_projection"]
|
||||
supported_args = (
|
||||
@ -298,7 +300,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
|
||||
|
||||
def _has_lm_head(self):
|
||||
# check module names of all modules inside `pretrained_model` to find the language model head
|
||||
for name, module in self.pretrained_model.named_modules():
|
||||
for name, _module in self.pretrained_model.named_modules():
|
||||
if any(attribute in name for attribute in self.lm_head_namings):
|
||||
return True
|
||||
return False
|
||||
@ -374,7 +376,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
|
||||
return pretrained_model_state_dict
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
setattr(self.pretrained_model, "v_head", self.v_head)
|
||||
self.pretrained_model.v_head = self.v_head
|
||||
|
||||
return self.pretrained_model.push_to_hub(*args, **kwargs)
|
||||
|
||||
|
150
trl/models/sd_utils.py
Normal file
150
trl/models/sd_utils.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
State dict utilities: utility methods for converting state dicts easily
|
||||
File copied from diffusers to avoid import issues and make TRL compatible
|
||||
with most of diffusers versions.
|
||||
"""
|
||||
import enum
|
||||
|
||||
|
||||
class StateDictType(enum.Enum):
|
||||
"""
|
||||
The mode to use when converting state dicts.
|
||||
"""
|
||||
|
||||
DIFFUSERS_OLD = "diffusers_old"
|
||||
PEFT = "peft"
|
||||
|
||||
|
||||
PEFT_TO_DIFFUSERS = {
|
||||
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
|
||||
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
|
||||
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
|
||||
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
|
||||
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
|
||||
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
|
||||
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
|
||||
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
|
||||
"to_k.lora_A": "to_k.lora.down",
|
||||
"to_k.lora_B": "to_k.lora.up",
|
||||
"to_q.lora_A": "to_q.lora.down",
|
||||
"to_q.lora_B": "to_q.lora.up",
|
||||
"to_v.lora_A": "to_v.lora.down",
|
||||
"to_v.lora_B": "to_v.lora.up",
|
||||
"to_out.0.lora_A": "to_out.0.lora.down",
|
||||
"to_out.0.lora_B": "to_out.0.lora.up",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
|
||||
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
|
||||
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
|
||||
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
|
||||
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
|
||||
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
||||
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
|
||||
DIFFUSERS_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
|
||||
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
|
||||
}
|
||||
|
||||
|
||||
KEYS_TO_ALWAYS_REPLACE = {
|
||||
".processor.": ".",
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict(state_dict, mapping):
|
||||
r"""
|
||||
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
mapping (`dict[str, str]`):
|
||||
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
|
||||
- key: the pattern to replace
|
||||
- value: the pattern to replace with
|
||||
|
||||
Returns:
|
||||
converted_state_dict (`dict`)
|
||||
The converted state dict.
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
# First, filter out the keys that we always want to replace
|
||||
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
|
||||
if pattern in k:
|
||||
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
|
||||
k = k.replace(pattern, new_pattern)
|
||||
|
||||
for pattern in mapping.keys():
|
||||
if pattern in k:
|
||||
new_pattern = mapping[pattern]
|
||||
k = k.replace(pattern, new_pattern)
|
||||
break
|
||||
converted_state_dict[k] = v
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
|
||||
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
|
||||
return the state dict as is.
|
||||
|
||||
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
peft_adapter_name = kwargs.pop("adapter_name", None)
|
||||
if peft_adapter_name is not None:
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
else:
|
||||
peft_adapter_name = ""
|
||||
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
# nothing to do
|
||||
return state_dict
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user