[dev] feat: immigrate from yapf & pylint to ruff based on pre-commit (#1010)

> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and
`pre-commit` as the managing
tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and
`pylint`, simply merging might trigger overwhelming linting errors,
while **you are only expected to resolve ones in the files related to
your PR**.
>
> To resolve this issue, please try the following workaround to only
include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check
--fix && ruff-format`
> 2. Squash into a single commit in a new branch: `git reset --soft
$(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`

We add the reminder above to the documentation to tell contributors how
to avoid overwhelming linting errors.

### Motivation

According to dicussion in #896, this PR immigrates from yapf & pylint to
ruff based on pre-commit, which allows unified version control and
automatic hook on committing.

### Summary

The `pre-commit` hook and CI

- checks staged / committed files in commits / PR's
- checks all files each month (This should fail before we fix all the
files by the ruff standard)

### Explanation for the Failing CI Workflow `pre-commit`

For now, we only apply `ruff format` and `ruff check --fix` **without
resolving all the errors**, since there are too many errors to resolve,
which causes the CI workflow `pre-commit` fails.

For resolving the remaining errors, we leave to future commits.
Specifically, the `pre-commit` hook and CI will require every commit to
fix its related files with `ruff`, which will fix all the files
incrementally.

### Reviewing Suggestion

The commit
3d93f51ba8
is huge since we apply `ruff` to all the files. To review the main
changes, please check the commits before and after it.
This commit is contained in:
Shawn/Yuxuan Tong
2025-04-18 22:49:31 +08:00
committed by GitHub
parent c98fb3197b
commit b00f77d855
268 changed files with 10660 additions and 9233 deletions

30
.github/workflows/pre-commit-full.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: pre-commit-full
# Run weekly on Sunday at 00:00 UTC
on:
schedule:
- cron: '0 0 * * 0'
# Allow manual triggering
workflow_dispatch:
# Declare permissions just read content.
permissions:
contents: read
jobs:
pre-commit-full:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/action@v3.0.1
env:
RUFF_OUTPUT_FORMAT: github
with:
extra_args: --all-files

30
.github/workflows/pre-commit.yml vendored Normal file
View File

@ -0,0 +1,30 @@
# c.f. https://github.com/pre-commit/action?tab=readme-ov-file#using-this-action
name: pre-commit
# No need to avoid / cancel lightweight pre-commit jobs
on:
pull_request:
push:
branches:
- main
- v0.2.x
# Declare permissions just read content.
permissions:
contents: read
jobs:
pre-commit:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/action@v3.0.1
env:
RUFF_OUTPUT_FORMAT: github

View File

@ -1,40 +0,0 @@
name: Pylint Check
on:
push:
paths:
- '**.py'
- 'requirements.txt'
- 'pyproject.toml'
pull_request:
paths:
- '**.py'
- 'requirements.txt'
- 'pyproject.toml'
jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install pylint (version from requirements.txt)
run: |
PYLINT_VERSION=$(grep '^pylint' requirements.txt)
if [ -z "$PYLINT_VERSION" ]; then
echo "No pylint version found in requirements.txt"
exit 1
fi
# only install pylint to avoid dependency problems on CPU
pip install "$PYLINT_VERSION"
- name: Run pylint
run: |
pylint --recursive=y --rcfile=pyproject.toml ./

View File

@ -1,56 +0,0 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- .github/workflows/yapf_format.yml
pull_request:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- .github/workflows/yapf_format.yml
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
# Declare permissions just read content.
permissions:
contents: read
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# - name: checkout
# run: |
# commits=${{ github.event.pull_request.commits }}
# if [[ -n "$commits" ]]; then
# # Prepare enough depth for diffs with main
# git fetch --depth="$(( commits + 1 ))"
# fi
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --upgrade yapf
pip install toml==0.10.2
- name: Running yapf
run: |
yapf -r -vv -d --style=./.style.yapf verl tests examples recipe

8
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,8 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.11.4"
hooks:
- id: ruff
pass_filenames: true
entry: bash -c 'ruff check --fix --show-fixes --output-format=${RUFF_OUTPUT_FORMAT:-full} "$@"'
- id: ruff-format

View File

@ -1,5 +0,0 @@
[style]
based_on_style = google
column_limit = 120
indent_width = 4
split_arguments_when_comma_terminated: true

View File

@ -1,9 +1,8 @@
{
"pylint.enabled": true,
"[python]": {
"editor.defaultFormatter": "eeyore.yapf",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.codeActionsOnSave": {
"source.organizeImports": "never",
"source.organizeImports": "always",
}
}
}

View File

@ -170,16 +170,34 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The
- [all-hands/openhands-lm-32b-v0.1](https://www.all-hands.dev/blog/introducing-openhands-lm-32b----a-strong-open-coding-agent-model): A strong, open coding agent model, trained with [multi-turn fine-tuning](https://github.com/volcengine/verl/pull/195)
## Contribution Guide
Contributions from the community are welcome! Please check out our [project roadmap](https://github.com/volcengine/verl/issues/710) and [good first issues](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22) to see where you can contribute.
### Code formatting
We use yapf (Google style) to enforce strict code formatting when reviewing PRs. To reformat your code locally, make sure you have installed the **latest** version of `yapf`
### Code Linting and Formatting
> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and `pre-commit` as the managing tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and `pylint`, simply merging might trigger overwhelming linting errors, while **you are only expected to resolve ones in the files related to your PR**.
>
> To resolve this issue, please try the following workaround to only include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check --fix && ruff-format`
> 2. Squash into a new single commit: `git reset --soft $(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`
We use pre-commit to help improve code quality. To initialize pre-commit, run:
```bash
pip3 install yapf --upgrade
pip install pre-commit
pre-commit install
```
Then, make sure you are at top level of verl repo and run
You can also manually run pre-commit by:
```bash
bash scripts/format.sh
pre-commit run
```
### Adding CI tests

View File

@ -3,12 +3,12 @@ FROM nvcr.io/nvidia/pytorch:24.05-py3
# uninstall nv-pytorch fork
RUN pip3 uninstall pytorch-quantization \
pytorch-triton \
torch \
torch-tensorrt \
torchvision \
xgboost transformer_engine flash_attn \
apex megatron-core -y
pytorch-triton \
torch \
torch-tensorrt \
torchvision \
xgboost transformer_engine flash_attn \
apex megatron-core -y
RUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
@ -38,7 +38,7 @@ RUN pip3 install --no-cache-dir \
'wandb'
# full dependencies
RUN pip3 install pytest yapf py-spy pyext liger-kernel
RUN pip3 install pytest pre-commit py-spy pyext liger-kernel
# =============== Megatron dependencies (optional) =================
# install Transformer Engine, which requires FA 2.5.8. Do it in a separate step for docker cache

View File

@ -51,7 +51,7 @@ RUN pip install --no-cache-dir "vllm==0.8.3" "torch==2.6.0" "torchvision==0.21.0
"transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \
"numpy<2.0.0" "pyarrow>=15.0.0" pandas \
ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \
pytest yapf py-spy pyext pre-commit ruff
pytest py-spy pyext pre-commit ruff
# Install flash-attn-2.7.4.post1 (cxx11abi=False)
RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \

View File

@ -27,7 +27,7 @@ RUN apt-get update && \
RUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata==0.11.0 \
transformers>=4.49.0 accelerate datasets peft hf-transfer \
ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \
pytest yapf py-spy pyext pre-commit ruff
pytest pre-commit py-spy pyext pre-commit ruff
# Install flash_attn-2.7.4.post1
RUN pip uninstall -y transformer-engine flash-attn && \

View File

@ -43,7 +43,7 @@ RUN pip install "sglang[all]==0.4.4.post4" --no-cache-dir --find-links https://f
RUN pip install --no-cache-dir torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \
transformers>=4.49.0 accelerate datasets peft hf_transfer \
ray codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel \
pytest yapf py-spy pyext
pytest pre-commit py-spy pyext
# Install flash_attn-2.7.4.post1
RUN pip uninstall -y transformer-engine flash-attn && \

View File

@ -31,43 +31,43 @@
# -- Project information -----------------------------------------------------
project = u'verl'
# pylint: disable=W0622
copyright = u'2024 ByteDance Seed Foundation MLSys Team'
author = u'Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin'
project = "verl"
copyright = "2024 ByteDance Seed Foundation MLSys Team"
author = "Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin"
# -- General configuration ---------------------------------------------------
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['recommonmark',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.autosectionlabel',
extensions = [
"recommonmark",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.autosectionlabel",
]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
source_suffix = ['.rst', 'rest', '.md']
source_suffix = [".rst", "rest", ".md"]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = u'en'
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
@ -75,9 +75,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]

View File

@ -113,14 +113,35 @@ verl is free software; you can redistribute it and/or modify it under the terms
of the Apache License 2.0. We welcome contributions.
Join us on `GitHub <https://github.com/volcengine/verl>`_, `Slack <https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA>`_ and `Wechat <https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG>`_ for discussions.
Code formatting
^^^^^^^^^^^^^^^^^^^^^^^^
We use yapf (Google style) to enforce strict code formatting when reviewing MRs. Run yapf at the top level of verl repo:
Contributions from the community are welcome! Please check out our `project roadmap <https://github.com/volcengine/verl/issues/710>`_ and `good first issues <https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22>`_ to see where you can contribute.
Code Linting and Formatting
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. warning::
We are `immigrating to ``ruff`` as the linter and formatter and ``pre-commit`` as the managing tool <https://github.com/volcengine/verl/pull/1010>`_.
If your branch is based on a previous commit using ``yapf`` and ``pylint``, simply merging might trigger overwhelming linting errors, while **you are only expected to resolve ones in the files related to your PR**.
To resolve this issue, please try the following workaround to only include the files you **really changed** in the PR:
1. In your branch, fix linting and format with ``ruff``: ``ruff check --fix && ruff-format``
2. Squash into a new single commit: ``git reset --soft $(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."``
3. Merge with the latest main: ``git merge origin/main``
4. Force push to your branch: ``git push --force``
We use pre-commit to help improve code quality. To initialize pre-commit, run:
.. code-block:: bash
pip3 install yapf
yapf -ir -vv --style ./.style.yapf verl examples tests
pip install pre-commit
pre-commit install
You can also manually run pre-commit by:
.. code-block:: bash
pre-commit run
Adding CI tests
^^^^^^^^^^^^^^^^^^^^^^^^
@ -129,4 +150,6 @@ If possible, please add CI test(s) for your new feature:
1. Find the most relevant workflow yml file, which usually corresponds to a ``hydra`` default config (e.g. ``ppo_trainer``, ``ppo_megatron_trainer``, ``sft_trainer``, etc).
2. Add related path patterns to the ``paths`` section if not already included.
3. Minimize the workload of the test script(s) (see existing scripts for examples).
3. Minimize the workload of the test script(s) (see existing scripts for examples).
We are HIRING! Send us an `email <mailto:haibin.lin@bytedance.com>`_ if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment.

View File

@ -16,131 +16,124 @@
- All the training data is used to train SFT and RL.
- Both chosen and rejected is used to train SFT
"""
import argparse
import os
import pandas as pd
from datasets import load_dataset
from tqdm.auto import tqdm
from verl.utils.fs import copy, makedirs
def generate_sft_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlh/sft'):
dataset = load_dataset('Dahoas/full-hh-rlhf')
output = {'prompt': [], 'response': []}
for data in tqdm(dataset['train']):
def generate_sft_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/sft"):
dataset = load_dataset("Dahoas/full-hh-rlhf")
output = {"prompt": [], "response": []}
for data in tqdm(dataset["train"]):
# add chosen
output['prompt'].append(data['prompt'])
output['response'].append(data['chosen'])
output["prompt"].append(data["prompt"])
output["response"].append(data["chosen"])
# add rejection
output['prompt'].append(data['prompt'])
output['response'].append(data['rejected'])
output["prompt"].append(data["prompt"])
output["response"].append(data["rejected"])
df = pd.DataFrame(output)
local_dir = os.path.expanduser(local_dir)
os.makedirs(local_dir, exist_ok=True)
local_path = os.path.join(local_dir, 'train.parquet')
local_path = os.path.join(local_dir, "train.parquet")
df.to_parquet(path=local_path)
if target_hdfs_path_dir is not None:
hdfs_dir = target_hdfs_path_dir + '/' + 'train.parquet'
hdfs_dir = target_hdfs_path_dir + "/" + "train.parquet"
makedirs(hdfs_dir)
copy(local_path, hdfs_dir)
def generate_rm_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlh/rm'):
train_dataset = load_dataset('Dahoas/full-hh-rlhf', split='train[:75%]')
test_dataset = load_dataset('Dahoas/full-hh-rlhf', split='train[-25%:]')
def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm"):
train_dataset = load_dataset("Dahoas/full-hh-rlhf", split="train[:75%]")
test_dataset = load_dataset("Dahoas/full-hh-rlhf", split="train[-25%:]")
local_dir = os.path.expanduser(local_dir)
os.makedirs(local_dir, exist_ok=True)
for dataset, name in zip([train_dataset, test_dataset], ['train', 'test']):
output = {'prompt': [], 'chosen': [], 'rejected': []}
for dataset, name in zip([train_dataset, test_dataset], ["train", "test"]):
output = {"prompt": [], "chosen": [], "rejected": []}
for data in tqdm(dataset):
# add chosen
output['prompt'].append(data['prompt'])
output['chosen'].append(data['chosen'])
output['rejected'].append(data['rejected'])
output["prompt"].append(data["prompt"])
output["chosen"].append(data["chosen"])
output["rejected"].append(data["rejected"])
df = pd.DataFrame(output)
local_path = os.path.join(local_dir, name + '.parquet')
local_path = os.path.join(local_dir, name + ".parquet")
df.to_parquet(path=local_path)
if target_hdfs_path_dir is not None:
hdfs_dir = target_hdfs_path_dir + '/' + name + '.parquet'
hdfs_dir = target_hdfs_path_dir + "/" + name + ".parquet"
makedirs(hdfs_dir)
copy(local_path, hdfs_dir)
def generate_rl_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlhf/rl'):
dataset = load_dataset('Dahoas/full-hh-rlhf')
train_dataset = dataset['train']
def generate_rl_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlhf/rl"):
dataset = load_dataset("Dahoas/full-hh-rlhf")
train_dataset = dataset["train"]
data_source = 'Dahoas/full-hh-rlhf'
data_source = "Dahoas/full-hh-rlhf"
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
prompt = example.pop('prompt')
response = example.pop('response')
prompt = example.pop("prompt")
response = example.pop("response")
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": prompt
}],
"prompt": [{"role": "user", "content": prompt}],
"ability": "alignment",
"reward_model": {
"style": "model",
"ground_truth": response # should not be used
"ground_truth": response, # should not be used
},
"extra_info": {
'split': split,
'index': idx
}
"extra_info": {"split": split, "index": idx},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
local_dir = os.path.expanduser(local_dir)
local_path = os.path.join(local_dir, 'train.parquet')
local_path = os.path.join(local_dir, "train.parquet")
train_dataset.to_parquet(local_path)
if target_hdfs_path_dir is not None:
hdfs_dir = target_hdfs_path_dir + '/' + 'train.parquet'
hdfs_dir = target_hdfs_path_dir + "/" + "train.parquet"
makedirs(hdfs_dir)
copy(local_path, hdfs_dir)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, choices=['sft', 'rm', 'rl'], required=True)
parser.add_argument('--local_dir', type=str, default='~/data/full_hh_rlhf')
parser.add_argument('--hdfs_dir', type=str, required=False, default=None)
parser.add_argument("--split", type=str, choices=["sft", "rm", "rl"], required=True)
parser.add_argument("--local_dir", type=str, default="~/data/full_hh_rlhf")
parser.add_argument("--hdfs_dir", type=str, required=False, default=None)
args = parser.parse_args()
if args.split == 'sft':
if args.split == "sft":
generate_sft_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))
elif args.split == 'rm':
elif args.split == "rm":
generate_rm_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))
elif args.split == 'rl':
elif args.split == "rl":
generate_rl_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))
else:
raise NotImplementedError

View File

@ -15,71 +15,70 @@
Preprocess the Geometry3k dataset to parquet format
"""
import argparse
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs
import argparse
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/geo3k')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument("--local_dir", default="~/data/geo3k")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
data_source = 'hiyouga/geometry3k'
data_source = "hiyouga/geometry3k"
dataset = datasets.load_dataset(data_source)
train_dataset = dataset['train']
test_dataset = dataset['test']
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = (
r'You FIRST think about the reasoning process as an internal monologue and then provide the final answer. '
r'The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.'
r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. "
r"The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."
)
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
problem = example.pop('problem')
prompt = problem + ' ' + instruction_following
answer = example.pop('answer')
images = example.pop('images')
problem = example.pop("problem")
prompt = problem + " " + instruction_following
answer = example.pop("answer")
images = example.pop("images")
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": prompt,
}],
"prompt": [
{
"role": "user",
"content": prompt,
}
],
"images": images,
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": answer
},
"reward_model": {"style": "rule", "ground_truth": answer},
"extra_info": {
'split': split,
'index': idx,
'answer': answer,
"split": split,
"index": idx,
"answer": answer,
"question": problem,
}
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True, num_proc=8)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True, num_proc=8)
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)

View File

@ -15,78 +15,77 @@
Preprocess the GSM8k dataset to parquet format
"""
import re
import argparse
import os
import re
import datasets
from verl.utils.hdfs_io import copy, makedirs
import argparse
def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split('#### ')[1].replace(',', '')
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/gsm8k')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument("--local_dir", default="~/data/gsm8k")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
data_source = 'openai/gsm8k'
data_source = "openai/gsm8k"
dataset = datasets.load_dataset(data_source, 'main')
dataset = datasets.load_dataset(data_source, "main")
train_dataset = dataset['train']
test_dataset = dataset['test']
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = "Let's think step by step and output the final answer after \"####\"."
instruction_following = 'Let\'s think step by step and output the final answer after "####".'
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = example.pop('question')
question_raw = example.pop("question")
question = question_raw + ' ' + instruction_following
question = question_raw + " " + instruction_following
answer_raw = example.pop('answer')
answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question,
}],
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
'split': split,
'index': idx,
'answer': answer_raw,
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
}
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)

View File

@ -16,12 +16,13 @@ Preprocess Hellaswag dataset.
"""
import re
import argparse
import os
import re
import datasets
from verl.utils.hdfs_io import copy, makedirs
import argparse
def preprocess(text):
@ -33,25 +34,24 @@ def preprocess(text):
return text
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='/opt/tiger/hellaswag')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument("--local_dir", default="/opt/tiger/hellaswag")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
data_source = 'Rowan/hellaswag'
data_source = "Rowan/hellaswag"
dataset = datasets.load_dataset(data_source, trust_remote_code=True)
train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]
instruction = 'Please complete the following sentence.\n'
instruction = "Please complete the following sentence.\n"
def make_map_fn(split):
def process_fn(doc, idx):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
query = preprocess(doc["activity_label"] + ": " + ctx)
@ -60,41 +60,35 @@ if __name__ == '__main__':
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": query
}],
"prompt": [{"role": "user", "content": query}],
"ability": "nlp",
"reward_model": {
"style": "model",
"eval": "multiple_choice", # using loglikelihood
"ground_truth": gold,
"choices": choices
"choices": choices,
},
"extra_info": {
'split': split,
'index': idx
}
"extra_info": {"split": split, "index": idx},
}
return data
return process_fn
# filter data that doesn't have a label
train_dataset = train_dataset.filter(lambda x: len(x['label']) > 0)
val_dataset = val_dataset.filter(lambda x: len(x['label']) > 0)
test_dataset = test_dataset.filter(lambda x: len(x['label']) > 0)
train_dataset = train_dataset.filter(lambda x: len(x["label"]) > 0)
val_dataset = val_dataset.filter(lambda x: len(x["label"]) > 0)
test_dataset = test_dataset.filter(lambda x: len(x["label"]) > 0)
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
val_dataset = val_dataset.map(function=make_map_fn('validation'), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
val_dataset = val_dataset.map(function=make_map_fn("validation"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
val_dataset.to_parquet(os.path.join(local_dir, 'validation.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
val_dataset.to_parquet(os.path.join(local_dir, "validation.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)

View File

@ -15,75 +15,65 @@
Preprocess the MATH-lighteval dataset to parquet format
"""
import argparse
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs
import argparse
from verl.utils.reward_score.math import remove_boxed, last_boxed_only_string
from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed
def extract_solution(solution_str):
return remove_boxed(last_boxed_only_string(solution_str))
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/math')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument("--local_dir", default="~/data/math")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
# 'lighteval/MATH' is no longer available on huggingface.
# Use mirror repo: DigitalLearningGmbH/MATH-lighteval
data_source = 'DigitalLearningGmbH/MATH-lighteval'
data_source = "DigitalLearningGmbH/MATH-lighteval"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = datasets.load_dataset(data_source, trust_remote_code=True)
train_dataset = dataset['train']
test_dataset = dataset['test']
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question = example.pop('problem')
question = example.pop("problem")
question = question + ' ' + instruction_following
question = question + " " + instruction_following
answer = example.pop('solution')
answer = example.pop("solution")
solution = extract_solution(answer)
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question
}],
"prompt": [{"role": "user", "content": question}],
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"extra_info": {
'split': split,
'index': idx
}
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {"split": split, "index": idx},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)

View File

@ -15,87 +15,71 @@
Create a simple multi-turn dataset for testing
"""
import os
import pandas as pd
import argparse
import os
import pandas as pd
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/multiturn')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument("--local_dir", default="~/data/multiturn")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
# Create example conversations
conversations = []
# Conversation 1
conversations.append({
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "What is the capital of France?"
}, {
"role": "assistant",
"content": "The capital of France is Paris."
}, {
"role": "user",
"content": "And what about Germany?"
}, {
"role": "assistant",
"content": "The capital of Germany is Berlin."
}]
})
conversations.append(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."},
{"role": "user", "content": "And what about Germany?"},
{"role": "assistant", "content": "The capital of Germany is Berlin."},
]
}
)
# Conversation 2
conversations.append({
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Can you explain quantum computing?"
}, {
"role":
"assistant",
"content":
"Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data."
}, {
"role": "user",
"content": "How is it different from classical computing?"
}, {
"role":
"assistant",
"content":
"Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition."
}]
})
conversations.append(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Can you explain quantum computing?"},
{
"role": "assistant",
"content": "Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data.",
},
{"role": "user", "content": "How is it different from classical computing?"},
{
"role": "assistant",
"content": "Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition.",
},
]
}
)
# Conversation 3
conversations.append({
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Write a simple Python function to calculate factorial."
}, {
"role":
"assistant",
"content":
"```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number."
}, {
"role": "user",
"content": "Can you make it iterative instead?"
}, {
"role":
"assistant",
"content":
"```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function."
}]
})
conversations.append(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a simple Python function to calculate factorial."},
{
"role": "assistant",
"content": "```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number.",
},
{"role": "user", "content": "Can you make it iterative instead?"},
{
"role": "assistant",
"content": "```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function.",
},
]
}
)
# Create train and test datasets
train_data = conversations[:2] # First 2 conversations for training
@ -109,13 +93,14 @@ def main():
train_df = pd.DataFrame(train_data)
test_df = pd.DataFrame(test_data)
train_df.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_df.to_parquet(os.path.join(local_dir, 'test.parquet'))
train_df.to_parquet(os.path.join(local_dir, "train.parquet"))
test_df.to_parquet(os.path.join(local_dir, "test.parquet"))
# Handle HDFS if specified
if args.hdfs_dir is not None:
try:
from verl.utils.hdfs_io import copy, makedirs
makedirs(args.hdfs_dir)
copy(src=local_dir, dst=args.hdfs_dir)
except ImportError:
@ -127,5 +112,5 @@ def main():
print(f"Data saved to {local_dir}")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -313,6 +313,7 @@
"outputs": [],
"source": [
"import torch\n",
"\n",
"try:\n",
" assert torch.cuda.is_available() is True\n",
" torch.ones(1, dtype=torch.bfloat16).cuda()\n",
@ -320,12 +321,10 @@
" print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n",
"\n",
"try:\n",
" import verl\n",
" pass\n",
"except Exception as e:\n",
" print(\"Please install verl via pip and restart the kernel\")\n",
" raise e\n",
"\n",
"import flash_attn"
" raise e"
]
},
{
@ -560,6 +559,7 @@
],
"source": [
"import inspect\n",
"\n",
"from verl.utils.reward_score.gsm8k import compute_score as gsm8k_reward\n",
"\n",
"print(inspect.getsource(gsm8k_reward))"

View File

@ -37,10 +37,12 @@
},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"import ray\n",
"import torch\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
@ -146,10 +148,10 @@
"class Accumulator:\n",
" def __init__(self):\n",
" self.value = 0\n",
" \n",
"\n",
" def add(self, x):\n",
" self.value += x\n",
" \n",
"\n",
" def get_value(self):\n",
" return self.value"
]
@ -184,7 +186,7 @@
}
],
"source": [
"value_ref = accumulator.get_value.remote() # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n",
"value_ref = accumulator.get_value.remote() # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n",
"# Get the value\n",
"value = ray.get(value_ref)\n",
"print(value)"
@ -232,8 +234,8 @@
},
"outputs": [],
"source": [
"from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n",
"from verl.single_controller.base import Worker"
"from verl.single_controller.base import Worker\n",
"from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool"
]
},
{
@ -259,16 +261,15 @@
"source": [
"@ray.remote\n",
"class GPUAccumulator(Worker):\n",
"\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" # The initial value of each rank is the same as the rank\n",
" self.value = torch.zeros(size=(1,), device='cuda') + self.rank\n",
" self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n",
"\n",
" def add(self, x):\n",
" self.value += x\n",
" print(f'rank {self.rank}, value: {self.value}')\n",
" return self.value.cpu()\n"
" print(f\"rank {self.rank}, value: {self.value}\")\n",
" return self.value.cpu()"
]
},
{
@ -291,7 +292,7 @@
"# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\n",
"class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\n",
"worker_group = RayWorkerGroup(resource_pool, class_with_args)\n",
"print(worker_group.execute_all_sync('add', x=[1,1,1,1]))"
"print(worker_group.execute_all_sync(\"add\", x=[1, 1, 1, 1]))"
]
},
{
@ -329,7 +330,7 @@
"outputs": [],
"source": [
"# Create a new resource pool and then merge the newly created resource pool with the previous one.\n",
"resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix='a')\n",
"resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix=\"a\")\n",
"resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)"
]
},
@ -365,7 +366,7 @@
],
"source": [
"# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\n",
"output_1 = worker_group_1.execute_all_sync('add', x=[2,2,2,2])\n",
"output_1 = worker_group_1.execute_all_sync(\"add\", x=[2, 2, 2, 2])\n",
"print(output_1)"
]
},
@ -387,7 +388,7 @@
],
"source": [
"# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\n",
"output_merge = worker_group_merge.execute_all_sync('add', x=[3,3,3,3,3,3,3,3])\n",
"output_merge = worker_group_merge.execute_all_sync(\"add\", x=[3, 3, 3, 3, 3, 3, 3, 3])\n",
"print(output_merge)"
]
},
@ -437,7 +438,7 @@
},
"outputs": [],
"source": [
"from verl.single_controller.base.decorator import register, Dispatch, Execute"
"from verl.single_controller.base.decorator import Dispatch, Execute, register"
]
},
{
@ -451,18 +452,17 @@
"source": [
"@ray.remote\n",
"class GPUAccumulatorDecorator(Worker):\n",
"\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" # The initial value of each rank is the same as the rank\n",
" self.value = torch.zeros(size=(1,), device='cuda') + self.rank\n",
" \n",
" self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n",
"\n",
" # map from a single input to all the worker\n",
" @register(Dispatch.ONE_TO_ALL)\n",
" def add(self, x):\n",
" print(x)\n",
" self.value = self.value + x\n",
" print(f'rank {self.rank}, value: {self.value}')\n",
" print(f\"rank {self.rank}, value: {self.value}\")\n",
" return self.value.cpu()"
]
},
@ -518,7 +518,7 @@
},
"outputs": [],
"source": [
"from verl.single_controller.base.decorator import register, Dispatch, collect_all_to_all, Execute"
"from verl.single_controller.base.decorator import Dispatch, collect_all_to_all, register"
]
},
{
@ -559,7 +559,7 @@
" def foo_rank_zero(self, x, y):\n",
" return self._x + y + x\n",
"\n",
" @register(dispatch_mode={'dispatch_fn': two_to_all_dispatch_fn, 'collect_fn': collect_all_to_all})\n",
" @register(dispatch_mode={\"dispatch_fn\": two_to_all_dispatch_fn, \"collect_fn\": collect_all_to_all})\n",
" def foo_custom(self, x, y):\n",
" return self._x + y + x"
]
@ -691,26 +691,24 @@
}
],
"source": [
"import os\n",
"import sys\n",
"import site\n",
"\n",
"current_pythonpath = os.environ.get(\"PYTHONPATH\", \"\")\n",
"\n",
"current_pythonpath = os.environ.get('PYTHONPATH', '')\n",
"\n",
"new_path = '/opt/tiger/Megatron-LM'\n",
"new_path = \"/opt/tiger/Megatron-LM\"\n",
"\n",
"if current_pythonpath:\n",
" new_pythonpath = f'{new_path}:{current_pythonpath}'\n",
" new_pythonpath = f\"{new_path}:{current_pythonpath}\"\n",
"else:\n",
" new_pythonpath = new_path\n",
"\n",
"os.environ['PYTHONPATH'] = new_pythonpath\n",
"os.environ[\"PYTHONPATH\"] = new_pythonpath\n",
"\n",
"print(new_path)\n",
"sys.path.append(new_path)\n",
"\n",
"import megatron\n",
"\n",
"print(megatron.__file__)"
]
},
@ -723,12 +721,13 @@
},
"outputs": [],
"source": [
"from verl.single_controller.base.decorator import register, Dispatch, Execute\n",
"from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n",
"from verl.single_controller.base.megatron.worker import MegatronWorker\n",
"from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n",
"from megatron.core import parallel_state as mpu\n",
"from omegaconf import OmegaConf\n",
"from megatron.core import parallel_state as mpu"
"\n",
"from verl.single_controller.base.decorator import Dispatch, Execute, register\n",
"from verl.single_controller.base.megatron.worker import MegatronWorker\n",
"from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n",
"from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup"
]
},
{
@ -756,52 +755,56 @@
"class MLPLayerWorker(MegatronWorker):\n",
" def __init__(self):\n",
" super().__init__()\n",
" rank = int(os.environ['LOCAL_RANK'])\n",
" rank = int(os.environ[\"LOCAL_RANK\"])\n",
" torch.distributed.init_process_group(backend=\"nccl\")\n",
" torch.cuda.set_device(rank)\n",
"\n",
" mpu.initialize_model_parallel(\n",
" tensor_model_parallel_size=4,\n",
" pipeline_model_parallel_size=1,\n",
" virtual_pipeline_model_parallel_size=None,\n",
" pipeline_model_parallel_split_rank=None,\n",
" use_sharp=False,\n",
" context_parallel_size=1,\n",
" expert_model_parallel_size=1,\n",
" nccl_communicator_config_path=None,\n",
" )\n",
" tensor_model_parallel_size=4,\n",
" pipeline_model_parallel_size=1,\n",
" virtual_pipeline_model_parallel_size=None,\n",
" pipeline_model_parallel_split_rank=None,\n",
" use_sharp=False,\n",
" context_parallel_size=1,\n",
" expert_model_parallel_size=1,\n",
" nccl_communicator_config_path=None,\n",
" )\n",
" from megatron.core import tensor_parallel\n",
" tensor_parallel.model_parallel_cuda_manual_seed(10)\n",
"\n",
" tensor_parallel.model_parallel_cuda_manual_seed(10)\n",
"\n",
" @register(Dispatch.ONE_TO_ALL)\n",
" def init_model(self, config):\n",
" from omegaconf import OmegaConf\n",
" from verl.utils.megatron_utils import init_model_parallel_config\n",
"\n",
" from verl.models.llama.megatron.layers import ParallelLlamaMLP\n",
" megatron_config = OmegaConf.create({\n",
" 'sequence_parallel': False,\n",
" 'param_dtype': 'fp32',\n",
" 'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),\n",
" 'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),\n",
" 'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),\n",
" 'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),\n",
" 'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()\n",
" })\n",
" from verl.utils.megatron_utils import init_model_parallel_config\n",
"\n",
" megatron_config = OmegaConf.create(\n",
" {\n",
" \"sequence_parallel\": False,\n",
" \"param_dtype\": \"fp32\",\n",
" \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n",
" \"pipeline_model_parallel_rank\": mpu.get_pipeline_model_parallel_rank(),\n",
" \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n",
" \"virtual_pipeline_model_parallel_rank\": mpu.get_virtual_pipeline_model_parallel_rank(),\n",
" \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n",
" }\n",
" )\n",
"\n",
" megatron_config = init_model_parallel_config(megatron_config)\n",
" self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\n",
" \n",
"\n",
" @register(Dispatch.ONE_TO_ALL)\n",
" def get_weights(self):\n",
" output = {}\n",
" for key, val in self.parallel_layer.named_parameters():\n",
" output[key] = val\n",
" return output\n",
" \n",
"\n",
" @register(Dispatch.MEGATRON_COMPUTE)\n",
" def run_layer(self, x):\n",
" x = x.to('cuda')\n",
" x = x.to(\"cuda\")\n",
" y = self.parallel_layer(x)\n",
" return y"
]
@ -816,9 +819,10 @@
"outputs": [],
"source": [
"layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\n",
"layer_worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool,\n",
" ray_cls_with_init=layer_cls,\n",
" )\n"
"layer_worker_group = NVMegatronRayWorkerGroup(\n",
" resource_pool=resource_pool,\n",
" ray_cls_with_init=layer_cls,\n",
")"
]
},
{
@ -855,13 +859,15 @@
"seq_len = 2048\n",
"hidden_size = 4096\n",
"\n",
"config = OmegaConf.create({\n",
" 'hidden_size': hidden_size,\n",
" 'intermediate_size': ffn_hidden_size,\n",
" 'hidden_act': 'silu',\n",
" 'pretraining_tp': 1,\n",
" 'tp': layer_worker_group.tp_size,\n",
"})"
"config = OmegaConf.create(\n",
" {\n",
" \"hidden_size\": hidden_size,\n",
" \"intermediate_size\": ffn_hidden_size,\n",
" \"hidden_act\": \"silu\",\n",
" \"pretraining_tp\": 1,\n",
" \"tp\": layer_worker_group.tp_size,\n",
" }\n",
")"
]
},
{
@ -916,7 +922,9 @@
}
],
"source": [
"output = layer_worker_group.run_layer([x]) # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n",
"output = layer_worker_group.run_layer(\n",
" [x]\n",
") # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n",
"print(output[0].shape)"
]
},

View File

@ -15,23 +15,23 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl import DataProto
import torch
from verl.utils.reward_score import gsm8k, math
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.reward_score import gsm8k, math
def _select_rm_score_fn(data_source):
if data_source == 'openai/gsm8k':
if data_source == "openai/gsm8k":
return gsm8k.compute_score
elif data_source == 'lighteval/MATH':
elif data_source == "lighteval/MATH":
return math.compute_score
else:
raise NotImplementedError
class RewardManager():
class RewardManager:
def __init__(self, tokenizer, num_examine) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
@ -40,35 +40,35 @@ class RewardManager():
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if 'rm_scores' in data.batch.keys():
return data.batch['rm_scores']
if "rm_scores" in data.batch.keys():
return data.batch["rm_scores"]
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
already_print_data_sources = {}
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch['prompts']
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
# select rm_score
data_source = data_item.non_tensor_batch['data_source']
data_source = data_item.non_tensor_batch["data_source"]
compute_score_fn = _select_rm_score_fn(data_source)
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
@ -87,28 +87,29 @@ class RewardManager():
return reward_tensor
import ray
import hydra
import ray
from split_monkey_patch import fit
@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None)
@hydra.main(config_path="config", config_name="ppo_trainer_split", version_base=None)
def main(config):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(config))
@ray.remote
def main_task(config):
from verl.utils.fs import copy_to_local
from transformers import AutoTokenizer
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
@ -117,19 +118,22 @@ def main_task(config):
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
@ -143,8 +147,8 @@ def main_task(config):
}
# NOTE: initialze two resource pool
actor_rollout_ref_pool_id = 'actor_rollout_ref_pool'
critic_pool_id = 'critic_pool'
actor_rollout_ref_pool_id = "actor_rollout_ref_pool"
critic_pool_id = "critic_pool"
if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:
resource_pool_spec = {
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
@ -155,13 +159,13 @@ def main_task(config):
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
}
print(f'resource_pool_spec: {resource_pool_spec}')
print(f"resource_pool_spec: {resource_pool_spec}")
mapping = {
Role.ActorRollout: actor_rollout_ref_pool_id,
Role.Critic: critic_pool_id,
}
#use reference model
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = actor_rollout_ref_pool_id
@ -173,9 +177,9 @@ def main_task(config):
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
if config.reward_model.strategy == "fsdp":
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
@ -190,16 +194,18 @@ def main_task(config):
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
RayPPOTrainer.fit = fit
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -14,13 +14,24 @@
"""
An naive implementation of split placment example
"""
from pprint import pprint
from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics, AdvantageEstimator
import uuid
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
import uuid
from verl import DataProto
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
_timer,
apply_kl_penalty,
compute_advantage,
compute_data_metrics,
compute_timing_metrics,
reduce_metrics,
)
def fit(self):
@ -29,13 +40,16 @@ def fit(self):
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
@ -44,11 +58,11 @@ def fit(self):
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
if self.config.trainer.get("val_only", False):
return
# we start from step 1
@ -63,18 +77,18 @@ def fit(self):
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
with _timer("step", timing_raw):
# generate a batch
with _timer('gen', timing_raw):
with _timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
@ -83,12 +97,13 @@ def fit(self):
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch['reward_baselines'] = reward_baseline_tensor
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
@ -99,26 +114,26 @@ def fit(self):
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer('values', timing_raw):
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
with _timer("adv", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
@ -129,57 +144,63 @@ def fit(self):
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor
batch.batch["token_level_scores"] = reward_tensor
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
)
# update critic
if self.use_critic:
with _timer('update_critic_call', timing_raw):
with _timer("update_critic_call", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor_call', timing_raw):
with _timer("update_actor_call", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
# NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class
with _timer('update_actor_critic', timing_raw):
with _timer("update_actor_critic", timing_raw):
critic_output = critic_output.get()
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
actor_output = actor_output.get()
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
@ -190,7 +211,7 @@ def fit(self):
logger.log(data=metrics, step=self.global_steps)
if self.global_steps >= self.total_training_steps:
pprint(f'Final validation metrics: {last_val_metrics}')
pprint(f"Final validation metrics: {last_val_metrics}")
return
self.global_steps += 1

View File

@ -23,6 +23,45 @@ license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifie
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.8"
# -------------------------------
# tool.ruff - Linting configuration
# -------------------------------
[tool.ruff]
line-length = 120
# Enable import sorting
[tool.ruff.lint]
isort = {known-first-party = ["verl"]}
# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
# -------------------------------
# tool.setuptools - Additional config
# -------------------------------
@ -45,117 +84,4 @@ version = {file = "verl/version/version"}
verl = [
"version/*",
"trainer/config/*.yaml"
]
[tool.pylint.message_control]
disable = [
"abstract-method",
"anomalous-backslash-in-string",
"arguments-differ",
"arguments-renamed",
"assignment-from-none",
"attribute-defined-outside-init",
"bad-str-strip-call",
"bare-except",
"broad-exception-caught",
"broad-exception-raised",
"cell-var-from-loop",
"chained-comparison",
"consider-iterating-dictionary",
"consider-using-enumerate",
"consider-using-f-string",
"consider-using-from-import",
"consider-using-generator",
"consider-using-in",
"consider-using-max-builtin",
"consider-using-set-comprehension",
"consider-using-sys-exit",
"consider-using-with",
"cyclic-import",
"dangerous-default-value",
"duplicate-code",
"eval-used",
"expression-not-assigned",
"f-string-without-interpolation",
"fixme",
"function-redefined",
"global-statement",
"global-variable-not-assigned",
"import-error",
"import-outside-toplevel",
"import-self",
"inconsistent-return-statements",
"invalid-character-zero-width-space",
"invalid-name",
"line-too-long",
"logging-fstring-interpolation",
"logging-not-lazy",
"missing-class-docstring",
"missing-final-newline",
"missing-function-docstring",
"missing-module-docstring",
"multiple-imports",
"no-else-continue",
"no-else-raise",
"no-else-return",
"no-member",
"no-self-argument",
"no-value-for-parameter",
"not-an-iterable",
"not-callable",
"notimplemented-raised",
"pointless-exception-statement",
"pointless-string-statement",
"pointless-statement",
"possibly-used-before-assignment",
"protected-access",
"raise-missing-from",
"raising-format-tuple",
"redefined-argument-from-local",
"redefined-builtin",
"redefined-outer-name",
"redundant-u-string-prefix",
"reimported",
"simplifiable-if-expression",
"simplifiable-if-statement",
"singleton-comparison",
"super-init-not-called",
"superfluous-parens",
"too-few-public-methods",
"too-many-arguments",
"too-many-boolean-expressions",
"too-many-branches",
"too-many-instance-attributes",
"too-many-lines",
"too-many-locals",
"too-many-positional-arguments",
"too-many-return-statements",
"too-many-statements",
"trailing-newlines",
"trailing-newlines",
"trailing-whitespace",
"unbalanced-tuple-unpacking",
"undefined-loop-variable",
"undefined-variable",
"ungrouped-imports",
"unidiomatic-typecheck",
"unnecessary-comprehension",
"unnecessary-lambda",
"unnecessary-lambda-assignment",
"unnecessary-pass",
"unspecified-encoding",
"unused-argument",
"unused-import",
"unused-variable",
"unused-wildcard-import",
"use-a-generator",
"use-dict-literal",
"used-before-assignment",
"useless-object-inheritance",
"useless-parent-delegation",
"useless-return",
"wildcard-import",
"wrong-import-order",
"wrong-import-position",
]

View File

@ -17,17 +17,22 @@ This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from pprint import pprint
from copy import deepcopy
from collections import defaultdict
from tqdm import tqdm
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from tqdm import tqdm
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, AdvantageEstimator
from verl.trainer.ppo.metric_utils import (compute_data_metrics, compute_throughout_metrics, compute_timing_metrics,
reduce_metrics)
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage
class RayDAPOTrainer(RayPPOTrainer):
@ -41,13 +46,16 @@ class RayDAPOTrainer(RayPPOTrainer):
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
@ -56,11 +64,11 @@ class RayDAPOTrainer(RayPPOTrainer):
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
if self.config.trainer.get("val_only", False):
return
# add tqdm
@ -81,28 +89,28 @@ class RayDAPOTrainer(RayPPOTrainer):
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys():
if "multi_modal_inputs" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
gen_batch = new_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids'],
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
with _timer("step", timing_raw):
# generate a batch
with _timer('gen', timing_raw):
with _timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
@ -111,17 +119,18 @@ class RayDAPOTrainer(RayPPOTrainer):
new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
new_batch.batch['reward_baselines'] = reward_baseline_tensor
new_batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
new_batch.non_tensor_batch['uid'] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
new_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
with _timer('reward', timing_raw):
with _timer("reward", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
@ -134,30 +143,31 @@ class RayDAPOTrainer(RayPPOTrainer):
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(new_batch, return_dict=True)
reward_tensor = reward_result['reward_tensor']
reward_extra_infos_dict = reward_result['reward_extra_info']
reward_tensor = reward_result["reward_tensor"]
reward_extra_infos_dict = reward_result["reward_extra_info"]
except Exception as e:
print(f'Error in reward_fn: {e}')
print(f"Error in reward_fn: {e}")
reward_tensor = self.reward_fn(new_batch)
reward_extra_infos_dict = {}
new_batch.batch['token_level_scores'] = reward_tensor
new_batch.batch["token_level_scores"] = reward_tensor
print(f'{list(reward_extra_infos_dict.keys())=}')
print(f"{list(reward_extra_infos_dict.keys())=}")
if reward_extra_infos_dict:
new_batch.non_tensor_batch.update({
k: np.array(v) for k, v in reward_extra_infos_dict.items()
})
new_batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
)
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(new_batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
new_batch, kl_metrics = apply_kl_penalty(
new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(
kl_metrics) # TODO: This will be cleared if we use multiple genenration batches
kl_metrics
) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch['token_level_rewards'] = new_batch.batch['token_level_scores']
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
@ -165,16 +175,19 @@ class RayDAPOTrainer(RayPPOTrainer):
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch['token_level_rewards'].sum(
dim=-1).numpy()
new_batch.non_tensor_batch["seq_final_reward"] = (
new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
)
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = new_batch.batch['token_level_scores'].sum(
dim=-1).numpy()
new_batch.non_tensor_batch["seq_reward"] = (
new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
)
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(new_batch.non_tensor_batch['uid'],
new_batch.non_tensor_batch[metric_name]):
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]
):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2metric_std = {}
@ -182,13 +195,14 @@ class RayDAPOTrainer(RayPPOTrainer):
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [
uid for uid, std in prompt_uid2metric_std.items()
uid
for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
num_prompt_in_batch += len(kept_prompt_uids)
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
@ -200,14 +214,14 @@ class RayDAPOTrainer(RayPPOTrainer):
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f'{num_prompt_in_batch=} < {prompt_bsz=}')
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f'{num_gen_batches=}. Keep generating...')
print(f"{num_gen_batches=}. Keep generating...")
continue
else:
raise ValueError(
f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'
f"{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data."
)
else:
# Align the batch
@ -221,60 +235,66 @@ class RayDAPOTrainer(RayPPOTrainer):
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer('values', timing_raw):
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
with _timer("adv", timing_raw):
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
)
# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
@ -294,7 +314,7 @@ class RayDAPOTrainer(RayPPOTrainer):
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f'Final validation metrics: {last_val_metrics}')
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return

View File

@ -14,15 +14,18 @@
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from .dapo_ray_trainer import RayDAPOTrainer
import os
import ray
import hydra
import ray
from .dapo_ray_trainer import RayDAPOTrainer
def get_custom_reward_fn(config):
import importlib.util, os
import importlib.util
import os
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
@ -49,7 +52,7 @@ def get_custom_reward_fn(config):
return getattr(module, function_name)
@hydra.main(config_path='config', config_name='dapo_trainer', version_base=None)
@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None)
def main(config):
run_ppo(config)
@ -57,16 +60,14 @@ def main(config):
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={
'env_vars': {
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN',
'VLLM_LOGGING_LEVEL': 'WARN'
ray.init(
runtime_env={
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
}
})
)
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
@ -74,12 +75,14 @@ def run_ppo(config) -> None:
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
from verl.utils.fs import copy_to_local
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
@ -87,21 +90,24 @@ class TaskRunner:
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer, hf_processor
from verl.utils import hf_processor, hf_tokenizer
tokenizer = hf_tokenizer(local_path)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
@ -112,10 +118,10 @@ class TaskRunner:
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = 'global_pool'
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
@ -132,9 +138,9 @@ class TaskRunner:
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
if config.reward_model.strategy == "fsdp":
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
@ -147,47 +153,55 @@ class TaskRunner:
mapping[Role.RefPolicy] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
if reward_manager_name == "naive":
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
elif reward_manager_name == "prime":
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
elif reward_manager_name == 'dapo':
elif reward_manager_name == "dapo":
from verl.workers.reward_manager import DAPORewardManager
reward_manager_cls = DAPORewardManager
else:
raise NotImplementedError
compute_score = get_custom_reward_fn(config)
reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer)
reward_fn = reward_manager_cls(
tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=1,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer)
val_reward_fn = reward_manager_cls(
tokenizer=tokenizer,
num_examine=1,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayDAPOTrainer(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer = RayDAPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.

View File

@ -28,13 +28,14 @@
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import hydra
import ray
from .prime_ray_trainer import RayPRIMETrainer
import ray
import hydra
@hydra.main(config_path='config', config_name='prime_trainer', version_base=None)
@hydra.main(config_path="config", config_name="prime_trainer", version_base=None)
def main(config):
run_prime(config)
@ -43,10 +44,7 @@ def run_prime(config, compute_score=None):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={'env_vars': {
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN'
}},
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
)
ray.get(main_task.remote(config, compute_score))
@ -54,10 +52,13 @@ def run_prime(config, compute_score=None):
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
@ -66,19 +67,22 @@ def main_task(config, compute_score=None):
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
@ -90,7 +94,7 @@ def main_task(config, compute_score=None):
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = 'global_pool'
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
@ -98,22 +102,25 @@ def main_task(config, compute_score=None):
Role.ActorRollout: global_pool_id,
}
#use reference model
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
if config.reward_model.enable:
from .prime_fsdp_workers import PRIMERewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
if reward_manager_name == "naive":
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
elif reward_manager_name == "prime":
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
@ -124,16 +131,18 @@ def main_task(config, compute_score=None):
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPRIMETrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer = RayPRIMETrainer(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import torch
import verl
import verl.utils.torch_functional as verl_F
@ -23,45 +24,48 @@ def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Ten
reward_tensor = reward_tensor_original.clone()
reward_tensor[~mask_tensor] = 0
for start_pos in range(0, reward_tensor.shape[0], n_samples):
cur_rewards_mean = torch.cat([
reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True)
for pos in range(start_pos, start_pos + n_samples)
],
dim=0)
cur_rewards_mean = torch.cat(
[
reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True)
for pos in range(start_pos, start_pos + n_samples)
],
dim=0,
)
cur_rewards_sum = cur_rewards_mean.sum()
cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
reward_tensor[start_pos:start_pos + n_samples][
mask_tensor[start_pos:start_pos + n_samples]] = \
reward_tensor[start_pos:start_pos + n_samples][
mask_tensor[start_pos:start_pos + n_samples]] * (
n_samples / (n_samples - 1)) - cur_reward_baseline
reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (
reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]
* (n_samples / (n_samples - 1))
- cur_reward_baseline
)
return reward_tensor
reward_tensors = []
with torch.no_grad():
if 'rm_scores' in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.:
reward_tensor = data.batch['rm_scores']
if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:
reward_tensor = data.batch["rm_scores"]
reward_mask = response_mask.bool()
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
if 'acc' in data.batch.keys() and config.algorithm.reward_gt_coef != 0.:
if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
prompt_ids = data.batch['prompts']
prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1)
reward_mask[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1] = True
valid_response_length - 1,
] = True
reward_tensor[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1] = data.batch['acc']
valid_response_length - 1,
] = data.batch["acc"]
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)
@ -81,7 +85,7 @@ def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
return cur_dpo_loss
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode='none'):
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"):
# we always assume that the BoN size equals n_samples
# mode1: use acc as rm
# mode2: use Q as rm
@ -97,15 +101,15 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_m
else:
other_Q[i] = 0
dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))
if bon_mode == 'none':
if bon_mode == "none":
dpo_loss = dpo_loss.mean()
else:
weight = torch.zeros_like(dpo_loss)
n_samples = acc_bc.shape[1]
if bon_mode == 'bon_rm':
if bon_mode == "bon_rm":
for i in range(token_level_scores.shape[0]):
weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)
elif bon_mode == 'bon_acc':
elif bon_mode == "bon_acc":
for i in range(token_level_scores.shape[0]):
weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)
else:
@ -118,22 +122,24 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_m
def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
dpo_acc = []
for start_id in range(0, token_level_scores.shape[0], n_samples):
cur_scores = (token_level_scores[start_id:start_id + n_samples] *
response_mask[start_id:start_id + n_samples]).sum(dim=1)
cur_scores = (
token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]
).sum(dim=1)
def get_upper_triangle(tensor_x):
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
return diff_matrix[upper_tri_indices]
cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples]) # in range [-1,1]
cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1]
cur_score_diff = get_upper_triangle(cur_scores) # in R
cur_score_prediction = (cur_score_diff > 0).float() # in [0,1]
if cur_acc_diff.abs().sum() == 0:
cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
else:
cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() *
cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum()
cur_acc = (
((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()
).sum() / cur_acc_diff.abs().sum()
dpo_acc.append(cur_acc.unsqueeze(0))

View File

@ -14,62 +14,56 @@
"""
Implement a multiprocess PPOCritic
"""
import itertools
from typing import Iterable
import torch
import torch.distributed
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
__all__ = ['DataParallelPRIMERewardModel']
__all__ = ["DataParallelPRIMERewardModel"]
class DataParallelPRIMERewardModel:
def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):
self.config = config
self.reward_module = reward_module
self.ref_module = ref_module
self.reward_optimizer = reward_optimizer
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
print(f'Reward model use_remove_padding={self.use_remove_padding}')
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
print(f"Reward model use_remove_padding={self.use_remove_padding}")
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
def _forward_micro_batch(self, micro_batch, prompt_length):
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
input_ids = micro_batch['input_ids']
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
num_actions = micro_batch['input_ids'].shape[-1] - prompt_length
max_positions = micro_batch['attention_mask'][:, prompt_length:].sum(-1)
num_actions = micro_batch["input_ids"].shape[-1] - prompt_length
max_positions = micro_batch["attention_mask"][:, prompt_length:].sum(-1)
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
@ -77,90 +71,93 @@ class DataParallelPRIMERewardModel:
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
self.ulysses_sequence_parallel_size)
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
rm_output_logits = self.reward_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False).logits.squeeze(
0) # copied. I don't really know why there is a squeeze
rm_output_logits = self.reward_module(
input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False
).logits.squeeze(0) # copied. I don't really know why there is a squeeze
rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled)
if self.ulysses_sequence_parallel_size > 1:
rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1]
rm_log_labels = pad_input(
hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
).squeeze(-1)[:, -num_actions - 1 : -1]
else:
rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'],
use_cache=False).logits
rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :],
dim=-1) # (batch_size, seq_length, vocab_size)
rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
-1) # (batch, seq_length)
rm_output_logits = self.reward_module(
input_ids=micro_batch["input_ids"],
attention_mask=micro_batch["attention_mask"],
position_ids=micro_batch["position_ids"],
use_cache=False,
).logits
rm_log_prob = torch.nn.functional.log_softmax(
rm_output_logits[:, :-1, :], dim=-1
) # (batch_size, seq_length, vocab_size)
rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze(
-1
) # (batch, seq_length)
if self.ref_module is not None:
# do not have to pad again
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:
ref_output_logits = self.ref_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False).logits.squeeze(0)
ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits,
labels=input_ids_rmpad_rolled)
ref_log_labels = gather_outpus_and_unpad(ref_log_labels,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1]
ref_output_logits = self.ref_module(
input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False
).logits.squeeze(0)
ref_log_labels = verl_F.logprobs_from_logits(
logits=ref_output_logits, labels=input_ids_rmpad_rolled
)
ref_log_labels = gather_outpus_and_unpad(
ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
ref_log_labels = pad_input(
hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
).squeeze(-1)[:, -num_actions - 1 : -1]
else:
ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'],
use_cache=False).logits
ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :],
dim=-1) # (batch_size, seq_length, vocab_size)
ref_log_labels = ref_log_prob.gather(dim=-1,
index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
-1) # (batch, seq_length)
ref_output_logits = self.ref_module(
input_ids=micro_batch["input_ids"],
attention_mask=micro_batch["attention_mask"],
position_ids=micro_batch["position_ids"],
use_cache=False,
).logits
ref_log_prob = torch.nn.functional.log_softmax(
ref_output_logits[:, :-1, :], dim=-1
) # (batch_size, seq_length, vocab_size)
ref_log_labels = ref_log_prob.gather(
dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)
).squeeze(-1) # (batch, seq_length)
else:
ref_log_labels = micro_batch['old_log_probs']
ref_log_labels = micro_batch["old_log_probs"]
ref_log_labels.to(rm_log_labels.dtype)
q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q
# trim unnecessary logprobs here
for i in range(micro_batch['input_ids'].shape[0]):
q[i, max_positions[i]:] = 0
for i in range(micro_batch["input_ids"].shape[0]):
q[i, max_positions[i] :] = 0
# reward computation does not need gradient. only q needs
with torch.no_grad():
# generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model.
lam = self.config.get('lambda', 0.)
beta = self.config.model.get('beta_train', 0.05)
if lam == 0.:
lam = self.config.get("lambda", 0.0)
beta = self.config.model.get("beta_train", 0.05)
if lam == 0.0:
r = q * beta
else:
# reward coefficient takes no effect here
acc = micro_batch['acc']
acc = micro_batch["acc"]
q_ = q * beta
r = torch.zeros_like(q)
lastgaelam = 0
# change the last token and mask out all paddings to make this process easier if we rely on outcome reward to calculate V
for i in range(q.shape[0]):
if self.config.prime_use_gt:
q_[i, max_positions[i] - 1] = acc[i] - q_[i, :max_positions[i] - 1].sum()
q_[i, max_positions[i]:] = 0
q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum()
q_[i, max_positions[i] :] = 0
for t in reversed(range(num_actions)):
delta = q_[:, t]
@ -169,12 +166,12 @@ class DataParallelPRIMERewardModel:
token_level_score = torch.zeros_like(q)
if self.config.prime_granularity == 'token':
for i in range(micro_batch['input_ids'].shape[0]):
token_level_score[i, :max_positions[i] - 1] = r[i, :max_positions[i] - 1]
elif self.config.prime_granularity == 'whole':
for i in range(micro_batch['input_ids'].shape[0]):
token_level_score[i, max_positions[i] - 1] = r[i, :max_positions[i]]
if self.config.prime_granularity == "token":
for i in range(micro_batch["input_ids"].shape[0]):
token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1]
elif self.config.prime_granularity == "whole":
for i in range(micro_batch["input_ids"].shape[0]):
token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]]
else:
raise NotImplementedError
@ -186,13 +183,14 @@ class DataParallelPRIMERewardModel:
if isinstance(self.reward_module, FSDP):
grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(),
max_norm=self.config.model.optim.grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip
)
self.reward_optimizer.step()
return grad_norm
def prime_norm(self, token_level_scores):
if self.config.prime_norm == 'batch_norm':
if self.config.prime_norm == "batch_norm":
reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])
token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
return token_level_scores
@ -200,15 +198,15 @@ class DataParallelPRIMERewardModel:
def compute_rm_score(self, data: DataProto):
self.reward_module.eval()
self.ref_module.eval()
micro_batch_size = data.meta_info['micro_batch_size']
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'acc']
micro_batch_size = data.meta_info["micro_batch_size"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"]
batch = data.select(batch_keys=select_keys).batch
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
prompt_length = data.batch['input_ids'].shape[-1] - data.batch['responses'].shape[-1]
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1]
if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
@ -231,21 +229,25 @@ class DataParallelPRIMERewardModel:
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
rm_scores = rm_scores[revert_indices]
return rm_scores, q.detach(), {
'reward_model/reward': rm_scores.sum(dim=-1).mean().item(),
'reward_model/raw_reward': q.sum(dim=-1).mean().item()
}
return (
rm_scores,
q.detach(),
{
"reward_model/reward": rm_scores.sum(dim=-1).mean().item(),
"reward_model/raw_reward": q.sum(dim=-1).mean().item(),
},
)
def update_rm(self, data: DataProto):
# make sure we are in training mode
self.reward_module.train()
metrics = {}
beta = self.config.model.get('beta_train', 0.05)
beta = self.config.model.get("beta_train", 0.05)
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'acc', 'prompts']
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "acc", "prompts"]
for key in ['Q_bc', 'acc_bc']:
for key in ["Q_bc", "acc_bc"]:
if key in data.batch.keys():
select_keys.append(key)
@ -271,10 +273,10 @@ class DataParallelPRIMERewardModel:
for data in micro_batches:
data = data.cuda()
attention_mask = data['attention_mask']
acc = data['acc']
attention_mask = data["attention_mask"]
acc = data["acc"]
prompt_ids = data['prompts']
prompt_ids = data["prompts"]
prompt_length = prompt_ids.shape[-1]
response_mask = attention_mask[:, prompt_length:]
@ -284,37 +286,38 @@ class DataParallelPRIMERewardModel:
rm_scores_lst.append(rm_score)
q_lst.append(q.detach())
if self.config.model.loss_type == 'ce':
if self.config.model.loss_type == "ce":
dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)
elif self.config.model.loss_type == 'dpo':
elif self.config.model.loss_type == "dpo":
# the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update.
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
response_mask=response_mask,
beta=beta)
elif self.config.model.loss_type == 'bon_acc':
dpo_loss = compute_detach_dpo_loss_rm(
q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta
)
elif self.config.model.loss_type == "bon_acc":
# change the original distribution of each sample to BoN distribution, then update reward model
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
response_mask=response_mask,
beta=beta,
bon_mode='bon_acc')
elif self.config.model.loss_type == 'bon_rm':
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
response_mask=response_mask,
beta=beta,
bon_mode='bon_rm')
dpo_loss = compute_detach_dpo_loss_rm(
q,
acc,
Q_bc=data["Q_bc"],
acc_bc=data["acc_bc"],
response_mask=response_mask,
beta=beta,
bon_mode="bon_acc",
)
elif self.config.model.loss_type == "bon_rm":
dpo_loss = compute_detach_dpo_loss_rm(
q,
acc,
Q_bc=data["Q_bc"],
acc_bc=data["acc_bc"],
response_mask=response_mask,
beta=beta,
bon_mode="bon_rm",
)
else:
raise NotImplementedError
data = {'reward_model/dpo_loss': dpo_loss.detach().item()}
data = {"reward_model/dpo_loss": dpo_loss.detach().item()}
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
@ -327,7 +330,7 @@ class DataParallelPRIMERewardModel:
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
data = {'reward_model/grad_norm': grad_norm.detach().item()}
data = {"reward_model/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, data)
self.reward_optimizer.zero_grad()
@ -336,9 +339,11 @@ class DataParallelPRIMERewardModel:
rm_scores = self.prime_norm(rm_scores)
metrics.update({
'reward_model/reward': rm_scores.sum(dim=-1).mean().item(),
'reward_model/raw_reward': q.sum(dim=-1).mean().item()
})
metrics.update(
{
"reward_model/reward": rm_scores.sum(dim=-1).mean().item(),
"reward_model/raw_reward": q.sum(dim=-1).mean().item(),
}
)
return rm_scores, metrics

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import warnings
@ -19,54 +18,56 @@ import warnings
import torch
import torch.distributed
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
from omegaconf import DictConfig, open_dict
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_tokenizer
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \
load_fsdp_model_to_gpu
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.flops_counter import FlopsCounter
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from codetiming import Timer
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
from .prime_core_algos import compute_dpo_accuracy, compute_dpo_abs_accuracy
from .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
class PRIMERewardModelWorker(Worker):
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh('cuda',
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=['dp', 'sp'])
self.ulysses_device_mesh = init_device_mesh(
"cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
@ -75,40 +76,42 @@ class PRIMERewardModelWorker(Worker):
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size)
self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size)
self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0
def _build_reward_ref_model_optimizer(self, config):
# the following line is necessary
from verl.utils.model import LambdaLayer, print_model_size, squeeze
from verl.utils.torch_dtypes import PrecisionType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import print_model_size
from verl.utils.torch_dtypes import PrecisionType
local_path = copy_local_path_from_hdfs(config.model.path)
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f'Reward model overriding config {override_config_kwargs}')
print(f"Reward model overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32')
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForCausalLM
from torch import nn
trust_remote_code = False
reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
@ -117,34 +120,37 @@ class PRIMERewardModelWorker(Worker):
init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(reward_model_config, 'classifier_dropout', 0.)
setattr(reward_model_config, 'hidden_dropout', '0')
reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
reward_model_config.classifier_dropout = 0.0
reward_model_config.hidden_dropout = "0"
reward_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1:
if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# some parameters may not in torch_dtype
reward_module.to(torch_dtype)
if config.model.get('enable_gradient_checkpointing', False):
reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
if config.model.get("enable_gradient_checkpointing", False):
reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if self.rank == 0:
print_model_size(reward_module)
self.reward_model_config = reward_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get('mixed_precision', None)
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
@ -154,78 +160,89 @@ class PRIMERewardModelWorker(Worker):
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy)
log_gpu_memory_usage('Before reward model FSDP', logger=None)
log_gpu_memory_usage("Before reward model FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(reward_model_config, 'classifier_dropout', 0.)
setattr(reward_model_config, 'hidden_dropout', '0')
ref_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=copy_local_path_from_hdfs(
config.model.ref_path),
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
reward_model_config.classifier_dropout = 0.0
reward_model_config.hidden_dropout = "0"
ref_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path),
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
# some parameters may not in torch_dtype
ref_module.to(torch_dtype)
reward_module = FSDP(reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
log_gpu_memory_usage('After reward FSDP', logger=None)
log_gpu_memory_usage("After reward FSDP", logger=None)
ref_module = FSDP(ref_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None)
ref_module = FSDP(
ref_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
reward_optimizer = optim.AdamW(reward_module.parameters(),
lr=config.model.optim.lr,
betas=config.model.optim.get('betas', (0.9, 0.999)),
weight_decay=config.model.optim.get('weight_decay', 1e-2))
reward_optimizer = optim.AdamW(
reward_module.parameters(),
lr=config.model.optim.lr,
betas=config.model.optim.get("betas", (0.9, 0.999)),
weight_decay=config.model.optim.get("weight_decay", 1e-2),
)
total_steps = config.model.optim.get('total_training_steps', 0)
num_warmup_steps = int(config.model.optim.get('lr_warmup_steps', -1))
total_steps = config.model.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1))
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.model.optim.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps_ratio = config.model.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import get_constant_schedule_with_warmup
reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=reward_optimizer,
num_warmup_steps=num_warmup_steps)
reward_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps
)
return reward_module, ref_module, reward_optimizer, reward_lr_scheduler
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
import_external_libs(self.config.model.get("external_lib", None))
from .prime_dp_rm import DataParallelPRIMERewardModel
self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = self._build_reward_ref_model_optimizer(
config=self.config)
self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = (
self._build_reward_ref_model_optimizer(config=self.config)
)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
@ -233,47 +250,51 @@ class PRIMERewardModelWorker(Worker):
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
self.rm = DataParallelPRIMERewardModel(config=self.config,
reward_module=self.reward_module,
ref_module=self.ref_module,
reward_optimizer=self.reward_optimizer)
self.rm = DataParallelPRIMERewardModel(
config=self.config,
reward_module=self.reward_module,
ref_module=self.ref_module,
reward_optimizer=self.reward_optimizer,
)
self.flops_counter = FlopsCounter(self.reward_model_config)
self.checkpoint_manager = FSDPCheckpointManager(model=self.reward_module,
optimizer=self.reward_optimizer,
lr_scheduler=self.reward_lr_scheduler,
tokenizer=self.tokenizer)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.reward_module,
optimizer=self.reward_optimizer,
lr_scheduler=self.reward_lr_scheduler,
tokenizer=self.tokenizer,
)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data = data.to('cuda')
data = data.to("cuda")
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
load_fsdp_model_to_gpu(self.ref_module)
micro_batch_size = self.config.micro_batch_size_per_gpu
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
rm_scores, q, metrics = self.rm.compute_rm_score(data=data)
prompt_length = data.batch['prompts'].shape[-1]
response_mask = data.batch['attention_mask'][:, prompt_length:]
acc = data.batch['acc']
prompt_length = data.batch["prompts"].shape[-1]
response_mask = data.batch["attention_mask"][:, prompt_length:]
acc = data.batch["acc"]
dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"])
metrics['reward_model/dpo_acc'] = dpo_acc.detach().item()
metrics['reward_model/dpo_acc_abs'] = dpo_acc_abs.detach().item()
metrics["reward_model/dpo_acc"] = dpo_acc.detach().item()
metrics["reward_model/dpo_acc_abs"] = dpo_acc_abs.detach().item()
output = DataProto.from_dict(tensors={'rm_scores': rm_scores, 'q': q}, meta_info={'metrics': metrics})
output = DataProto.from_dict(tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to('cpu')
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
@ -281,7 +302,7 @@ class PRIMERewardModelWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_rm(self, data: DataProto):
data = data.to('cuda')
data = data.to("cuda")
if self._is_offload_param:
load_fsdp_model_to_gpu(self.ref_module)
load_fsdp_model_to_gpu(self.reward_module)
@ -296,22 +317,21 @@ class PRIMERewardModelWorker(Worker):
self.reward_lr_scheduler.step()
lr = self.reward_lr_scheduler.get_last_lr()[0]
metrics['rm/lr'] = lr
metrics["rm/lr"] = lr
prompt_length = data.batch['prompts'].shape[-1]
response_mask = data.batch['attention_mask'][:, prompt_length:]
acc = data.batch['acc']
prompt_length = data.batch["prompts"].shape[-1]
response_mask = data.batch["attention_mask"][:, prompt_length:]
acc = data.batch["acc"]
dpo_acc_before = compute_dpo_accuracy(rm_scores,
acc,
response_mask=response_mask,
n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
dpo_acc_before = compute_dpo_accuracy(
rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]
)
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"])
metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item()
metrics['reward_model/dpo_acc_abs_before'] = dpo_acc_abs.detach().item()
metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item()
metrics["reward_model/dpo_acc_abs_before"] = dpo_acc_abs.detach().item()
output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics})
output = DataProto.from_dict(tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
@ -319,19 +339,19 @@ class PRIMERewardModelWorker(Worker):
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
output = output.to('cpu')
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep)
self.checkpoint_manager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
torch.distributed.barrier()
if self._is_offload_param:
@ -340,6 +360,7 @@ class PRIMERewardModelWorker(Worker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)

View File

@ -28,129 +28,115 @@ from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _timer
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.core_algos import agg_loss
from verl.utils.py_functional import append_to_dict
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from . import prime_core_algos
def compute_advantage(data: DataProto, adv_estimator, config):
if adv_estimator == 'rloo':
responses = data.batch['responses']
if adv_estimator == "rloo":
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask,
config.actor_rollout_ref.rollout.n, config)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
advantages, returns = prime_core_algos.compute_rloo_advantage_return(
data, response_mask, config.actor_rollout_ref.rollout.n, config
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
else:
raise NotImplementedError
return data
def compute_data_metrics(batch, use_critic=True):
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
advantages = batch.batch['advantages']
returns = batch.batch['returns']
max_response_length = batch.batch["responses"].shape[-1]
max_response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch['values']
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# adv
'critic/advantages/mean':
torch.mean(valid_adv).detach().item(),
'critic/advantages/max':
torch.max(valid_adv).detach().item(),
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
**({
# values
'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': torch.max(valid_values).detach().item(),
'critic/values/min': torch.min(valid_values).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} if use_critic else {}),
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
'response_length/mean':
torch.mean(response_length).detach().item(),
'response_length/max':
torch.max(response_length).detach().item(),
'response_length/min':
torch.min(response_length).detach().item(),
'response_length/clip_ratio':
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
'prompt_length/mean':
torch.mean(prompt_length).detach().item(),
'prompt_length/max':
torch.max(prompt_length).detach().item(),
'prompt_length/min':
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_response_mask(data: DataProto):
responses = data.batch['responses']
responses = data.batch["responses"]
response_length = responses.size(1)
attention_mask = data.batch['attention_mask']
attention_mask = data.batch["attention_mask"]
return attention_mask[:, -response_length:]
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
num_response_tokens = torch.sum(response_info['response_length']).item()
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
num_response_tokens = torch.sum(response_info["response_length"]).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
'gen': num_response_tokens,
**{
name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']
},
"gen": num_response_tokens,
**{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f'timing_s/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
@ -162,19 +148,27 @@ class RayPRIMETrainer(RayPPOTrainer):
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None):
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None,
):
# assert torch.cuda.is_available(), 'cuda must be available on driver'
super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn,
val_reward_fn)
super().__init__(
config,
tokenizer,
role_worker_mapping,
resource_pool_manager,
ray_worker_group_cls,
reward_fn,
val_reward_fn,
)
self.use_critic = False
@ -185,39 +179,43 @@ class RayPRIMETrainer(RayPPOTrainer):
def _create_dataloader(self):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
# TODO: we have to make sure the batch size is divisible by the dp size
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
config=self.config.data)
self.train_dataset = RLHFDataset(
data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=int(self.config.data.train_batch_size *
self.config.data.oversample_factor),
drop_last=True,
collate_fn=collate_fn,
sampler=sampler)
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor),
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)
self.val_dataset = RLHFDataset(data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
config=self.config.data)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
self.val_dataset = RLHFDataset(
data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data
)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f'Size of train dataloader: {len(self.train_dataloader)}')
print(f'Size of val dataloader: {len(self.val_dataloader)}')
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
@ -226,7 +224,7 @@ class RayPRIMETrainer(RayPPOTrainer):
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
print(f"Total training steps: {self.total_training_steps}")
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
@ -235,45 +233,58 @@ class RayPRIMETrainer(RayPPOTrainer):
def _save_checkpoint(self):
# path: given_path + `/global_step_{global_steps}` + `/actor`
local_global_step_folder = os.path.join(self.config.trainer.default_local_dir,
f'global_step_{self.global_steps}')
print(f'local_global_step_folder: {local_global_step_folder}')
actor_local_path = os.path.join(local_global_step_folder, 'actor')
local_global_step_folder = os.path.join(
self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
)
print(f"local_global_step_folder: {local_global_step_folder}")
actor_local_path = os.path.join(local_global_step_folder, "actor")
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path,
actor_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)
actor_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
)
self.actor_rollout_wg.save_checkpoint(
actor_local_path,
actor_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save,
)
if self.use_rm:
reward_local_path = os.path.join(local_global_step_folder, 'reward')
reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward')
self.rm_wg.save_checkpoint(reward_local_path,
reward_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)
reward_local_path = os.path.join(local_global_step_folder, "reward")
reward_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward")
)
self.rm_wg.save_checkpoint(
reward_local_path,
reward_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save,
)
# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
import dill
torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill)
# latest checkpointed iteration tracker (for atomic usage)
local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir,
'latest_checkpointed_iteration.txt')
with open(local_latest_checkpointed_iteration, 'w') as f:
local_latest_checkpointed_iteration = os.path.join(
self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
)
with open(local_latest_checkpointed_iteration, "w") as f:
f.write(str(self.global_steps))
def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable':
if self.config.trainer.resume_mode == "disable":
return 0
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
NotImplementedError('load from hdfs is not implemented yet')
NotImplementedError("load from hdfs is not implemented yet")
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
@ -282,37 +293,40 @@ class RayPRIMETrainer(RayPPOTrainer):
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
# find global_step_folder
if self.config.trainer.resume_mode == 'auto':
if self.config.trainer.resume_mode == "auto":
if global_step_folder is None:
print('Training from scratch')
print("Training from scratch")
return 0
else:
if self.config.trainer.resume_mode == "resume_path":
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps"
assert "global_step_" in self.config.trainer.resume_from_path, (
"resume ckpt must specify the global_steps"
)
global_step_folder = self.config.trainer.resume_from_path
if not os.path.isabs(global_step_folder):
working_dir = os.getcwd()
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f'Load from checkpoint folder: {global_step_folder}')
print(f"Load from checkpoint folder: {global_step_folder}")
# set global step
self.global_steps = int(global_step_folder.split('global_step_')[-1])
self.global_steps = int(global_step_folder.split("global_step_")[-1])
print(f'Setting global step to {self.global_steps}')
print(f'Resuming from {global_step_folder}')
print(f"Setting global step to {self.global_steps}")
print(f"Resuming from {global_step_folder}")
actor_path = os.path.join(global_step_folder, 'actor')
reward_path = os.path.join(global_step_folder, 'reward')
actor_path = os.path.join(global_step_folder, "actor")
reward_path = os.path.join(global_step_folder, "reward")
# load actor
self.actor_rollout_wg.load_checkpoint(actor_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
# load rm
if self.use_rm:
self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load dataloader,
# TODO: from remote not implemented yet
dataloader_local_path = os.path.join(global_step_folder, 'data.pt')
dataloader_local_path = os.path.join(global_step_folder, "data.pt")
self.train_dataloader = torch.load(dataloader_local_path)
if isinstance(self.train_dataloader.dataset, RLHFDataset):
self.train_dataloader.dataset.resume_dataset_state()
@ -323,13 +337,16 @@ class RayPRIMETrainer(RayPPOTrainer):
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
@ -338,11 +355,11 @@ class RayPRIMETrainer(RayPPOTrainer):
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
if self.config.trainer.get("val_only", False):
return
# we start from step 1
@ -356,17 +373,17 @@ class RayPRIMETrainer(RayPPOTrainer):
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])
with _timer('step', timing_raw):
with _timer("step", timing_raw):
# generate a batch
with _timer('gen', timing_raw):
with _timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax':
with _timer('gen_max', timing_raw):
if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
@ -375,12 +392,13 @@ class RayPRIMETrainer(RayPPOTrainer):
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch['reward_baselines'] = reward_baseline_tensor
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
@ -391,96 +409,105 @@ class RayPRIMETrainer(RayPPOTrainer):
# self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# verify
with _timer('verify', timing_raw):
with _timer("verify", timing_raw):
scores = self.reward_fn.verify(batch)
metrics['acc'] = statistics.mean(scores)
metrics["acc"] = statistics.mean(scores)
# filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized.
batch = self.filter_and_downsample(scores, batch)
batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n
batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n
n_samples = self.config.actor_rollout_ref.rollout.n
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch['entropys']
entropys = old_log_prob.batch["entropys"]
response_masks = compute_response_mask(batch)
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_loss = agg_loss(loss_mat=entropys,
loss_mask=response_masks,
loss_agg_mode=loss_agg_mode)
entropy_loss = agg_loss(
loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
)
old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop('entropys')
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
with _timer('adv', timing_raw):
with _timer("adv", timing_raw):
if self.use_rm:
update_style = self.config.reward_model.model.get('update', 'none')
if update_style == 'none': # only run forward
update_style = self.config.reward_model.model.get("update", "none")
if update_style == "none": # only run forward
reward_output = self.rm_wg.compute_rm_score(batch)
elif update_style == 'after': # update and directly return the reward
elif update_style == "after": # update and directly return the reward
reward_output = self.rm_wg.update_rm(batch)
elif update_style == 'before': # update reward model, and then run forward
elif update_style == "before": # update reward model, and then run forward
reward_output = self.rm_wg.update_rm(batch)
if 'metrics' in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics'])
if "metrics" in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"])
metrics.update(reward_output_metrics)
reward_output = self.rm_wg.compute_rm_score(batch)
elif update_style == 'reverse': # run forward to calculate statistics, then update reward model
elif (
update_style == "reverse"
): # run forward to calculate statistics, then update reward model
reward_output = self.rm_wg.compute_rm_score(batch)
# broadcast q and acc tensor to each result
bc_td = DataProto.from_dict(
tensors={
'Q_bc':
reward_output.batch['q'].sum(dim=-1).view(-1, n_samples).unsqueeze(
1).expand(-1, n_samples, -1).reshape(-1, n_samples),
'acc_bc':
batch.batch['acc'].view(-1, n_samples).unsqueeze(1).expand(
-1, n_samples, -1).reshape(-1, n_samples)
})
"Q_bc": reward_output.batch["q"]
.sum(dim=-1)
.view(-1, n_samples)
.unsqueeze(1)
.expand(-1, n_samples, -1)
.reshape(-1, n_samples),
"acc_bc": batch.batch["acc"]
.view(-1, n_samples)
.unsqueeze(1)
.expand(-1, n_samples, -1)
.reshape(-1, n_samples),
}
)
batch = batch.union(bc_td)
reward_output = self.rm_wg.update_rm(batch)
else:
raise NotImplementedError
batch = batch.union(reward_output)
if 'metrics' in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics'])
if "metrics" in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"])
metrics.update(reward_output_metrics)
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
config=self.config)
batch = compute_advantage(
batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config
)
# update actor
with _timer('update_actor', timing_raw):
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and self.global_steps % self.config.trainer.test_freq == 0
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
with _timer('save_checkpoint', timing_raw):
if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
@ -493,15 +520,16 @@ class RayPRIMETrainer(RayPPOTrainer):
self.global_steps += 1
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
pprint(f"Final validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.save_freq > 0 and \
(self.global_steps - 1) % self.config.trainer.save_freq != 0:
with _timer('save_checkpoint', timing_raw):
if (
self.config.trainer.save_freq > 0
and (self.global_steps - 1) % self.config.trainer.save_freq != 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
return
@ -517,18 +545,24 @@ class RayPRIMETrainer(RayPPOTrainer):
if self.config.data.filter_accuracy:
acc_tensor = torch.mean(reward_matrix, dim=-1)
filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) |
(acc_tensor < self.config.data.accuracy_lower_bound)] = False
filter_mask[
(acc_tensor > self.config.data.accuracy_upper_bound)
| (acc_tensor < self.config.data.accuracy_lower_bound)
] = False
if self.config.data.filter_truncate:
length_matrix = batch.batch['attention_mask'][:, -batch.batch['responses'].shape[-1]:].sum(dim=-1).reshape(
-1, n_samples)
length_matrix = (
batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :]
.sum(dim=-1)
.reshape(-1, n_samples)
)
length_tensor = torch.max(length_matrix, dim=-1)[0]
filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False
reorder_index = torch.argsort(filter_mask, descending=True)
reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1)
batch.reorder(reorder_index[:int(len(batch) //
self.config.data.oversample_factor)]) # this operation is inplace
batch.reorder(
reorder_index[: int(len(batch) // self.config.data.oversample_factor)]
) # this operation is inplace
return batch

View File

@ -15,54 +15,44 @@
Preprocess the dataset to parquet format
"""
import argparse
import os
from datasets import load_dataset, concatenate_datasets
from functools import partial
from datasets import concatenate_datasets, load_dataset
from verl.utils.hdfs_io import copy, makedirs
import argparse
def example_map_fn(example, idx, process_fn, data_source, ability, split):
question, solution = process_fn(example)
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question
}],
"prompt": [{"role": "user", "content": question}],
"ability": ability,
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"extra_info": {
'split': split,
'index': idx
}
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {"split": split, "index": idx},
}
return data
def build_aime2024_dataset():
def process_aime2024(example):
return example["Problem"], str(example["Answer"])
data_source = 'Maxwell-Jia/AIME_2024'
data_source = "Maxwell-Jia/AIME_2024"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, split="train")
map_fn = partial(example_map_fn,
process_fn=process_aime2024,
data_source=data_source,
ability="English",
split="test")
map_fn = partial(
example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test"
)
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)
return dataset
def build_gpqa_dimond_dataset():
import random
GPQA_QUERY_TEMPLATE = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
def process_gpqa_diamond(example):
@ -70,49 +60,40 @@ def build_gpqa_dimond_dataset():
random.shuffle(choices)
gold_index = random.randint(0, 3)
choices.insert(gold_index, example["Correct Answer"])
query_prompt = GPQA_QUERY_TEMPLATE.format(A=choices[0],
B=choices[1],
C=choices[2],
D=choices[3],
Question=example["Question"])
query_prompt = GPQA_QUERY_TEMPLATE.format(
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"]
)
gold_choice = "ABCD"[gold_index]
return query_prompt, gold_choice
data_source = 'Idavidrein/gpqa'
data_source = "Idavidrein/gpqa"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, "gpqa_diamond", split="train")
map_fn = partial(example_map_fn,
process_fn=process_gpqa_diamond,
data_source=data_source,
ability="Math",
split="test")
map_fn = partial(
example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test"
)
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)
return dataset
def build_cnmo2024_dataset():
def process_cnmo2024(example):
return example["question"], example["answer"]
data_source = 'opencompass/LiveMathBench'
data_source = "opencompass/LiveMathBench"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test")
map_fn_en = partial(example_map_fn,
process_fn=process_cnmo2024,
data_source='opencompass/cnmo2024_en',
ability="Math",
split="test")
map_fn_en = partial(
example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test"
)
dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names)
dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test")
map_fn_zh = partial(example_map_fn,
process_fn=process_cnmo2024,
data_source='opencompass/cnmo2024_zh',
ability="Math",
split="test")
map_fn_zh = partial(
example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test"
)
dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names)
dataset = concatenate_datasets([dataset_en, dataset_zh])
@ -120,22 +101,28 @@ def build_cnmo2024_dataset():
def build_livecodebench_dataset():
import json, pickle, zlib, base64
import base64
import json
import pickle
import zlib
def process_livecodebench(example):
# Construct Query Prompt
# From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140
query_prompt = (
"You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests.\n\n"
f"Question: {example['question_content']}\n\n")
f"Question: {example['question_content']}\n\n"
)
if example["starter_code"]:
query_prompt += (
"You will use the following starter code to write the solution to the problem and enclose your code within delimiters.\n"
f"```python\n{example['starter_code']}\n```")
f"```python\n{example['starter_code']}\n```"
)
else:
query_prompt += (
"Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT."
f"```python\n# YOUR CODE HERE\n```")
"```python\n# YOUR CODE HERE\n```"
)
# Construct test cases
public_test_cases = json.loads(example["public_test_cases"])
@ -143,7 +130,8 @@ def build_livecodebench_dataset():
private_test_cases = json.loads(example["private_test_cases"])
except:
private_test_cases = json.loads(
pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))))
pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8"))))
)
full_test_cases = public_test_cases + private_test_cases
metadata = json.loads(example["metadata"])
@ -155,16 +143,14 @@ def build_livecodebench_dataset():
text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8")
return query_prompt, text_cases_compressed
data_source = 'livecodebench/code_generation_lite'
data_source = "livecodebench/code_generation_lite"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, split="test")
# R1 Evaluation use LiveCodeBench 24.08-25.01
dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00")
map_fn = partial(example_map_fn,
process_fn=process_livecodebench,
data_source=data_source,
ability="Code",
split="test")
map_fn = partial(
example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test"
)
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8)
return dataset
@ -178,18 +164,18 @@ TASK2DATA = {
}
SUPPORTED_TASKS = TASK2DATA.keys()
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/r1')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument('--tasks', default="all")
parser.add_argument("--local_dir", default="~/data/r1")
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--tasks", default="all")
args = parser.parse_args()
if args.tasks.lower() == "all":
args.tasks = SUPPORTED_TASKS
else:
args.tasks = [task.strip() for task in args.tasks.split(',') if task.strip()]
args.tasks = [task.strip() for task in args.tasks.split(",") if task.strip()]
for task in args.tasks:
if task not in SUPPORTED_TASKS:
raise NotImplementedError(f"{task} has not been supported.")
@ -202,7 +188,7 @@ if __name__ == '__main__':
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)

View File

@ -17,17 +17,20 @@ The input is a parquet file that contains N generated sequences and (optional) t
"""
import hydra
from verl.utils.fs import copy_to_local
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import hydra
import numpy as np
import pandas as pd
import ray
from tqdm import tqdm
from verl.utils.fs import copy_to_local
def get_custom_reward_fn(config):
import importlib.util, os
import importlib.util
import os
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
@ -56,12 +59,12 @@ def get_custom_reward_fn(config):
@ray.remote
def process_item(reward_fn, data_source, response_lst, reward_data):
ground_truth = reward_data['ground_truth']
ground_truth = reward_data["ground_truth"]
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
return data_source, np.mean(score_lst)
@hydra.main(config_path='config', config_name='evaluation', version_base=None)
@hydra.main(config_path="config", config_name="evaluation", version_base=None)
def main(config):
local_path = copy_to_local(config.data.path)
dataset = pd.read_parquet(local_path)
@ -97,10 +100,10 @@ def main(config):
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f'test_score/{data_source}'] = np.mean(rewards)
metric_dict[f"test_score/{data_source}"] = np.mean(rewards)
print(metric_dict)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -14,14 +14,17 @@
def reward_func(data_source, solution_str, ground_truth, extra_info=None):
if data_source in ['Maxwell-Jia/AIME_2024', "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]:
if data_source in ["Maxwell-Jia/AIME_2024", "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]:
from recipe.r1.tasks import math
return math.compute_score(solution_str, ground_truth)
elif data_source == 'Idavidrein/gpqa':
elif data_source == "Idavidrein/gpqa":
from recipe.r1.tasks import gpqa
return gpqa.compute_score(solution_str, ground_truth)
elif data_source in ['livecodebench/code_generation_lite', 'livecodebench/code_generation']:
elif data_source in ["livecodebench/code_generation_lite", "livecodebench/code_generation"]:
from recipe.r1.tasks import livecodebench
return livecodebench.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import base64
import json
import multiprocessing
import pickle
import zlib
import base64
# Reuse `run_test` for convenience
from verl.utils.reward_score.prime_code.testing_util import run_test
@ -48,12 +48,12 @@ def check_correctness(in_outs, generation, timeout, debug=True):
# consider that all tests failed
result = [[-1 for i in range(len(in_outs["inputs"]))]]
if debug:
print(f"global timeout")
print("global timeout")
return result[0], metadata_list[0]
def compute_score(completion, test_cases):
solution = completion.split('```python')[-1].split('```')[0]
solution = completion.split("```python")[-1].split("```")[0]
# extract test cases
try:
@ -65,7 +65,7 @@ def compute_score(completion, test_cases):
try:
res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False)
success = all(map(lambda x: x == True, res))
except Exception as e:
except Exception:
pass
return success

View File

@ -14,7 +14,7 @@
try:
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
print("To use Math-Verify, please install it first by running `pip install math-verify`.")
@ -24,13 +24,13 @@ def compute_score(model_output: str, ground_truth: str) -> bool:
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
)
ret_score = 0.
ret_score = 0.0
# Wrap the ground truth in \boxed{} format for verification
ground_truth_boxed = "\\boxed{" + ground_truth + "}"
try:
ret_score, _ = verify_func([ground_truth_boxed], [model_output])
except Exception as e:
except Exception:
pass
return ret_score

View File

@ -13,7 +13,7 @@ peft
pyarrow>=15.0.0
pybind11
pylatexenc
pylint==3.3.6
pre-commit
ray[default]
tensordict<=0.6.2
torchdata

View File

@ -13,47 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Dict
import re
import os
import torch
import argparse
import os
import warnings
import numpy as np
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from concurrent.futures import ThreadPoolExecutor
from safetensors.torch import load_file
from torch.distributed._tensor import Shard, Placement
from verl.utils.megatron_utils import get_model, convert_config
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core import parallel_state as mpu
import torch
from megatron.core import dist_checkpointing
from megatron.core import parallel_state as mpu
from megatron.core.dist_checkpointing.serialization import StrictHandling
from megatron.core.models.gpt.gpt_model import ModelType
from transformers import AutoConfig, AutoModelForCausalLM
from verl.utils.megatron_utils import convert_config, get_model
def _init_args():
parser = argparse.ArgumentParser()
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
parser.add_argument('--output_path', type=str, required=True, help="The path for the output mcore model")
parser.add_argument('--test', action='store_true', help="Whether to test the conversion")
parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model")
parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model")
parser.add_argument("--test", action="store_true", help="Whether to test the conversion")
args = parser.parse_args()
return args
class MegatronConfig:
def __init__(self):
self.params_dtype = torch.bfloat16
class ModelConfig:
def __init__(self):
self.path = None
class Config:
def __init__(self):
self.model = ModelConfig()
@ -65,15 +58,17 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
return
# init torch distributed and mpu
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
torch.distributed.init_process_group('nccl')
mpu.initialize_model_parallel(tensor_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
expert_model_parallel_size=1)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl")
mpu.initialize_model_parallel(
tensor_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
expert_model_parallel_size=1,
)
# init hf config
hf_config = AutoConfig.from_pretrained(hf_model_path)
@ -87,17 +82,20 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
# init megatron model
def megatron_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_gptmodel_from_config
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=tie_word_embeddings,
value=False)
parallel_model = get_parallel_gptmodel_from_config(
tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=tie_word_embeddings,
value=False,
)
return parallel_model
model = get_model(model_provider_func=megatron_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
model = get_model(
model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@ -108,11 +106,14 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
# load hf state dict to megatron model
from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
load_state_dict_to_megatron_gptmodel(state_dict=ref_state_dict,
wrapped_models=model,
config=hf_config,
params_dtype=torch.bfloat16,
is_value_model=False)
load_state_dict_to_megatron_gptmodel(
state_dict=ref_state_dict,
wrapped_models=model,
config=hf_config,
params_dtype=torch.bfloat16,
is_value_model=False,
)
ssd = model[0].module.module.sharded_state_dict()
del ref_state_dict, hf_model
@ -122,9 +123,9 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
if test:
########### test ###########
# load model
model_test = get_model(model_provider_func=megatron_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
model_test = get_model(
model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True
)
ssd2 = model_test[0].module.module.sharded_state_dict()
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)
@ -136,7 +137,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
d1 = sd[k].data
if k in sd2:
d2 = sd2[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}"
assert (d1 == d2).all(), f"{k} is not equal"
for k in sd2.keys():
if sd2[k] is None:
@ -144,24 +145,24 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
d1 = sd2[k].data
if k in sd:
d2 = sd[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}"
assert (d1 == d2).all(), f"{k} is not equal"
# load value model
def megatron_value_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_gptmodel_from_config
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=False,
value=True)
parallel_model = get_parallel_gptmodel_from_config(
tfconfig, hf_config, pre_process, post_process, share_embeddings_and_output_weights=False, value=True
)
parallel_model.cuda()
return parallel_model
model_value = get_model(model_provider_func=megatron_value_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
model_value = get_model(
model_provider_func=megatron_value_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True,
)
ssd2 = model_value[0].module.module.sharded_state_dict()
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL)
@ -173,7 +174,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
d1 = sd[k].data
if k in sd2:
d2 = sd2[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}"
assert (d1 == d2).all(), f"{k} is not equal"
for k in sd2.keys():
if sd2[k] is None:
@ -181,7 +182,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
d1 = sd2[k].data
if k in sd:
d2 = sd[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}"
assert (d1 == d2).all(), f"{k} is not equal"

View File

@ -14,28 +14,35 @@
"""Diagnose script for checking OS/hardware/python/pip/verl/network.
The output of this script can be a very good hint to issue/problem.
"""
import os
import platform
import socket
import subprocess
import sys
import time
import psutil
import platform, subprocess, sys, os
import socket, time
try:
from urllib.request import urlopen
from urllib.parse import urlparse
from urllib.request import urlopen
except ImportError:
from urlparse import urlparse
from urllib2 import urlopen
from urlparse import urlparse
import argparse
import importlib.metadata
import torch
URLS = {
'PYPI': 'https://pypi.python.org/pypi/pip',
"PYPI": "https://pypi.python.org/pypi/pip",
}
REGIONAL_URLS = {
'cn': {
'PYPI(douban)': 'https://pypi.douban.com/',
'Conda(tsinghua)': 'https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/',
"cn": {
"PYPI(douban)": "https://pypi.douban.com/",
"Conda(tsinghua)": "https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
}
}
@ -47,7 +54,7 @@ def test_connection(name, url, timeout=10):
try:
ip = socket.gethostbyname(urlinfo.netloc)
except Exception as e:
print('Error resolving DNS for {}: {}, {}'.format(name, url, e))
print("Error resolving DNS for {}: {}, {}".format(name, url, e))
return
dns_elapsed = time.time() - start
start = time.time()
@ -61,26 +68,27 @@ def test_connection(name, url, timeout=10):
def check_python():
print('----------Python Info----------')
print('Version :', platform.python_version())
print('Compiler :', platform.python_compiler())
print('Build :', platform.python_build())
print('Arch :', platform.architecture())
print("----------Python Info----------")
print("Version :", platform.python_version())
print("Compiler :", platform.python_compiler())
print("Build :", platform.python_build())
print("Arch :", platform.architecture())
def check_pip():
print('------------Pip Info-----------')
print("------------Pip Info-----------")
try:
import pip
print('Version :', pip.__version__)
print('Directory :', os.path.dirname(pip.__file__))
print("Version :", pip.__version__)
print("Directory :", os.path.dirname(pip.__file__))
except ImportError:
print('No corresponding pip install for current python.')
print("No corresponding pip install for current python.")
def _get_current_git_commit():
try:
result = subprocess.run(['git', 'rev-parse', 'HEAD'], capture_output=True, text=True, check=True)
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
print(f"Error running git command: {e.stderr.strip()}")
@ -91,22 +99,24 @@ def _get_current_git_commit():
def check_verl():
print('----------verl Info-----------')
print("----------verl Info-----------")
try:
sys.path.insert(0, os.getcwd())
import verl
print('Version :', verl.__version__)
print("Version :", verl.__version__)
verl_dir = os.path.dirname(verl.__file__)
print('Directory :', verl_dir)
print("Directory :", verl_dir)
try:
commit_hash = _get_current_git_commit()
print('Commit Hash :', commit_hash)
print("Commit Hash :", commit_hash)
except AttributeError:
print('Commit hash not found. ')
print("Commit hash not found. ")
except ImportError as e:
print(f'No verl installed: {e}')
print(f"No verl installed: {e}")
except Exception as e:
import traceback
if not isinstance(e, IOError):
print("An error occurred trying to import verl.")
print("This is very likely due to missing missing or incompatible library files.")
@ -114,36 +124,36 @@ def check_verl():
def check_os():
print('----------Platform Info----------')
print('Platform :', platform.platform())
print('system :', platform.system())
print('node :', platform.node())
print('release :', platform.release())
print('version :', platform.version())
print("----------Platform Info----------")
print("Platform :", platform.platform())
print("system :", platform.system())
print("node :", platform.node())
print("release :", platform.release())
print("version :", platform.version())
def check_hardware():
print('----------Hardware Info----------')
print('machine :', platform.machine())
print('processor :', platform.processor())
if sys.platform.startswith('darwin'):
pipe = subprocess.Popen(('sysctl', '-a'), stdout=subprocess.PIPE)
print("----------Hardware Info----------")
print("machine :", platform.machine())
print("processor :", platform.processor())
if sys.platform.startswith("darwin"):
pipe = subprocess.Popen(("sysctl", "-a"), stdout=subprocess.PIPE)
output = pipe.communicate()[0]
for line in output.split(b'\n'):
if b'brand_string' in line or b'features' in line:
for line in output.split(b"\n"):
if b"brand_string" in line or b"features" in line:
print(line.strip())
elif sys.platform.startswith('linux'):
subprocess.call(['lscpu'])
elif sys.platform.startswith('win32'):
subprocess.call(['wmic', 'cpu', 'get', 'name'])
elif sys.platform.startswith("linux"):
subprocess.call(["lscpu"])
elif sys.platform.startswith("win32"):
subprocess.call(["wmic", "cpu", "get", "name"])
def check_network(args):
print('----------Network Test----------')
print("----------Network Test----------")
if args.timeout > 0:
print('Setting timeout: {}'.format(args.timeout))
print("Setting timeout: {}".format(args.timeout))
socket.setdefaulttimeout(10)
for region in args.region.strip().split(','):
for region in args.region.strip().split(","):
r = region.strip().lower()
if not r:
continue
@ -151,20 +161,21 @@ def check_network(args):
URLS.update(REGIONAL_URLS[r])
else:
import warnings
warnings.warn('Region {} do not need specific test, please refer to global sites.'.format(r))
warnings.warn("Region {} do not need specific test, please refer to global sites.".format(r))
for name, url in URLS.items():
test_connection(name, url, args.timeout)
def check_environment():
print('----------Environment----------')
print("----------Environment----------")
for k, v in os.environ.items():
if k.startswith('VERL_') or k.startswith('OMP_') or k.startswith('KMP_') or k == 'CC' or k == 'CXX':
if k.startswith("VERL_") or k.startswith("OMP_") or k.startswith("KMP_") or k == "CC" or k == "CXX":
print('{}="{}"'.format(k, v))
def check_pip_package_versions():
packages = ['vllm', 'sglang', 'ray', 'torch']
packages = ["vllm", "sglang", "ray", "torch"]
for package in packages:
try:
version = importlib.metadata.version(package)
@ -179,8 +190,9 @@ def check_cuda_versions():
cuda_runtime_version = torch.version.cuda
print(f"CUDA Runtime : {cuda_runtime_version}")
import subprocess
nvcc_output = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
cuda_compiler_version = next((line for line in nvcc_output.splitlines() if 'release' in line), None)
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_compiler_version = next((line for line in nvcc_output.splitlines() if "release" in line), None)
if cuda_compiler_version:
print(f"CUDA Compiler : {cuda_compiler_version.strip()}")
else:
@ -206,19 +218,23 @@ def _get_gpu_info():
Get GPU type, GPU memory, and GPU count using nvidia-smi command.
"""
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=gpu_name,memory.total', '--format=csv,noheader,nounits'],
capture_output=True,
text=True,
check=True)
gpu_lines = result.stdout.strip().split('\n')
result = subprocess.run(
["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader,nounits"],
capture_output=True,
text=True,
check=True,
)
gpu_lines = result.stdout.strip().split("\n")
gpu_count = len(gpu_lines)
gpu_info = []
for line in gpu_lines:
gpu_name, gpu_memory = line.split(', ')
gpu_info.append({
'type': gpu_name,
'memory': float(gpu_memory) / 1024 # Convert to GB
})
gpu_name, gpu_memory = line.split(", ")
gpu_info.append(
{
"type": gpu_name,
"memory": float(gpu_memory) / 1024, # Convert to GB
}
)
return gpu_count, gpu_info
except subprocess.CalledProcessError:
print("Failed to execute nvidia-smi command.")
@ -231,39 +247,43 @@ def _get_system_info():
"""
cpu_memory = _get_cpu_memory()
gpu_count, gpu_info = _get_gpu_info()
return {'cpu_memory': cpu_memory, 'gpu_count': gpu_count, 'gpu_info': gpu_info}
return {"cpu_memory": cpu_memory, "gpu_count": gpu_count, "gpu_info": gpu_info}
def check_system_info():
print('----------System Info----------')
print("----------System Info----------")
system_info = _get_system_info()
print(f"CPU Memory\t: {system_info['cpu_memory']:.2f} GB")
print(f"GPU Count\t: {system_info['gpu_count']}")
for i, gpu in enumerate(system_info['gpu_info']):
for i, gpu in enumerate(system_info["gpu_info"]):
print(f"GPU {i + 1}\tType : {gpu['type']}")
print(f"GPU {i + 1}\tMemory : {gpu['memory']:.2f} GB")
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description='Diagnose script for checking the current system.')
choices = ['python', 'pip', 'verl', 'system', 'os', 'environment']
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Diagnose script for checking the current system.",
)
choices = ["python", "pip", "verl", "system", "os", "environment"]
for choice in choices:
parser.add_argument('--' + choice, default=1, type=int, help='Diagnose {}.'.format(choice))
parser.add_argument('--network', default=0, type=int, help='Diagnose network.')
parser.add_argument('--hardware', default=0, type=int, help='Diagnose hardware.')
parser.add_argument('--region',
default='',
type=str,
help="Additional sites in which region(s) to test. \
Specify 'cn' for example to test mirror sites in China.")
parser.add_argument('--timeout', default=10, type=int, help="Connection test timeout threshold, 0 to disable.")
parser.add_argument("--" + choice, default=1, type=int, help="Diagnose {}.".format(choice))
parser.add_argument("--network", default=0, type=int, help="Diagnose network.")
parser.add_argument("--hardware", default=0, type=int, help="Diagnose hardware.")
parser.add_argument(
"--region",
default="",
type=str,
help="Additional sites in which region(s) to test. \
Specify 'cn' for example to test mirror sites in China.",
)
parser.add_argument("--timeout", default=10, type=int, help="Connection test timeout threshold, 0 to disable.")
args = parser.parse_args()
return args
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
if args.python:
check_python()

View File

@ -1,3 +0,0 @@
#!/bin/bash
pip3 install --upgrade yapf
python3 -m yapf -ir -vv --style ./.style.yapf verl tests examples recipe scripts

View File

@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Dict
import re
import os
import torch
import argparse
import numpy as np
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple
import numpy as np
import torch
from safetensors.torch import load_file
from torch.distributed._tensor import Shard, Placement
from torch.distributed._tensor import Placement, Shard
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
try:
# for torch 2.5+
from torch.distributed.tensor import DTensor
@ -29,28 +31,31 @@ except ImportError:
from torch.distributed._tensor import DTensor
parser = argparse.ArgumentParser()
parser.add_argument('--backend', type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"])
parser.add_argument('--tie-word-embedding', action='store_true', help="Whether to tie word embedding weights")
parser.add_argument('--is-value-model', action='store_true', help="Whether the model loaded as value model")
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
parser.add_argument("--backend", type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"])
parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights")
parser.add_argument("--is-value-model", action="store_true", help="Whether the model loaded as value model")
parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model")
parser.add_argument(
'--local_dir',
"--local_dir",
type=str,
required=True,
help=
"The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`."
help="The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`.",
)
parser.add_argument('--target_dir', required=False, default="tmp", type=str, help="The path for the target model")
parser.add_argument("--target_dir", required=False, default="tmp", type=str, help="The path for the target model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
parser.add_argument("--test", action="store_true", help="test correctness of hf_model")
parser.add_argument("--test_hf_dir",
type=str,
required=False,
help="test correctness of hf_model, , with hf_model in checkpoint.contents")
parser.add_argument(
"--test_hf_dir",
type=str,
required=False,
help="test correctness of hf_model, , with hf_model in checkpoint.contents",
)
args = parser.parse_args()
os.makedirs(args.target_dir, exist_ok=True)
if args.test:
assert args.test_hf_dir is not None, f'You must run verl save checkpoint first, with hf_model in checkpoint.contents, and provide the directory here'
assert args.test_hf_dir is not None, (
"You must run verl save checkpoint first, with hf_model in checkpoint.contents, and provide the directory here"
)
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
@ -67,6 +72,7 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
def upload_model_to_huggingface(hf_path):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
@ -85,9 +91,9 @@ def convert_fsdp_checkpoints_to_hfmodels():
break
assert world_size, "No model file with the proper format"
state_dict = torch.load(os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt'),
map_location='cpu',
weights_only=False)
state_dict = torch.load(
os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu", weights_only=False
)
pivot_key = sorted(list(state_dict.keys()))[0]
weight = state_dict[pivot_key]
@ -99,13 +105,13 @@ def convert_fsdp_checkpoints_to_hfmodels():
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ('fsdp',)
mesh_dim_names = ("fsdp",)
print(f'Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}')
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (('fsdp',), ('ddp', 'fsdp')), f'Unsupported mesh_dim_names {mesh_dim_names}'
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}"
if 'tp' in mesh_dim_names:
if "tp" in mesh_dim_names:
# fsdp * tp
total_shards = mesh.shape[-1] * mesh.shape[-2]
mesh_shape = (mesh.shape[-2], mesh.shape[-1])
@ -114,21 +120,21 @@ def convert_fsdp_checkpoints_to_hfmodels():
total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],)
print(f'Processing model shards with {total_shards} {mesh_shape} in total')
print(f"Processing model shards with {total_shards} {mesh_shape} in total")
model_state_dict_lst = []
model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank):
model_path = os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt')
state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
def process_one_shard(rank, model_state_dict_lst):
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
model_state_dict_lst[rank] = state_dict
return state_dict
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards):
executor.submit(process_one_shard, rank)
executor.submit(process_one_shard, rank, model_state_dict_lst)
state_dict = {}
param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys())
@ -144,9 +150,7 @@ def convert_fsdp_checkpoints_to_hfmodels():
state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
# replicated placement at dp dimension can be discarded
if mesh_dim_names[0] == 'dp':
placements = placements[1:]
elif mesh_dim_names[0] == 'ddp':
if mesh_dim_names[0] == "dp" or mesh_dim_names[0] == "ddp":
placements = placements[1:]
if key not in param_placements:
param_placements[key] = placements
@ -175,27 +179,27 @@ def convert_fsdp_checkpoints_to_hfmodels():
else:
state_dict[key] = torch.cat(state_dict[key], dim=0)
print('Writing to local disk')
print("Writing to local disk")
if args.target_dir is None:
hf_path = os.path.join(local_dir, 'huggingface')
hf_path = os.path.join(local_dir, "huggingface")
else:
hf_path = args.target_dir
config = AutoConfig.from_pretrained(args.hf_model_path)
if 'ForTokenClassification' in config.architectures[0]:
if "ForTokenClassification" in config.architectures[0]:
auto_model = AutoModelForTokenClassification
elif 'ForCausalLM' in config.architectures[0]:
elif "ForCausalLM" in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif 'ForConditionalGeneration' in config.architectures[0]:
elif "ForConditionalGeneration" in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f'Unknown architecture {config["architectures"]}')
raise NotImplementedError(f"Unknown architecture {config['architectures']}")
with torch.device('meta'):
with torch.device("meta"):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model.to_empty(device='cpu')
model.to_empty(device="cpu")
print(f'Saving model to {hf_path}')
print(f"Saving model to {hf_path}")
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict
del model
@ -217,7 +221,7 @@ def check_megatron_checkpoint_path(model_path):
for sharded_dir in sharded_dirs:
match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir)
assert match, f"Invalid sharded dir {sharded_dir}"
assert f"model.pt" in os.listdir(os.path.join(model_path, sharded_dir)), f"model.pt not found in {sharded_dir}"
assert "model.pt" in os.listdir(os.path.join(model_path, sharded_dir)), f"model.pt not found in {sharded_dir}"
tp_rank = int(match.group(1))
pp_rank = int(match.group(2))
if tp_size < tp_rank + 1:
@ -228,7 +232,7 @@ def check_megatron_checkpoint_path(model_path):
def convert_megatron_checkpoints_to_hfmodels():
from verl.utils.megatron_utils import get_model_checkpoint_path, get_hf_model_checkpoint_path
from verl.utils.megatron_utils import get_hf_model_checkpoint_path, get_model_checkpoint_path
local_path = args.local_dir
@ -243,11 +247,11 @@ def convert_megatron_checkpoints_to_hfmodels():
for j in range(tp_size):
model_state_dict_lst[i].append("")
print(f'sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {mp_size}')
print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {mp_size}")
def process_one_shard(shard_dir):
def process_one_shard(shard_dir, model_state_dict_lst):
model_path = os.path.join(model_ckpt_path, shard_dir, "model.pt")
state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
tp_rank, pp_rank = get_tp_pp_rank_from_sharded_dir(shard_dir)
model_state_dict_lst[pp_rank][tp_rank] = state_dict
@ -255,12 +259,12 @@ def convert_megatron_checkpoints_to_hfmodels():
# for rank in range(1, mp_size):
# executor.submit(process_one_shard, sharded_dirs[rank])
for sharded_dir in sharded_dirs:
process_one_shard(sharded_dir)
process_one_shard(sharded_dir, model_state_dict_lst)
state_dict = {}
config = AutoConfig.from_pretrained(args.hf_model_path)
if args.test:
ref_state_dict = load_file(os.path.join(args.test_hf_dir, 'model.safetensors'))
ref_state_dict = load_file(os.path.join(args.test_hf_dir, "model.safetensors"))
def merge_across_tp(key, tp_data):
if "linear_fc1.weight" in key:
@ -274,7 +278,7 @@ def convert_megatron_checkpoints_to_hfmodels():
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
tp_data = [gate, up]
elif "self_attention.linear_qkv." in key and 'layer_norm' not in key:
elif "self_attention.linear_qkv." in key and "layer_norm" not in key:
# if the tensor is qkv, for each param on tp, split into q, k, v
# concat q, k, v separately.
q_lst = []
@ -291,7 +295,7 @@ def convert_megatron_checkpoints_to_hfmodels():
split_size = [
kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
kv_size_per_tp // num_query_groups_per_partition,
kv_size_per_tp // num_query_groups_per_partition
kv_size_per_tp // num_query_groups_per_partition,
]
q, k, v = chunk.split(split_size)
q_lst.append(q)
@ -323,16 +327,16 @@ def convert_megatron_checkpoints_to_hfmodels():
if "extra_state" in key:
continue
if args.tie_word_embedding and ("output_layer" in key):
print(f'skip lm_head and reward_head loading because of tie_word_embeddings')
print("skip lm_head and reward_head loading because of tie_word_embeddings")
continue
new_key = key
if "decoder.layers." in key:
local_layer_no = int(key.split('.')[2])
local_layer_no = int(key.split(".")[2])
layers_handled = max(local_layer_no, layers_handled)
global_layer_no = local_layer_no + layers_cum
new_key_list = key.split('.')
new_key_list = key.split(".")
new_key_list[2] = str(global_layer_no)
new_key = '.'.join(new_key_list)
new_key = ".".join(new_key_list)
tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)]
merged = merge_across_tp(new_key, tp_data)
@ -340,7 +344,7 @@ def convert_megatron_checkpoints_to_hfmodels():
state_dict[new_key] = merged
elif len(merged) == 3:
# split qkv
for n, d in zip(['q', 'k', 'v'], merged):
for n, d in zip(["q", "k", "v"], merged):
state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d
elif len(merged) == 2:
# split gate up
@ -370,7 +374,6 @@ def convert_megatron_checkpoints_to_hfmodels():
]
if args.test:
for original_name, loaded_weight in state_dict.items():
name = _replace_name(original_name, params_mapping)
if not name or name.endswith(".bias") and name not in ref_state_dict:
@ -380,31 +383,31 @@ def convert_megatron_checkpoints_to_hfmodels():
if args.tie_word_embedding and "lm_head.weight" in name:
continue
if name not in ref_state_dict:
raise RuntimeError(f'key: {name} not exist in state_dict')
raise RuntimeError(f"key: {name} not exist in state_dict")
param = ref_state_dict[name]
assert loaded_weight.dtype == param.dtype
torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4)
print('Writing to local disk')
print("Writing to local disk")
if args.target_dir is None:
hf_path = os.path.join(args.local_dir, 'huggingface')
hf_path = os.path.join(args.local_dir, "huggingface")
else:
hf_path = args.target_dir
if 'ForTokenClassification' in config.architectures[0]:
if "ForTokenClassification" in config.architectures[0]:
auto_model = AutoModelForTokenClassification
elif 'ForCausalLM' in config.architectures[0]:
elif "ForCausalLM" in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif 'ForConditionalGeneration' in config.architectures[0]:
elif "ForConditionalGeneration" in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f'Unknown architecture {config["architectures"]}')
raise NotImplementedError(f"Unknown architecture {config['architectures']}")
with torch.device('meta'):
with torch.device("meta"):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model.to_empty(device='cpu')
model.to_empty(device="cpu")
print(f'Saving model to {hf_path}')
print(f"Saving model to {hf_path}")
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict
del model
@ -435,7 +438,7 @@ def _replace_name(megatron_name, name_mapping):
return param_name
if __name__ == '__main__':
if __name__ == "__main__":
if args.backend == "fsdp":
convert_fsdp_checkpoints_to_hfmodels()
elif args.backend == "megatron":

100
setup.py
View File

@ -13,75 +13,79 @@
# limitations under the License.
# setup.py is the fallback installation script when pyproject.toml does not work
from setuptools import setup, find_packages
import os
from pathlib import Path
from setuptools import find_packages, setup
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
with open(os.path.join(version_folder, 'verl/version/version')) as f:
with open(os.path.join(version_folder, "verl/version/version")) as f:
__version__ = f.read().strip()
install_requires = [
'accelerate',
'codetiming',
'datasets',
'dill',
'hydra-core',
'numpy',
'pandas',
'datasets',
'peft',
'pyarrow>=15.0.0',
'pybind11',
'pylatexenc',
'ray[default]>=2.10',
'tensordict<=0.6.2',
'torchdata',
'transformers',
'wandb',
"accelerate",
"codetiming",
"datasets",
"dill",
"hydra-core",
"numpy",
"pandas",
"datasets",
"peft",
"pyarrow>=15.0.0",
"pybind11",
"pylatexenc",
"ray[default]>=2.10",
"tensordict<=0.6.2",
"torchdata",
"transformers",
"wandb",
]
TEST_REQUIRES = ['pytest', 'yapf', 'py-spy']
PRIME_REQUIRES = ['pyext']
GEO_REQUIRES = ['mathruler']
GPU_REQUIRES = ['liger-kernel', 'flash-attn']
MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2']
TEST_REQUIRES = ["pytest", "pre-commit", "py-spy"]
PRIME_REQUIRES = ["pyext"]
GEO_REQUIRES = ["mathruler"]
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.2"]
SGLANG_REQUIRES = [
'tensordict<=0.6.2',
'sglang[all]==0.4.4.post4',
'torch-memory-saver>=0.0.5'
"tensordict<=0.6.2",
"sglang[all]==0.4.4.post4",
"torch-memory-saver>=0.0.5",
]
extras_require = {
'test': TEST_REQUIRES,
'prime': PRIME_REQUIRES,
'geo': GEO_REQUIRES,
'gpu': GPU_REQUIRES,
'math': MATH_REQUIRES,
'vllm': VLLM_REQUIRES,
'sglang': SGLANG_REQUIRES,
"test": TEST_REQUIRES,
"prime": PRIME_REQUIRES,
"geo": GEO_REQUIRES,
"gpu": GPU_REQUIRES,
"math": MATH_REQUIRES,
"vllm": VLLM_REQUIRES,
"sglang": SGLANG_REQUIRES,
}
from pathlib import Path
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
setup(
name='verl',
name="verl",
version=__version__,
package_dir={'': '.'},
packages=find_packages(where='.'),
url='https://github.com/volcengine/verl',
license='Apache 2.0',
author='Bytedance - Seed - MLSys',
author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk',
description='verl: Volcano Engine Reinforcement Learning for LLM',
package_dir={"": "."},
packages=find_packages(where="."),
url="https://github.com/volcengine/verl",
license="Apache 2.0",
author="Bytedance - Seed - MLSys",
author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk",
description="verl: Volcano Engine Reinforcement Learning for LLM",
install_requires=install_requires,
extras_require=extras_require,
package_data={'': ['version/*'],
'verl': ['trainer/config/*.yaml'],},
package_data={
"": ["version/*"],
"verl": ["trainer/config/*.yaml"],
},
include_package_data=True,
long_description=long_description,
long_description_content_type='text/markdown'
)
long_description_content_type="text/markdown",
)

View File

@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.

View File

@ -12,65 +12,65 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import shutil
import tempfile
import torch
import copy
import torch.distributed
from torch.distributed import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Qwen2Config
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.distributed import initialize_global_process_group
def test_fsdp_ckpt():
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",))
model_name = 'Qwen/Qwen2.5-0.5B-Instruct'
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
config = Qwen2Config(num_hidden_layers=1)
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device="cuda")
# Wrap model with FSDP
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
model = FSDP(model,
use_orig_params=False,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh)
model = FSDP(
model,
use_orig_params=False,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
# Create checkpoint manager
tokenizer = AutoTokenizer.from_pretrained(model_name)
checkpoint_manager = FSDPCheckpointManager(model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
tokenizer=tokenizer)
checkpoint_manager = FSDPCheckpointManager(
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer
)
# Generate sample input
batch_size = 2
seq_len = 32
vocab_size = 32000
# First input for initial update
input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
attention_mask1 = torch.ones_like(input_ids1)
# Second input for verification
input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
attention_mask2 = torch.ones_like(input_ids2)
# Step 1: Initial update and save checkpoint
@ -83,7 +83,7 @@ def test_fsdp_ckpt():
# Save checkpoint after first update
temp_dir = tempfile.mkdtemp()
checkpoint_path = os.path.join(temp_dir, 'checkpoint')
checkpoint_path = os.path.join(temp_dir, "checkpoint")
checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
# Step 2: Second update and forward pass
@ -122,5 +122,5 @@ def test_fsdp_ckpt():
torch.distributed.barrier()
if __name__ == '__main__':
if __name__ == "__main__":
test_fsdp_ckpt()

View File

@ -14,104 +14,108 @@
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ["NCCL_DEBUG"] = "WARN"
from verl.protocol import all_gather_data_proto, DataProto
from verl.utils.distributed import initialize_global_process_group
import numpy as np
import torch
import torch.distributed
import numpy as np
from verl.protocol import DataProto, all_gather_data_proto
from verl.utils.distributed import initialize_global_process_group
def test_all_gather_data_proto():
device_mesh = torch.distributed.device_mesh.init_device_mesh('cuda', mesh_shape=[2, 2], mesh_dim_names=['dp', 'tp'])
device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"])
global_rank = torch.distributed.get_rank()
obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])
labels = ['a', 'b'] if global_rank % 2 == 0 else ['b', 'a']
labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"]
labels = np.array(labels, dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
all_gather_data_proto(data=data, process_group=device_mesh.get_group('dp'))
all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp"))
if global_rank == 0:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda')
expected_labels = ['a', 'b', 'a', 'b']
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
expected_labels = ["a", "b", "a", "b"]
elif global_rank == 1:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda')
expected_labels = ['b', 'a', 'b', 'a']
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
expected_labels = ["b", "a", "b", "a"]
elif global_rank == 2:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda')
expected_labels = ['a', 'b', 'a', 'b']
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
expected_labels = ["a", "b", "a", "b"]
elif global_rank == 3:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda')
expected_labels = ['b', 'a', 'b', 'a']
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
expected_labels = ["b", "a", "b", "a"]
torch.testing.assert_close(data.batch['obs'], expected_obs, atol=0, rtol=0)
assert (data.non_tensor_batch['labels'] == expected_labels).all()
assert data.meta_info == {'info': 'test_info'}
torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0)
assert (data.non_tensor_batch["labels"] == expected_labels).all()
assert data.meta_info == {"info": "test_info"}
def test_vocab_parallel_entropy():
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.torch_functional import entropy_from_logits
from megatron.core import parallel_state as mpu
mpu.initialize_model_parallel(tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None)
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy
from verl.utils.torch_functional import entropy_from_logits
mpu.initialize_model_parallel(
tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None
)
batch_size = 2
seqlen = 128
vocab_size = 155136
logits = torch.randn(batch_size * seqlen, vocab_size, device='cuda', requires_grad=True)
target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device='cuda', dtype=torch.int64)
logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True)
target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64)
# broadcast across tp
torch.distributed.broadcast(logits,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(target,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(
logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
)
torch.distributed.broadcast(
target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
)
tp_rank = mpu.get_tensor_model_parallel_rank()
vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()
# get the local logits of each tp
vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) *
vocab_size_per_tp].requires_grad_()
vocab_parallel_logits = (
logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_()
)
logits.grad = None
vocab_parallel_logits.grad = None
log_gpu_memory_usage('begin')
log_gpu_memory_usage("begin")
output_entropy = vocab_parallel_entropy(vocab_parallel_logits)
log_gpu_memory_usage('after forward')
log_gpu_memory_usage("after forward")
grad_output = torch.randn_like(output_entropy)
output_entropy.backward(grad_output)
log_gpu_memory_usage('after backward')
log_gpu_memory_usage("after backward")
target_entropy = entropy_from_logits(logits)
torch.testing.assert_close(output_entropy, target_entropy)
target_entropy.backward(grad_output)
torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp],
vocab_parallel_logits.grad)
torch.testing.assert_close(
logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad
)
# make sure logits is not altered
torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp],
vocab_parallel_logits)
torch.testing.assert_close(
logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits
)
if mpu.get_tensor_model_parallel_rank() == 0:
print('test_vocab_parallel_entropy passes')
print("test_vocab_parallel_entropy passes")
mpu.destroy_model_parallel()
if __name__ == '__main__':
if __name__ == "__main__":
local_rank, rank, world_size = initialize_global_process_group()
test_all_gather_data_proto()
test_vocab_parallel_entropy()

View File

@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.e2e.envs.digit_completion import DigitCompletion, generate_ground_truth_response
from torch.utils import data
import os
if __name__ == '__main__':
from torch.utils import data
from tests.e2e.envs.digit_completion import DigitCompletion
if __name__ == "__main__":
simple_task = DigitCompletion(max_number=9, max_diff=9, max_num_in_response=9)
all_prompts = simple_task.get_all_prompts()
@ -25,15 +27,13 @@ if __name__ == '__main__':
train_data = list(train_data)
test_data = list(test_data)
train_data = [[{'role': 'user', 'content': str(item)}] \
for item in train_data]
test_data = [[{'role': 'user', 'content': str(item)}] \
for item in test_data]
train_data = [[{"role": "user", "content": str(item)}] for item in train_data]
test_data = [[{"role": "user", "content": str(item)}] for item in test_data]
print(f'Size of train: {len(train_data)}, size of test: {len(test_data)}')
print(f"Size of train: {len(train_data)}, size of test: {len(test_data)}")
train_data = {'prompt': train_data}
test_data = {'prompt': test_data}
train_data = {"prompt": train_data}
test_data = {"prompt": test_data}
model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)))
@ -42,5 +42,5 @@ if __name__ == '__main__':
train_data_frame = pd.DataFrame(train_data)
test_data_frame = pd.DataFrame(test_data)
train_data_frame.to_parquet(os.path.join(model_folder, 'train.parquet'))
test_data_frame.to_parquet(os.path.join(model_folder, 'test.parquet'))
train_data_frame.to_parquet(os.path.join(model_folder, "train.parquet"))
test_data_frame.to_parquet(os.path.join(model_folder, "test.parquet"))

View File

@ -15,28 +15,30 @@
Create a random model and tokenizer for PPO training
"""
import torch
import os
from transformers import AutoModelForCausalLM, LlamaConfig, AutoTokenizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig
from tests.e2e.envs.digit_completion import CharTokenizer
tokenizer = CharTokenizer(
characters=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',', ':'],
characters=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ",", ":"],
model_max_length=2048,
chat_template=
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}"
chat_template="{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}",
)
config = LlamaConfig(vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16,
hidden_size=128,
intermediate_size=344,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id)
config = LlamaConfig(
vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16,
hidden_size=128,
intermediate_size=344,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
@ -50,12 +52,11 @@ tokenizer.save_pretrained(tokenizer_folder)
load_tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder)
chat = [{'role': 'user', 'content': '1,0:2,3'}]
chat = [{"role": "user", "content": "1,0:2,3"}]
load_tokenizer.padding_side = 'left'
load_tokenizer.padding_side = "left"
print(
load_tokenizer.apply_chat_template(chat,
tokenize=True,
add_generation_prompt=True,
max_length=10,
padding='max_length'))
load_tokenizer.apply_chat_template(
chat, tokenize=True, add_generation_prompt=True, max_length=10, padding="max_length"
)
)

View File

@ -14,54 +14,55 @@
"""
Using FSDPTrainer
"""
import os
import hydra
import ray
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformers import AutoTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.fs import copy_to_local
from tests.e2e.envs.digit_completion import CharTokenizer
def make_reward_function(tokenizer, num_examine):
def arithmetic_sequence_reward_function(data: DataProto, return_dict: bool = False):
from tests.e2e.envs.digit_completion.task import compute_reward
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
for i in range(data.batch.batch_size[0]):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch['prompts']
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
# extract raw prompt
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
# extract response
response_ids = data_item.batch['responses']
response_ids = data_item.batch["responses"]
response_length = response_ids.shape[-1]
response_mask = data.batch['attention_mask'][i][-response_length:]
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
response_mask = data.batch["attention_mask"][i][-response_length:]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
prompt = tokenizer.decode(valid_prompt_ids)
response = tokenizer.decode(valid_response_ids)
# remove bos and eos
prompt = prompt.replace(tokenizer.sep_token, '')
response = response.replace(tokenizer.eos_token, '')
prompt = prompt.replace(tokenizer.sep_token, "")
response = response.replace(tokenizer.eos_token, "")
if i < num_examine:
print(prompt, response)
reward_output = compute_reward(prompt, response)
dense_reward = reward_output[0].tolist()
ground_truth_response = reward_output[1]['ground_truth_response']
ground_truth_response = reward_output[1]["ground_truth_response"]
if len(dense_reward) > 0:
last_reward = dense_reward[-1]
else:
@ -85,26 +86,29 @@ def make_reward_function(tokenizer, num_examine):
return arithmetic_sequence_reward_function
@hydra.main(config_path='../../../../verl/trainer/config', config_name='ppo_trainer', version_base=None)
@hydra.main(config_path="../../../../verl/trainer/config", config_name="ppo_trainer", version_base=None)
def main(config):
ray.init(
runtime_env={
'env_vars': {
'MEGATRON_USE_CUDA_TIMER': '0',
'MEGATRON_START_PROCESS_TIMER': 'False',
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN'
"env_vars": {
"MEGATRON_USE_CUDA_TIMER": "0",
"MEGATRON_START_PROCESS_TIMER": "False",
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
}
})
}
)
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
# print the config
# print initial config
print('Config after normalizing batch_size')
print("Config after normalizing batch_size")
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
# download the checkpoint from hdfs
@ -112,18 +116,18 @@ def main(config):
local_path = os.path.expanduser(local_path)
# instantiate tokenizern
tokenizer = AutoTokenizer.from_pretrained(local_path)
print(f'Tokenizer vocab_size: {tokenizer.vocab_size}')
print(f"Tokenizer vocab_size: {tokenizer.vocab_size}")
# define worker classes
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = 'global_pool'
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
@ -141,15 +145,17 @@ def main(config):
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
reward_fn=reward_fn,
val_reward_fn=reward_fn)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
reward_fn=reward_fn,
val_reward_fn=reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -16,17 +16,17 @@ import argparse
def check_congratulations_in_file(output_file):
with open(output_file, 'r') as f:
with open(output_file) as f:
output = f.read()
success_message = "Congratulations!!! You have called my_reward_function successfully!!!"
assert success_message in output, f'Success message of my_reward_function not found in {output_file}'
assert success_message in output, f"Success message of my_reward_function not found in {output_file}"
print("Check passes")
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', required=True, type=str)
parser.add_argument("--output_file", required=True, type=str)
args = parser.parse_args()

View File

@ -20,10 +20,10 @@ import numpy as np
def extract_reward_from_line(line):
# TODO: this function needs error handling
try:
key_vals = line.split(' - ')
key_vals = line.split(" - ")
for key_val in key_vals:
key, val = key_val.split(':')
if key == 'critic/rewards/mean':
key, val = key_val.split(":")
if key == "critic/rewards/mean":
reward = float(val)
return reward
return -np.inf
@ -31,23 +31,23 @@ def extract_reward_from_line(line):
return -np.inf
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', required=True, type=str)
parser.add_argument('--target', type=float, default=0.2, help='target reward score')
parser.add_argument("--output_file", required=True, type=str)
parser.add_argument("--target", type=float, default=0.2, help="target reward score")
args = parser.parse_args()
with open(args.output_file, 'r') as f:
output = f.read().split('\n')
with open(args.output_file) as f:
output = f.read().split("\n")
best_reward = -np.inf
for line in output:
if line.startswith('step'):
if line.startswith("step"):
reward = extract_reward_from_line(line)
if reward > best_reward:
best_reward = reward
print(f'Best reward is {best_reward}')
assert best_reward > args.target, f'Best reward must be greater than {args.target}. best_reward: {best_reward}'
print('Check passes')
print(f"Best reward is {best_reward}")
assert best_reward > args.target, f"Best reward must be greater than {args.target}. best_reward: {best_reward}"
print("Check passes")

View File

@ -14,4 +14,4 @@
from .digit_completion import DigitCompletion
__all__ = ['DigitCompletion']
__all__ = ["DigitCompletion"]

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer, LlamaConfig
from .task import DigitCompletion, generate_ground_truth_response
from .tokenizer import CharTokenizer
from transformers import AutoTokenizer, LlamaConfig
AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True)
__all__ = ['DigitCompletion', 'generate_ground_truth_response', 'CharTokenizer']
__all__ = ["DigitCompletion", "generate_ground_truth_response", "CharTokenizer"]

View File

@ -16,7 +16,7 @@
import numpy as np
class DigitCompletion(object):
class DigitCompletion:
"""
The implementation of a simple digit completion task.
The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers.
@ -54,16 +54,18 @@ class DigitCompletion(object):
self.np_rng = np.random.default_rng(seed=seed)
def __str__(self):
return f'Prompt length: {self.prompt_length}. Response length: {self.response_length}, ' \
f'Max number: {self.max_number}. Max diff: {self.max_diff}, ' \
f'Max number in response: {self.max_num_in_response}'
return (
f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, "
f"Max number: {self.max_number}. Max diff: {self.max_diff}, "
f"Max number in response: {self.max_num_in_response}"
)
def get_state(self):
return {'rng': self.np_rng}
return {"rng": self.np_rng}
def set_state(self, state):
assert 'rng' in state, 'rng must be inside state'
self.np_rng = state['rng']
assert "rng" in state, "rng must be inside state"
self.np_rng = state["rng"]
@property
def prompt_length(self):
@ -84,7 +86,7 @@ class DigitCompletion(object):
for diff in range(0, self.max_diff + 1):
second_num = self.add(first_num, diff)
for num_to_complete in range(self.max_num_in_response + 1):
prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}'
prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
all_prompts.append(prompt)
return all_prompts
@ -94,7 +96,7 @@ class DigitCompletion(object):
diff = self.np_rng.integers(self.max_diff + 1)
second_num = self.add(first_num, diff)
num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)
prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}'
prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
return prompt
def sample_batch_str_prompts(self, batch_size):
@ -116,9 +118,9 @@ def compute_position_id_with_mask(mask):
def generate_ground_truth_response(prompt: str):
"""Generate ground truth response given a prompt."""
num, info = prompt.split(':')
num1, num2 = num.split(',')
max_number, num_to_gen = info.split(',')
num, info = prompt.split(":")
num1, num2 = num.split(",")
max_number, num_to_gen = info.split(",")
num1 = int(num1)
num2 = int(num2)
max_number = int(max_number)
@ -130,11 +132,11 @@ def generate_ground_truth_response(prompt: str):
curr = (last_num + diff) % max_number
results.append(str(curr))
last_num = curr
response = ','.join(results)
response = ",".join(results)
return response
def compute_reward(prompt: str, response: str, sequence_reward=1.):
def compute_reward(prompt: str, response: str, sequence_reward=1.0):
"""We compute dense reward here so that we can directly train RL without SFT"""
response_length = len(response)
ground_truth_response = generate_ground_truth_response(prompt)
@ -157,21 +159,21 @@ def compute_reward(prompt: str, response: str, sequence_reward=1.):
# no matches
break
return reward, {'ground_truth_response': ground_truth_response}
return reward, {"ground_truth_response": ground_truth_response}
if __name__ == '__main__':
if __name__ == "__main__":
task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)
print(task.sample_str_prompts())
prompt = '7,8:20,0'
response = ''
prompt = "7,8:20,0"
response = ""
print(compute_reward(prompt, response))
prompt = '7,8:20,0'
response = 'E000'
prompt = "7,8:20,0"
response = "E000"
print(compute_reward(prompt, response))
prompt = '9,10:20,2'
response = '11,12,13'
prompt = "9,10:20,2"
response = "11,12,13"
print(compute_reward(prompt, response))

View File

@ -27,7 +27,6 @@ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
class CharTokenizer(PreTrainedTokenizer):
def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs):
"""Character tokenizer for Hugging Face transformers.
@ -47,10 +46,10 @@ class CharTokenizer(PreTrainedTokenizer):
model_max_length (int): Model maximum sequence length.
"""
eos_token_str = 'E'
sep_token_str = 'S'
pad_token_str = 'P'
unk_token_str = 'U'
eos_token_str = "E"
sep_token_str = "S"
pad_token_str = "P"
unk_token_str = "U"
self.characters = characters
self.model_max_length = model_max_length
@ -64,9 +63,7 @@ class CharTokenizer(PreTrainedTokenizer):
eos_token_str: 1,
pad_token_str: 2,
unk_token_str: 3,
**{
ch: i + 4 for i, ch in enumerate(characters)
},
**{ch: i + 4 for i, ch in enumerate(characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
@ -101,9 +98,9 @@ class CharTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def build_inputs_with_special_tokens(self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = cls + token_ids_0 + sep
@ -133,11 +130,11 @@ class CharTokenizer(PreTrainedTokenizer):
return {
"char_ords": [ord(ch) for ch in self.characters],
"model_max_length": self.model_max_length,
"chat_template": self.chat_template
"chat_template": self.chat_template,
}
@classmethod
def from_config(cls, config: Dict) -> "DigitCompletionTokenizer":
def from_config(cls, config: Dict):
cfg = {}
cfg["characters"] = [chr(i) for i in config["char_ords"]]
cfg["model_max_length"] = config["model_max_length"]

View File

@ -15,14 +15,15 @@
import torch
import torch.distributed
from tensordict import TensorDict
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
from torch.distributed.device_mesh import init_device_mesh
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
from verl.utils.distributed import initialize_global_process_group
def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4):
"""Test consistency between original forward pass and SP+rmpad forward passes.
Args:
trainer: The FSDPSFTTrainer instance to test
total_steps: Number of steps to test (default: 4)
@ -88,28 +89,28 @@ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int =
def create_trainer(config):
"""Create and initialize a trainer instance with the given config.
Args:
config: Configuration object with training parameters
Returns:
FSDPSFTTrainer: Initialized trainer instance
"""
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, config.ulysses_sequence_parallel_size),
mesh_dim_names=('dp', 'sp'))
ulysses_device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")
)
return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
def main(config):
"""Main function to run trainer tests.
Args:
config: Configuration object with training parameters
"""
@ -117,7 +118,7 @@ def main(config):
test_trainer_forward_consistency(trainer)
if __name__ == '__main__':
if __name__ == "__main__":
import hydra
from omegaconf import DictConfig

View File

@ -17,20 +17,23 @@ Test memory buffers
- We use Memory buffer to make one of the models and then compare the parameters
"""
import torch
import gc
from transformers import LlamaModel, LlamaConfig
import torch
from transformers import LlamaConfig, LlamaModel
from verl.utils.memory_buffer import MemoryBufferModuleWrapper
def test_memory_buffers():
llama_config = LlamaConfig(vocab_size=256,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=2,
num_attention_heads=16,
num_key_value_heads=16)
llama_config = LlamaConfig(
vocab_size=256,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=2,
num_attention_heads=16,
num_key_value_heads=16,
)
model = LlamaModel(config=llama_config).cuda()
model_copy = LlamaModel(config=llama_config).cuda()
@ -45,7 +48,7 @@ def test_memory_buffers():
r_before = torch.cuda.memory_reserved(0) / norm_factor
a_before = torch.cuda.memory_allocated(0) / norm_factor
print(f'Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB')
print(f"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB")
model_wrapper = MemoryBufferModuleWrapper(model)
@ -56,15 +59,15 @@ def test_memory_buffers():
gc.collect()
torch.cuda.empty_cache()
print(f'After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB')
print(f"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB")
change_ratio = (a - a_before) / a_before
assert change_ratio < 0.01, f'make sure the allocated change is less than 1%, Got {change_ratio}'
assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}"
for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()):
assert name1 == name2
assert torch.eq(param1.data, param2.data).all(), f'{param1.data}, {param2.data}, {name1}'
assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}"
if __name__ == '__main__':
if __name__ == "__main__":
test_memory_buffers()

View File

@ -14,33 +14,31 @@
def test_flash_attn_cross_entropy():
from verl.utils.torch_functional import logprobs_from_logits_naive
from verl.utils.debug import log_gpu_memory_usage
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
import torch
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
from torch import nn
log_gpu_memory_usage('At start')
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.torch_functional import logprobs_from_logits_naive
hidden_states = torch.randn(size=(2048, 5120), device='cuda', requires_grad=True, dtype=torch.bfloat16)
log_gpu_memory_usage("At start")
linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device='cuda', dtype=torch.bfloat16)
hidden_states = torch.randn(size=(2048, 5120), device="cuda", requires_grad=True, dtype=torch.bfloat16)
linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device="cuda", dtype=torch.bfloat16)
logits = linear(hidden_states)
# logits = logits.float()
labels = torch.randint(low=0, high=155136, size=(2048,), device='cuda')
labels = torch.randint(low=0, high=155136, size=(2048,), device="cuda")
log_gpu_memory_usage('before computation')
log_gpu_memory_usage("before computation")
# output = checkpoint.checkpoint(logprobs_from_logits, logits, labels, use_reentrant=True)
output = -cross_entropy_loss(logits, labels)[0]
# output = logprobs_from_logits(logits, labels)
log_gpu_memory_usage('After forward')
log_gpu_memory_usage("After forward")
output.sum().backward()
log_gpu_memory_usage('After backward')
log_gpu_memory_usage("After backward")
groundtruth = logprobs_from_logits_naive(logits.float(), labels)

View File

@ -12,39 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.model import create_random_mask
from flash_attn.bert_padding import unpad_input
import torch
import pytest
import torch
from flash_attn.bert_padding import unpad_input
from verl.utils.model import create_random_mask
def test_log_probs_from_logits_response_rmpad():
from verl.utils.torch_functional import log_probs_from_logits_response, log_probs_from_logits_response_rmpad
vocab_size = 32000
batch_size = 2
prompt_length = 256
response_length = 256
input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.2,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.6)
input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0.2, max_ratio_of_valid_token=0.8, min_ratio_of_valid_token=0.6
)
response_mask = attention_mask[:, -response_length:]
assert torch.all(response_mask[:, 0] == 1)
logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device='cuda')
logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device="cuda")
logits_rmpad = unpad_input(logits, attention_mask)[0]
expected_output = log_probs_from_logits_response(input_ids=input_ids,
logits=logits,
response_length=response_length)
actual_output = log_probs_from_logits_response_rmpad(input_ids=input_ids,
attention_mask=attention_mask,
logits_rmpad=logits_rmpad,
response_length=response_length)
expected_output = log_probs_from_logits_response(
input_ids=input_ids, logits=logits, response_length=response_length
)
actual_output = log_probs_from_logits_response_rmpad(
input_ids=input_ids, attention_mask=attention_mask, logits_rmpad=logits_rmpad, response_length=response_length
)
# This should bitwise align as only this operation only contains gather operators
assert torch.all(torch.eq(actual_output * response_mask, expected_output * response_mask))
@ -52,13 +52,14 @@ def test_log_probs_from_logits_response_rmpad():
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
def test_logprobs_from_logits_v2(dtype):
from verl.utils.torch_functional import logprobs_from_logits_v2, logprobs_from_logits_naive
from verl.utils.torch_functional import logprobs_from_logits_naive, logprobs_from_logits_v2
vocab_size = 32000
batch_size = 2
seq_len = 512
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device='cuda')
logits = torch.randn(batch_size, seq_len, vocab_size, device='cuda', dtype=dtype)
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda")
logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype)
expected_output = logprobs_from_logits_naive(labels=labels, logits=logits)
actual_output = logprobs_from_logits_v2(labels=labels, logits=logits)
@ -71,10 +72,12 @@ def test_logprobs_from_logits_v2(dtype):
def test_lr_scheduler():
from torch import nn
model = nn.Linear(10, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
from verl.utils.torch_functional import get_constant_schedule_with_warmup
constant_lr = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=2)
lr_lst = []
@ -86,11 +89,11 @@ def test_lr_scheduler():
torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.001, 0.001])
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
cosine_lr = get_cosine_schedule_with_warmup(optimizer=optimizer,
num_warmup_steps=2,
num_training_steps=5,
min_lr_ratio=0.1)
cosine_lr = get_cosine_schedule_with_warmup(
optimizer=optimizer, num_warmup_steps=2, num_training_steps=5, min_lr_ratio=0.1
)
lr_lst = []

View File

@ -13,19 +13,26 @@
# limitations under the License.
import torch
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from transformers import (
AutoModelForCausalLM,
AutoModelForTokenClassification,
GemmaConfig,
LlamaConfig,
MistralConfig,
Qwen2Config,
)
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForSequenceClassification
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = [
LlamaConfig(num_hidden_layers=1),
MistralConfig(num_hidden_layers=1),
GemmaConfig(num_hidden_layers=1),
Qwen2Config(num_hidden_layers=1)
Qwen2Config(num_hidden_layers=1),
]
@ -36,56 +43,67 @@ def test_hf_casual_models():
for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device="cuda")
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5,
)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, vocab_size)
logits_rmpad = model(
input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
).logits # (1, total_nnz, vocab_size)
origin_logits = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
origin_logits = model(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
).logits
origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)
logits_rmpad = logits_rmpad.squeeze(0)
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)
origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=origin_logits_rmpad,
indices=origin_logits_indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)
log_probs = log_probs_from_logits_all_rmpad(
input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length,
) # (batch, seqlen)
origin_log_probs = log_probs_from_logits_all_rmpad(
input_ids_rmpad=input_ids_rmpad,
logits_rmpad=origin_logits_rmpad,
indices=origin_logits_indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length,
) # (batch, seqlen)
torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]),
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]),
atol=1e-2,
rtol=1e-5)
print(f'Check pass')
torch.testing.assert_close(
masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]),
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]),
atol=1e-2,
rtol=1e-5,
)
print("Check pass")
def test_hf_value_models():
@ -95,47 +113,54 @@ def test_hf_value_models():
for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
config.num_labels = 1
setattr(config, 'classifier_dropout', 0)
setattr(config, 'hidden_dropout', 0)
with torch.device('cuda'):
model = AutoModelForTokenClassification.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5)
config.classifier_dropout = 0
config.hidden_dropout = 0
with torch.device("cuda"):
model = AutoModelForTokenClassification.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device="cuda")
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5,
)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
origin_logits = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
origin_logits = model(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
).logits
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, 1)
rmpad_logits = model(
input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
).logits # (1, total_nnz, 1)
rmpad_logits = rmpad_logits.squeeze(0)
pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)
torch.testing.assert_close(masked_mean(pad_logits, attention_mask[:, :, None]),
masked_mean(origin_logits, attention_mask[:, :, None]),
atol=1e-2,
rtol=1e-5)
print('Value model check pass')
torch.testing.assert_close(
masked_mean(pad_logits, attention_mask[:, :, None]),
masked_mean(origin_logits, attention_mask[:, :, None]),
atol=1e-2,
rtol=1e-5,
)
print("Value model check pass")
if __name__ == '__main__':
if __name__ == "__main__":
test_hf_casual_models()
test_hf_value_models()

View File

@ -11,25 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import contextlib
import copy
from dataclasses import dataclass
import pytest
import torch
import copy
import torch.distributed
from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input
from torch.distributed import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group
from verl.workers.sharding_manager import FSDPUlyssesShardingManager
from verl.protocol import DataProto
from flash_attn.bert_padding import unpad_input, index_first_axis, rearrange
from transformers import LlamaConfig, Qwen2Config, PretrainedConfig
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.protocol import DataProto
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.ulysses import (
gather_outpus_and_unpad,
get_ulysses_sequence_parallel_world_size,
set_ulysses_sequence_parallel_group,
ulysses_pad_and_slice_inputs,
)
from verl.workers.sharding_manager import FSDPUlyssesShardingManager
# TODO(sgm): add more models for test
# we only need one scale for each model
@ -44,27 +47,25 @@ class SequenceParallelConfig:
def test_configs():
return [
SequenceParallelConfig(LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32),
sp_size=8,
is_valid=True),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2,
num_attention_heads=28,
num_key_value_heads=4,
hidden_size=3584),
sp_size=4,
is_valid=True),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2,
num_attention_heads=28,
num_key_value_heads=4,
hidden_size=3584),
sp_size=8,
is_valid=False),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4),
sp_size=4,
is_valid=True),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4),
sp_size=8,
is_valid=True),
SequenceParallelConfig(
LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),
sp_size=4,
is_valid=True,
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),
sp_size=8,
is_valid=False,
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True
),
]
@ -91,9 +92,9 @@ def test_hf_casual_fwd_bwd(test_config):
def _hf_casual_fwd(config, sp_size, dp_size):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, sp_size),
mesh_dim_names=('dp', 'sp'))
ulysses_device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")
)
sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)
batch_size = 1
@ -101,27 +102,27 @@ def _hf_casual_fwd(config, sp_size, dp_size):
response_length = 127
# patch before load
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
apply_monkey_patch(model, sp_size)
model = model.to(device='cuda')
model = model.to(device="cuda")
sync_model_parameters_global(model)
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.8)
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8
)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
model_inputs = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.int().cuda()
"input_ids": input_ids.cuda(),
"attention_mask": attention_mask.cuda(),
"position_ids": position_ids.int().cuda(),
}
model_inputs = DataProto.from_dict(model_inputs)
@ -129,33 +130,38 @@ def _hf_casual_fwd(config, sp_size, dp_size):
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch['input_ids']
attention_mask = model_inputs.batch['attention_mask']
position_ids = model_inputs.batch['position_ids']
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids = model_inputs.batch["input_ids"]
attention_mask = model_inputs.batch["attention_mask"]
position_ids = model_inputs.batch["position_ids"]
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
use_cache=False).logits # (1, total_nnz/n, vocab_size)
logits_split_in_seq = model(
input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False
).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, vocab_size)
logits_rmpad_local = model(
input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
).logits # (1, total_nnz, vocab_size)
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
@ -165,9 +171,9 @@ def _hf_casual_fwd(config, sp_size, dp_size):
def _hf_casual_fwd_bwd(config, sp_size, dp_size):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, sp_size),
mesh_dim_names=('dp', 'sp'))
ulysses_device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")
)
sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)
batch_size = 1
@ -175,27 +181,27 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size):
response_length = 127
# patch before load
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
apply_monkey_patch(model, sp_size)
model = model.to(device='cuda')
model = model.to(device="cuda")
sync_model_parameters_global(model)
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.8)
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8
)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
model_inputs = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.int().cuda()
"input_ids": input_ids.cuda(),
"attention_mask": attention_mask.cuda(),
"position_ids": position_ids.int().cuda(),
}
model_inputs = DataProto.from_dict(model_inputs)
@ -203,25 +209,29 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size):
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch['input_ids']
attention_mask = model_inputs.batch['attention_mask']
position_ids = model_inputs.batch['position_ids']
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids = model_inputs.batch["input_ids"]
attention_mask = model_inputs.batch["attention_mask"]
position_ids = model_inputs.batch["position_ids"]
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
use_cache=False).logits # (1, total_nnz/n, vocab_size)
logits_split_in_seq = model(
input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False
).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
@ -231,8 +241,9 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size):
input_ids_full = copy.deepcopy(input_ids_rmpad)
position_ids_full = copy.deepcopy(position_ids_rmpad)
model_no_sp = copy.deepcopy(model)
logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full,
use_cache=False).logits # (1, total_nnz, vocab_size)
logits_rmpad_local = model_no_sp(
input_ids_full, position_ids=position_ids_full, use_cache=False
).logits # (1, total_nnz, vocab_size)
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
@ -247,5 +258,5 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size):
torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)
if __name__ == '__main__':
if __name__ == "__main__":
pytest.main([__file__, "-svv"])

View File

@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import sys
import os
import sys
import time
import ray
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ray.remote
class TestActor(Worker):
def __init__(self) -> None:
super().__init__()
@ -41,7 +40,7 @@ if __name__ == "__main__":
ray.init()
# test single-node-no-partition
print(f"test single-node-no-partition")
print("test single-node-no-partition")
resource_pool = RayResourcePool([2], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
@ -56,8 +55,10 @@ if __name__ == "__main__":
_ = wg.foo(wait_time)
print("foo started")
print(time.time(),
f"wait 6x wait time {wait_time*6} to let signal returned to process but still not exceed process wait time")
print(
time.time(),
f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time",
)
time.sleep(wait_time * 6)
ray.shutdown()

View File

@ -17,44 +17,42 @@ In client, we can get the server handler and send RPC request
import ray
import torch
from server import Trainer
from tensordict import TensorDict
from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from tensordict import TensorDict
from server import Trainer
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
if __name__ == '__main__':
ray.init(address='auto', namespace='verl')
if __name__ == "__main__":
ray.init(address="auto", namespace="verl")
# get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"]
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names,
ray_cls_with_init=cls_with_init_args)
worker_group = NVMegatronRayWorkerGroup.from_detached(
worker_names=worker_names, ray_cls_with_init=cls_with_init_args
)
batch_size = 16
sequence_length = 1024
# give Trainer some data to train
input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device='cuda')
input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda")
attention_mask = torch.ones_like(input_ids)
position_ids = compute_position_id_with_mask(attention_mask)
data = DataProto(batch=TensorDict(
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}, batch_size=batch_size),
meta_info={})
data = DataProto(
batch=TensorDict(
{"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids},
batch_size=batch_size,
),
meta_info={},
)
output = worker_group.train_model(data)

View File

@ -17,46 +17,41 @@ Server starts a Trainer. Client sends data to the server to train.
import os
os.environ['MEGATRON_USE_CUDA_TIMER'] = '0'
os.environ['MEGATRON_START_PROCESS_TIMER'] = 'False'
os.environ['NCCL_DEBUG'] = 'WARN'
import torch
from torch import nn
os.environ["MEGATRON_USE_CUDA_TIMER"] = "0"
os.environ["MEGATRON_START_PROCESS_TIMER"] = "False"
os.environ["NCCL_DEBUG"] = "WARN"
import ray
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.single_controller.base.decorator import register, Dispatch
from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
import torch
from megatron.core import parallel_state as mpu
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core import tensor_parallel
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config
from verl.utils.megatron.optimizer import get_megatron_optimizer
from megatron.core.models.gpt.gpt_model import ModelType
from omegaconf import OmegaConf
from tensordict import TensorDict
from torch import nn
from transformers import LlamaConfig
from omegaconf import OmegaConf
from tensordict import TensorDict
from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.utils.megatron.optimizer import get_megatron_optimizer
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config
@ray.remote
class Trainer(MegatronWorker):
def __init__(self):
super().__init__()
if not torch.distributed.is_initialized():
rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
mpu.initialize_model_parallel(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
@ -71,12 +66,14 @@ class Trainer(MegatronWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)
actor_model_config = LlamaConfig(
vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
)
megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
self.megatron_config = megatron_config
@ -86,19 +83,23 @@ class Trainer(MegatronWorker):
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
parallel_model = ParallelLlamaForCausalLMRmPadPP(
config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process,
)
parallel_model.cuda()
return parallel_model
actor_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
actor_module = get_model(
model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True,
)
actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0})
optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0})
optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
@ -109,33 +110,34 @@ class Trainer(MegatronWorker):
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def train_model(self, data: DataProto) -> DataProto:
input_ids = data.batch['input_ids']
attention_mask = data.batch['attention_mask']
position_ids = data.batch['position_ids']
input_ids = data.batch["input_ids"]
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
self.optimizer.zero_grad()
self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer
)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
zero_buffer=(not self.optimizer_config.use_distributed_optimizer)
) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config,
self.megatron_config.timers)
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers
)
return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))
return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]))
if __name__ == '__main__':
ray.init(address='auto', namespace='verl')
if __name__ == "__main__":
ray.init(address="auto", namespace="verl")
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
name_prefix="trainer",
detached=True,
)

View File

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import subprocess
import time
def test():
@ -34,12 +34,13 @@ def test():
print(
time.time(),
f"wait 1.5 wait time {wait_time*1.5} to let signal returned to process but still not exceed process wait time")
f"wait 1.5 wait time {wait_time * 1.5} to let signal returned to process but still not exceed process wait time",
)
time.sleep(wait_time * 1.5)
print(time.time(), f"start checking")
print(time.time(), "start checking")
assert p.poll() is not None, f"process {p} still alive, expecting signal raised abort"
assert p.returncode != 0, f"process {p} exit with code 0, expecting not-zero exit code"
print(f"test passed")
print("test passed")
if __name__ == "__main__":

View File

@ -14,35 +14,37 @@
import ray
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray.base import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
create_colocated_worker_cls,
)
@ray.remote
class Actor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def add(self, data: DataProto):
data.batch['a'] += self.rank
data.batch["a"] += self.rank
return data
@ray.remote
class Critic(Worker):
def __init__(self, config) -> None:
super().__init__()
self.config = config
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def sub(self, data: DataProto):
data.batch['a'] -= self.config['b']
data.batch["a"] -= self.config["b"]
return data
@ -50,10 +52,11 @@ def test_colocated_workers():
ray.init()
import torch
data = DataProto.from_dict({'a': torch.zeros(10)})
data = DataProto.from_dict({"a": torch.zeros(10)})
# create separate workers on the same resource pool
actor_cls = RayClassWithInitArgs(cls=Actor)
critic_cls = RayClassWithInitArgs(cls=Critic, config={'b': 10})
critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
resource_pool = RayResourcePool(process_on_nodes=[2])
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
@ -63,13 +66,13 @@ def test_colocated_workers():
expected_critic_output = critic_wg.sub(data)
# create colocated workers
cls_dict = {'actor': actor_cls, 'critic': critic_cls}
cls_dict = {"actor": actor_cls, "critic": critic_cls}
ray_cls_with_init = create_colocated_worker_cls(cls_dict)
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
colocated_actor_wg = spawn_wg['actor']
colocated_critic_wg = spawn_wg['critic']
colocated_actor_wg = spawn_wg["actor"]
colocated_critic_wg = spawn_wg["critic"]
actor_output = colocated_actor_wg.add(data)
critic_output = colocated_critic_wg.sub(data)

View File

@ -15,27 +15,21 @@
In this test, we instantiate a data parallel worker with 8 GPUs
"""
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
from verl.single_controller.base.decorator import Dispatch, register
import ray
import tensordict
import torch
from codetiming import Timer
from torch import distributed as dist
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.ray_utils import parallel_put
from codetiming import Timer
import tensordict
@ray.remote
class DummyWorker(Worker):
def __init__(self):
super().__init__()
dist.init_process_group()
@ -44,7 +38,7 @@ class DummyWorker(Worker):
def do_nothing(self, data):
for key in data.batch.keys():
data.batch[key] += 1
if tensordict.__version__ >= '0.5.0':
if tensordict.__version__ >= "0.5.0":
data.batch = data.batch.consolidate()
return data
@ -75,35 +69,39 @@ def test_data_transfer():
for i in range(wg.world_size):
# consolidate is necessary
if tensordict.__version__ >= '0.5.0':
if tensordict.__version__ >= "0.5.0":
data_list[i].batch = data_list[i].batch.consolidate()
with Timer(name='ray.pickle', initial_text=True):
with Timer(name="ray.pickle", initial_text=True):
for i in range(wg.world_size):
ray.cloudpickle.pickle.dumps(data_list[i])
with Timer(name='raw.pickle', initial_text=True):
with Timer(name="raw.pickle", initial_text=True):
import pickle
for i in range(wg.world_size):
pickle.dumps(data_list[i])
# we put in advance
with Timer(name='put', initial_text=True):
with Timer(name="put", initial_text=True):
# takes around 40 seconds
data_list_ref = parallel_put(data_list)
# for i in range(wg.world_size):
# data_list[i] = ray.put(data_list[i])
with Timer(name='launch', initial_text=True):
with Timer(name="launch", initial_text=True):
output_ref = wg.do_nothing(data_list_ref)
with Timer(name='get', initial_text=True):
with Timer(name="get", initial_text=True):
# takes around 40 seconds
output_lst = ray.get(output_ref)
for input_data, output_data in zip(data_list, output_lst):
for key in input_data.batch.keys():
assert torch.all(torch.eq(input_data.batch[key] + 1,
output_data.batch[key])), (input_data.batch[key], output_data.batch[key], key)
assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
input_data.batch[key],
output_data.batch[key],
key,
)
ray.shutdown()

View File

@ -13,28 +13,27 @@
# limitations under the License.
import os
import ray
import torch
from verl import DataProto
from tensordict import TensorDict
from verl import DataProto
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs
from verl.single_controller.ray import RayWorkerGroup
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["NCCL_DEBUG"] = "WARN"
@ray.remote
class ModelActor(Worker):
def __init__(self):
pass
class HackSelf():
class HackSelf:
def __init__(self):
pass
@ -44,11 +43,11 @@ def get_aux_metrics(self, test_proto):
decode_count = []
for i in range(sequence_ids.size(0)):
decode_count.append(len(sequence_ids[i].tolist()))
ret_proto = DataProto(batch=TensorDict({
"sequence_ids": sequence_ids,
"decode_count": torch.tensor(decode_count)
},
batch_size=sequence_ids.size(0)))
ret_proto = DataProto(
batch=TensorDict(
{"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)
)
)
return ret_proto
@ -57,17 +56,21 @@ def test():
ray.init()
# create 2 workers, each hold a GPU
resource_pool = RayResourcePool([2], use_gpu=True, name_prefix='a')
resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a")
class_with_args = RayClassWithInitArgs(cls=ModelActor)
shard_wg = RayWorkerGroup(resource_pool, class_with_args)
test_bs = 8
test_proto = DataProto(TensorDict({
"sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
},
batch_size=test_bs),
meta_info={"query_length": 1536})
test_proto = DataProto(
TensorDict(
{
"sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
},
batch_size=test_bs,
),
meta_info={"query_length": 1536},
)
# Sharding among different ranks
ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)

View File

@ -16,8 +16,8 @@ import time
import ray
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
@ray.remote
@ -34,7 +34,7 @@ def test():
ray.init()
# test single-node-no-partition
print(f"test single-node-no-partition")
print("test single-node-no-partition")
resource_pool = RayResourcePool([8], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
@ -63,7 +63,7 @@ def test():
time.sleep(5)
# test single-node-multi-partition
print(f"test single-node-multi-partition")
print("test single-node-multi-partition")
rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)

View File

@ -14,17 +14,17 @@
"""
e2e test verl.single_controller.ray
"""
import os
import ray
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.worker import Worker
from verl.single_controller.base.decorator import register, Dispatch, collect_all_to_all, Execute
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ray.remote
class TestActor(Worker):
def __init__(self) -> None:
super().__init__()
@ -40,9 +40,9 @@ def test_basics():
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
worker_group = RayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=class_with_args,
name_prefix="worker_group_basic")
worker_group = RayWorkerGroup(
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic"
)
output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_WORLD_SIZE")
assert output == ["4", "4", "4", "4"]
@ -53,5 +53,5 @@ def test_basics():
ray.shutdown()
if __name__ == '__main__':
if __name__ == "__main__":
test_basics()

View File

@ -17,7 +17,6 @@ import ray
@ray.remote
class TestWorker:
def __init__(self, rank, world_size, group_name):
self.rank = rank
self.world_size = world_size
@ -26,6 +25,7 @@ class TestWorker:
def init(self):
from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray
self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)
def test(self):

View File

@ -15,12 +15,12 @@
e2e test verl.single_controller.ray
"""
import torch
import ray
import torch
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.base.decorator import register, Dispatch, collect_all_to_all, Execute
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
@ -60,7 +60,7 @@ class TestActor(Worker):
def foo_all_to_all(self, x, y):
return self._x + y + x
@register(dispatch_mode={'dispatch_fn': two_to_all_dispatch_fn, 'collect_fn': collect_all_to_all})
@register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all})
def foo_custom(self, x, y):
return self._x + y + x
@ -94,9 +94,9 @@ def test_basics():
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
worker_group = RayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=class_with_args,
name_prefix="worker_group_basic")
worker_group = RayWorkerGroup(
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic"
)
print(worker_group.worker_names)
@ -124,5 +124,5 @@ def test_basics():
ray.shutdown()
if __name__ == '__main__':
if __name__ == "__main__":
test_basics()

View File

@ -14,54 +14,52 @@
import os
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["NCCL_DEBUG"] = "WARN"
import ray
import torch
import torch.distributed
import ray
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ray.remote
class TestAllGatherActor(Worker):
def __init__(self, size) -> None:
super().__init__()
self.size = size
def init(self):
torch.distributed.init_process_group()
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device='cuda')
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda")
self.tensor += self.rank
def all_gather(self):
world_size = self._world_size
output = torch.zeros(size=(self.tensor.shape[0] * world_size,),
dtype=self.tensor.dtype,
device=self.tensor.device)
output = torch.zeros(
size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
)
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
return output
@ray.remote
class TestAllGatherActorV2(Worker):
def __init__(self, size) -> None:
super().__init__()
self.size = size
torch.distributed.init_process_group()
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device='cuda')
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda")
self.tensor += self.rank
def all_gather(self):
world_size = self._world_size
output = torch.zeros(size=(self.tensor.shape[0] * world_size,),
dtype=self.tensor.dtype,
device=self.tensor.device)
output = torch.zeros(
size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
)
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
return output
@ -78,8 +76,8 @@ def test_all_gather_torch():
worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
worker_group.execute_all_sync('init')
output = worker_group.execute_all_sync('all_gather')
worker_group.execute_all_sync("init")
output = worker_group.execute_all_sync("all_gather")
for i in range(1, len(output)):
assert torch.all(output[i] == output[0])
@ -102,7 +100,7 @@ def test_all_gather_torch_v2():
worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
output = worker_group.execute_all_sync('all_gather')
output = worker_group.execute_all_sync("all_gather")
for i in range(1, len(output)):
assert torch.all(output[i] == output[0])

View File

@ -13,29 +13,30 @@
# limitations under the License.
import os
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType
import time
import torch
from verl.utils.distributed import initialize_global_process_group
from verl.third_party.vllm import LLM
import torch.distributed as dist
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from vllm import SamplingParams
import time
import torch.distributed as dist
from verl.third_party.vllm import LLM
from verl.utils.distributed import initialize_global_process_group
def main():
assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example'
assert torch.cuda.is_available(), "CUDA must be present to run FSDP vLLM example"
local_rank, rank, world_size = initialize_global_process_group()
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = "~/.cache/verl/rlhf"
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
hdfs_path = "Qwen/Qwen2-7B-Instruct"
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True)
@ -51,14 +52,16 @@ def main():
"The future of AI is",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']
prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True)
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
from verl.utils.torch_functional import pad_sequence_to_length
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda()
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda()
from transformers import GenerationConfig
generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
@ -72,61 +75,63 @@ def main():
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False) # may OOM when use_cache = True
use_cache=False,
) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]
print(f'hf response: {tokenizer.batch_decode(response)}')
print(f"hf response: {tokenizer.batch_decode(response)}")
tensor_model_parallel_size = 4
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
fsdp_model = FSDP(actor_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh)
fsdp_model = FSDP(
actor_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)
FSDP.set_state_dict_type(fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
FSDP.set_state_dict_type(
fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()
)
state_dict = fsdp_model.state_dict()
sampling_params = SamplingParams(temperature=0,
top_p=1,
n=1,
max_tokens=response_length,
logprobs=1,
ignore_eos=True,
detokenize=False)
sampling_params = SamplingParams(
temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False
)
print(actor_model_config)
llm = LLM(model=None,
tokenizer=tokenizer,
model_hf_config=actor_model_config,
tensor_parallel_size=tensor_model_parallel_size,
enforce_eager=True,
dtype='bfloat16',
load_format='dummy_dtensor',
gpu_memory_utilization=0.8,
trust_remote_code=True)
llm = LLM(
model=None,
tokenizer=tokenizer,
model_hf_config=actor_model_config,
tensor_parallel_size=tensor_model_parallel_size,
enforce_eager=True,
dtype="bfloat16",
load_format="dummy_dtensor",
gpu_memory_utilization=0.8,
trust_remote_code=True,
)
# Warmup iterations
for _ in range(10):
torch.cuda.synchronize()
llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
llm.sync_model_weights(actor_weights=state_dict, load_format="dtensor")
torch.cuda.synchronize()
dist.barrier()
start_time = time.time()
llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
llm.sync_model_weights(actor_weights=state_dict, load_format="dtensor")
torch.cuda.synchronize()
dist.barrier()
end_time = time.time()
@ -142,14 +147,15 @@ def main():
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs
for i in range(batch_size):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
print('start generation')
print("start generation")
outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)
vllm_output = outputs[0].cuda()
if torch.distributed.get_rank() == 0:
print(f'hf response: {tokenizer.batch_decode(response)}')
print(f'vllm response: {tokenizer.batch_decode(vllm_output)}')
print(f"hf response: {tokenizer.batch_decode(response)}")
print(f"vllm response: {tokenizer.batch_decode(vllm_output)}")
if __name__ == "__main__":

View File

@ -26,13 +26,11 @@
# limitations under the License.
import os
import torch
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.entrypoints.verl_engine import VerlEngine
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GenerationConfig
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from verl.utils.torch_functional import pad_sequence_to_length
@ -53,7 +51,7 @@ def levenshtein(s1, s2):
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost # Substitution
dp[i - 1][j - 1] + cost, # Substitution
)
return dp[m][n]
@ -98,19 +96,20 @@ def initialize_global_process_group(timeout_second=36000):
def test_sglang_spmd():
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'
assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests."
initialize_global_process_group()
# fill rollout config
max_prompt_length = 16
max_response_length = 16
# Initialize model and token
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = "~/.cache/verl/rlhf"
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
hdfs_path = "Qwen/Qwen2-7B-Instruct"
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left')
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
preencode_prompts = [
"Who won the Champions League in 2019?",
@ -118,9 +117,9 @@ def test_sglang_spmd():
"What's your name",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']
prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True)
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
@ -128,17 +127,19 @@ def test_sglang_spmd():
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
actor_model.to(torch.bfloat16)
sampling_params = dict(n=1,
temperature=0,
top_p=1,
top_k=-1,
max_new_tokens=max_response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=True,
spaces_between_special_tokens=True,
ignore_eos=False)
sampling_params = dict(
n=1,
temperature=0,
top_p=1,
top_k=-1,
max_new_tokens=max_response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=True,
spaces_between_special_tokens=True,
ignore_eos=False,
)
tensor_parallel_size = 4
device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
@ -147,13 +148,15 @@ def test_sglang_spmd():
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]
print('building sglang rollout engine')
llm = VerlEngine(model_path=local_model_path,
dtype="bfloat16",
mem_fraction_static=0.5,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=0,
gpu_id_step=1)
print("building sglang rollout engine")
llm = VerlEngine(
model_path=local_model_path,
dtype="bfloat16",
mem_fraction_static=0.5,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=0,
gpu_id_step=1,
)
llm.release_memory_occupation()
print("start generation")
@ -174,7 +177,8 @@ def test_sglang_spmd():
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False) # may OOM when use_cache = True
use_cache=False,
) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]
@ -184,7 +188,7 @@ def test_sglang_spmd():
idx_list = []
batch_size = input_ids.shape[0]
pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
for i in range(batch_size):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
@ -197,8 +201,7 @@ def test_sglang_spmd():
sglang_response_tokens.append(generated_text)
print(f"sglang response: {sglang_response_tokens}")
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \
f"Strings differ more than 10%:\n"
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
print("Check Pass")

View File

@ -13,16 +13,12 @@
# limitations under the License.
import os
import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from vllm import SamplingParams
from verl.third_party.vllm import LLM, vllm_version
from verl.utils.model import update_model_config
from vllm import SamplingParams
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import GenerationConfig
from verl.utils.torch_functional import pad_sequence_to_length
from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs
@ -43,7 +39,7 @@ def levenshtein(s1, s2):
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost # Substitution
dp[i - 1][j - 1] + cost, # Substitution
)
return dp[m][n]
@ -70,17 +66,18 @@ def are_lists_similar(a, b):
def test_vllm_with_hf():
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'
assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests."
# fill rollout config
max_prompt_length = 16
max_response_length = 16
# Initialize model and token
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = "~/.cache/verl/rlhf"
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'deepseek-ai/deepseek-llm-7b-chat'
hdfs_path = "deepseek-ai/deepseek-llm-7b-chat"
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
@ -90,9 +87,9 @@ def test_vllm_with_hf():
"What's your name",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']
prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True)
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
@ -105,28 +102,27 @@ def test_vllm_with_hf():
temperature = 0
top_p = 1
kwargs = dict(n=1,
temperature=temperature,
top_p=top_p,
max_tokens=max_response_length,
logprobs=1,
ignore_eos=True)
kwargs = dict(
n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True
)
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
kwargs['detokenize'] = False
if vllm_version in ("0.4.2", "0.5.4", "0.6.3"):
kwargs["detokenize"] = False
sampling_params = SamplingParams(**kwargs)
tensor_parallel_size = 4
llm = LLM(model=actor_model,
tokenizer=tokenizer,
model_hf_config=actor_model_config,
tensor_parallel_size=tensor_parallel_size,
dtype='bfloat16',
gpu_memory_utilization=0.1,
load_format='hf')
llm = LLM(
model=actor_model,
tokenizer=tokenizer,
model_hf_config=actor_model_config,
tensor_parallel_size=tensor_parallel_size,
dtype="bfloat16",
gpu_memory_utilization=0.1,
load_format="hf",
)
print('start generation')
print("start generation")
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_size = input_ids.size(0)
@ -140,6 +136,7 @@ def test_vllm_with_hf():
llm.free_cache_engine()
llm = None
import gc
torch.cuda.empty_cache()
gc.collect()
@ -156,18 +153,18 @@ def test_vllm_with_hf():
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False) # may OOM when use_cache = True
use_cache=False,
) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]
hf_response_tokens = tokenizer.batch_decode(response)
vllm_response_tokens = tokenizer.batch_decode(vllm_output)
print(f'hf response: {hf_response_tokens}')
print(f'vllm response: {vllm_response_tokens}')
assert are_lists_similar(hf_response_tokens, vllm_response_tokens), \
f'Strings differ more than 10%:\n'
print('Check Pass')
print(f"hf response: {hf_response_tokens}")
print(f"vllm response: {vllm_response_tokens}")
assert are_lists_similar(hf_response_tokens, vllm_response_tokens), "Strings differ more than 10%:\n"
print("Check Pass")
# if __name__ == "__main__":

View File

@ -13,16 +13,14 @@
# limitations under the License.
import os
import torch
import transformers
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from verl.utils.model import update_model_config
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import GenerationConfig
from verl.utils.distributed import initialize_global_process_group
from verl.utils.torch_functional import pad_sequence_to_length
@ -43,7 +41,7 @@ def levenshtein(s1, s2):
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost # Substitution
dp[i - 1][j - 1] + cost, # Substitution
)
return dp[m][n]
@ -70,16 +68,17 @@ def are_lists_similar(a, b):
def test_vllm_spmd():
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'
assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests."
local_rank, rank, world_size = initialize_global_process_group()
# Initialize model and token
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = "~/.cache/verl/rlhf"
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
hdfs_path = "Qwen/Qwen2-7B-Instruct"
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True)
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)
actor_model.to(torch.bfloat16)
@ -93,46 +92,46 @@ def test_vllm_spmd():
"What's your name",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']
prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True)
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
print('start generation')
print("start generation")
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
temperature = 0
top_p = 1
kwargs = dict(n=1,
temperature=temperature,
top_p=top_p,
max_tokens=max_response_length,
logprobs=1,
ignore_eos=True)
kwargs = dict(
n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True
)
tensor_parallel_size = 4
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
fsdp_model = FSDP(actor_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh)
fsdp_model = FSDP(
actor_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)
FSDP.set_state_dict_type(fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
FSDP.set_state_dict_type(
fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()
)
state_dict = fsdp_model.state_dict()
@ -142,7 +141,7 @@ def test_vllm_spmd():
enable_sleep_mode=True,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="external_launcher",
dtype='bfloat16',
dtype="bfloat16",
enforce_eager=True,
gpu_memory_utilization=0.8,
disable_custom_all_reduce=True,
@ -162,7 +161,8 @@ def test_vllm_spmd():
world_size = torch.distributed.get_world_size()
model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
model.load_weights(
((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items()))
((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items())
)
outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)
verl_vllm_response_tokens = []
@ -171,11 +171,10 @@ def test_vllm_spmd():
verl_vllm_response_tokens.append(generated_text)
if torch.distributed.get_rank() == 0:
print(f'vllm response: {vllm_response_tokens}')
print(f'verl-vllm response: {verl_vllm_response_tokens}')
assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), \
f'Strings differ more than 10%:\n'
print('Check Pass')
print(f"vllm response: {vllm_response_tokens}")
print(f"verl-vllm response: {verl_vllm_response_tokens}")
assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), "Strings differ more than 10%:\n"
print("Check Pass")
torch.distributed.destroy_process_group()

View File

@ -12,21 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
from verl.utils.reward_score import _default_compute_score
from verl.utils.reward_score.prime_code import apps_check_correctness
import asyncio
from verl.workers.reward_manager.prime import parallel_compute_score_async
prime_math_answers = [
"""\\begin{bmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19 \n \\end{bmatrix}""",
"""\\frac{\\sqrt{505}}{7}""", """x^2 + y^2 + 4x - 6y + 13"""
"""\\frac{\\sqrt{505}}{7}""",
"""x^2 + y^2 + 4x - 6y + 13""",
]
prime_math_gts = [
"""\\begin{pmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19\n \\end{pmatrix}""", # mat test
"""\\frac{\\sqrt{505}}{7}""", # frac test
"""(x + 2)^2 + (y - 3)^2 """ # symbolic test
"""(x + 2)^2 + (y - 3)^2 """, # symbolic test
]
prime_code_answers = [
@ -83,7 +84,7 @@ if __name__ == '__main__':
] * 2
prime_code_gts = [
"""{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample
"""{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}"""
"""{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""",
] # A failed sample with first several in-out passed
prime_code_scores = [1.0, 0.9]
@ -99,18 +100,17 @@ def test_parallelism():
while len(sequences_str) < 32:
sequences_str.extend(prime_code_answers)
ground_truth.extend(prime_code_gts)
data_sources.extend(['codecontests'] * len(prime_code_answers))
data_sources.extend(["codecontests"] * len(prime_code_answers))
sequences_str.extend(prime_math_answers)
ground_truth.extend(prime_math_gts)
data_sources.extend(['numina_aops_forum'] * len(prime_math_answers))
data_sources.extend(["numina_aops_forum"] * len(prime_math_answers))
scores = asyncio.run(
parallel_compute_score_async(_default_compute_score,
sequences_str,
ground_truth,
data_sources,
num_processes=16))
parallel_compute_score_async(
_default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16
)
)
print(scores)
@ -118,7 +118,7 @@ def test_prime_code():
"""
Test PRIME code sandbox.
"""
data_source = 'codecontests'
data_source = "codecontests"
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
score = _default_compute_score(data_source, completion, ground_truth)
assert float(score) == score_
@ -127,13 +127,13 @@ def test_prime_code():
def test_check_correctness():
completion = prime_code_answers[0]
ground_truth = json.loads(prime_code_gts[0])
ground_truth_single = {'inputs': ground_truth['inputs'][:1], 'outputs': ground_truth['outputs'][:1]}
ground_truth_single = {"inputs": ground_truth["inputs"][:1], "outputs": ground_truth["outputs"][:1]}
res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False)
print(res, meta)
def test_prime_math():
data_source = 'numina_aops_forum'
data_source = "numina_aops_forum"
for completion, ground_truth in zip(prime_math_answers, prime_math_gts):
score = _default_compute_score(data_source, completion, ground_truth)
assert float(score) == 1.0

View File

@ -19,21 +19,21 @@ license_head_prime = "Copyright 2024 PRIME team and/or its affiliates"
license_head_individual = "Copyright 2025 Individual Contributor:"
license_headers = [license_head_bytedance, license_head_bytedance_25, license_head_prime, license_head_individual]
from pathlib import Path
from argparse import ArgumentParser
from pathlib import Path
if __name__ == '__main__':
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--directory', '-d', required=True, type=str)
parser.add_argument("--directory", "-d", required=True, type=str)
args = parser.parse_args()
directory_in_str = args.directory
pathlist = Path(directory_in_str).glob('**/*.py')
pathlist = Path(directory_in_str).glob("**/*.py")
for path in pathlist:
# because path is object not string
path_in_str = str(path.absolute())
print(path_in_str)
with open(path_in_str, 'r', encoding='utf-8') as f:
with open(path_in_str, encoding="utf-8") as f:
file_content = f.read()
has_license = False
@ -41,4 +41,4 @@ if __name__ == '__main__':
if lh in file_content:
has_license = True
break
assert has_license, f'file {path_in_str} does not contain license'
assert has_license, f"file {path_in_str} does not contain license"

View File

@ -15,9 +15,11 @@
def test_import():
import verl
print(verl.__version__)
def test_single_controller_import():
import verl.single_controller
print(verl.single_controller.__version__)

View File

@ -13,41 +13,38 @@
# limitations under the License.
import random
import numpy as np
import pytest
import torch
from tensordict import TensorDict
from verl.protocol import union_tensor_dict, union_numpy_dict
from verl import DataProto
import numpy as np
from verl.protocol import union_numpy_dict, union_tensor_dict
def test_union_tensor_dict():
obs = torch.randn(100, 10)
data1 = TensorDict({'obs': obs, 'act': torch.randn(100, 3)}, batch_size=[100])
data2 = TensorDict({'obs': obs, 'next_obs': torch.randn(100, 10), 'rew': torch.randn(100)}, batch_size=[100])
data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100])
data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100])
data_with_copied_obs = TensorDict({
'obs': obs.clone(),
'next_obs': torch.randn(100, 10),
'rew': torch.randn(100)
},
batch_size=[100])
data_with_copied_obs = TensorDict(
{"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]
)
data = union_tensor_dict(data1, data2)
with pytest.raises(AssertionError):
data = union_tensor_dict(data1, data_with_copied_obs)
data = np.random.random(100)
data2 = [float('nan') for _ in range(99)]
data2.append('nan')
data2 = [float("nan") for _ in range(99)]
data2.append("nan")
data2 = np.array(data2, dtype=object)
data3 = np.tile(data2, (2, 1))
a = {'a': data, 'b': data2, 'c': data3}
b = {'a': data, 'b': data2, 'c': data3}
b_ = {'a': np.random.random(100)}
a = {"a": data, "b": data2, "c": data3}
b = {"a": data, "b": data2, "c": data3}
b_ = {"a": np.random.random(100)}
union_numpy_dict(a, b)
with pytest.raises(AssertionError):
union_numpy_dict(a, b_)
@ -56,21 +53,21 @@ def test_union_tensor_dict():
def test_tensor_dict_constructor():
obs = torch.randn(100, 10)
act = torch.randn(100, 10, 3)
data = DataProto.from_dict(tensors={'obs': obs, 'act': act})
data = DataProto.from_dict(tensors={"obs": obs, "act": act})
assert data.batch.batch_size == torch.Size([100])
with pytest.raises(AssertionError):
data = DataProto.from_dict(tensors={'obs': obs, 'act': act}, num_batch_dims=2)
data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2)
with pytest.raises(AssertionError):
data = DataProto.from_dict(tensors={'obs': obs, 'act': act}, num_batch_dims=3)
data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3)
def test_tensor_dict_make_iterator():
obs = torch.randn(100, 10)
labels = [random.choice(['abc', 'cde']) for _ in range(100)]
dataset = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels})
labels = [random.choice(["abc", "cde"]) for _ in range(100)]
dataset = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels})
data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)
data_list_1 = []
@ -85,94 +82,94 @@ def test_tensor_dict_make_iterator():
for data1, data2 in zip(data_list_1, data_list_2):
assert isinstance(data1, DataProto)
assert isinstance(data2, DataProto)
result = torch.all(torch.eq(data1.batch['obs'], data2.batch['obs']))
result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"]))
if not result.item():
print(data1.batch['obs'])
print(data2.batch['obs'])
print(data1.batch["obs"])
print(data2.batch["obs"])
assert False
non_tensor_result = np.all(np.equal(data1.non_tensor_batch['labels'], data2.non_tensor_batch['labels']))
non_tensor_result = np.all(np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"]))
if not non_tensor_result.item():
print(data1.non_tensor_batch['labels'])
print(data2.non_tensor_batch['labels'])
print(data1.non_tensor_batch["labels"])
print(data2.non_tensor_batch["labels"])
def test_reorder():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ['a', 'b', 'c', 'd', 'e', 'f']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'})
labels = ["a", "b", "c", "d", "e", "f"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
data.reorder(torch.tensor([3, 4, 2, 0, 1, 5]))
assert torch.all(torch.eq(data.batch['obs'], torch.tensor([4, 5, 3, 1, 2, 6])))
assert np.all(data.non_tensor_batch['labels'] == np.array(['d', 'e', 'c', 'a', 'b', 'f']))
assert data.meta_info == {'name': 'abdce'}
assert torch.all(torch.eq(data.batch["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
assert np.all(data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
assert data.meta_info == {"name": "abdce"}
def test_chunk_concat():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ['a', 'b', 'c', 'd', 'e', 'f']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'})
labels = ["a", "b", "c", "d", "e", "f"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
with pytest.raises(AssertionError):
data.chunk(5)
data_split = data.chunk(2)
assert len(data_split) == 2
assert torch.all(torch.eq(data_split[0].batch['obs'], torch.tensor([1, 2, 3])))
assert np.all(data_split[0].non_tensor_batch['labels'] == np.array(['a', 'b', 'c']))
assert data_split[0].meta_info == {'name': 'abdce'}
assert torch.all(torch.eq(data_split[0].batch["obs"], torch.tensor([1, 2, 3])))
assert np.all(data_split[0].non_tensor_batch["labels"] == np.array(["a", "b", "c"]))
assert data_split[0].meta_info == {"name": "abdce"}
assert torch.all(torch.eq(data_split[1].batch['obs'], torch.tensor([4, 5, 6])))
assert np.all(data_split[1].non_tensor_batch['labels'] == np.array(['d', 'e', 'f']))
assert data_split[1].meta_info == {'name': 'abdce'}
assert torch.all(torch.eq(data_split[1].batch["obs"], torch.tensor([4, 5, 6])))
assert np.all(data_split[1].non_tensor_batch["labels"] == np.array(["d", "e", "f"]))
assert data_split[1].meta_info == {"name": "abdce"}
concat_data = DataProto.concat(data_split)
assert torch.all(torch.eq(concat_data.batch['obs'], data.batch['obs']))
assert np.all(concat_data.non_tensor_batch['labels'] == data.non_tensor_batch['labels'])
assert torch.all(torch.eq(concat_data.batch["obs"], data.batch["obs"]))
assert np.all(concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"])
assert concat_data.meta_info == data.meta_info
def test_pop():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
dataset = DataProto.from_dict({'obs': obs, 'act': act}, meta_info={'2': 2, '1': 1})
poped_dataset = dataset.pop(batch_keys=['obs'], meta_info_keys=['2'])
dataset = DataProto.from_dict({"obs": obs, "act": act}, meta_info={"2": 2, "1": 1})
poped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["2"])
assert poped_dataset.batch.keys() == {'obs'}
assert poped_dataset.meta_info.keys() == {'2'}
assert poped_dataset.batch.keys() == {"obs"}
assert poped_dataset.meta_info.keys() == {"2"}
assert dataset.batch.keys() == {'act'}
assert dataset.meta_info.keys() == {'1'}
assert dataset.batch.keys() == {"act"}
assert dataset.meta_info.keys() == {"1"}
def test_repeat():
# Create a DataProto object with some batch and non-tensor data
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ['a', 'b', 'c']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
labels = ["a", "b", "c"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
# Test interleave=True
repeated_data_interleave = data.repeat(repeat_times=2, interleave=True)
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
expected_labels_interleave = ['a', 'a', 'b', 'b', 'c', 'c']
expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]
assert torch.all(torch.eq(repeated_data_interleave.batch['obs'], expected_obs_interleave))
assert (repeated_data_interleave.non_tensor_batch['labels'] == expected_labels_interleave).all()
assert repeated_data_interleave.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave))
assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all()
assert repeated_data_interleave.meta_info == {"info": "test_info"}
# Test interleave=False
repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False)
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
expected_labels_no_interleave = ['a', 'b', 'c', 'a', 'b', 'c']
expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]
assert torch.all(torch.eq(repeated_data_no_interleave.batch['obs'], expected_obs_no_interleave))
assert (repeated_data_no_interleave.non_tensor_batch['labels'] == expected_labels_no_interleave).all()
assert repeated_data_no_interleave.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave))
assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all()
assert repeated_data_no_interleave.meta_info == {"info": "test_info"}
def test_dataproto_pad_unpad():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ['a', 'b', 'c']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
labels = ["a", "b", "c"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
@ -180,115 +177,116 @@ def test_dataproto_pad_unpad():
assert pad_size == 1
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ['a', 'b', 'c', 'a']
expected_labels = ["a", "b", "c", "a"]
assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs))
assert (padded_data.non_tensor_batch['labels'] == expected_labels).all()
assert padded_data.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs))
assert (padded_data.non_tensor_batch["labels"] == expected_labels).all()
assert padded_data.meta_info == {"info": "test_info"}
unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data.batch['obs'], obs))
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(unpadd_data.batch["obs"], obs))
assert (unpadd_data.non_tensor_batch["labels"] == labels).all()
assert unpadd_data.meta_info == {"info": "test_info"}
padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3)
assert pad_size == 0
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
expected_labels = ['a', 'b', 'c']
expected_labels = ["a", "b", "c"]
assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs))
assert (padded_data.non_tensor_batch['labels'] == expected_labels).all()
assert padded_data.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs))
assert (padded_data.non_tensor_batch["labels"] == expected_labels).all()
assert padded_data.meta_info == {"info": "test_info"}
unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data.batch['obs'], obs))
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(unpadd_data.batch["obs"], obs))
assert (unpadd_data.non_tensor_batch["labels"] == labels).all()
assert unpadd_data.meta_info == {"info": "test_info"}
padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7)
assert pad_size == 4
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ['a', 'b', 'c', 'a', 'b', 'c', 'a']
assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs))
assert (padded_data.non_tensor_batch['labels'] == expected_labels).all()
assert padded_data.meta_info == {'info': 'test_info'}
expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs))
assert (padded_data.non_tensor_batch["labels"] == expected_labels).all()
assert padded_data.meta_info == {"info": "test_info"}
unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data.batch['obs'], obs))
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}
assert torch.all(torch.eq(unpadd_data.batch["obs"], obs))
assert (unpadd_data.non_tensor_batch["labels"] == labels).all()
assert unpadd_data.meta_info == {"info": "test_info"}
def test_dataproto_fold_unfold():
from verl.protocol import fold_batch_dim, unfold_batch_dim, DataProto
from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ['a', 'b', 'c']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
labels = ["a", "b", "c"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
data1 = data.repeat(repeat_times=2, interleave=True)
data2 = fold_batch_dim(data1, new_batch_size=3)
torch.testing.assert_close(data2.batch['obs'], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]))
assert (data2.non_tensor_batch['labels'] == [['a', 'a'], ['b', 'b'], ['c', 'c']]).all()
torch.testing.assert_close(data2.batch["obs"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]))
assert (data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]]).all()
data2.reorder(indices=torch.tensor([1, 2, 0]))
data3 = unfold_batch_dim(data2, batch_dims=2)
torch.testing.assert_close(data3.batch['obs'], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]]))
assert (data3.non_tensor_batch['labels'] == ['b', 'b', 'c', 'c', 'a', 'a']).all()
assert data3.meta_info == {'info': 'test_info'}
torch.testing.assert_close(data3.batch["obs"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]]))
assert (data3.non_tensor_batch["labels"] == ["b", "b", "c", "c", "a", "a"]).all()
assert data3.meta_info == {"info": "test_info"}
def test_torch_save_data_proto():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ['a', 'b', 'c']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
data.save_to_disk('test_data.pt')
loaded_data = DataProto.load_from_disk('test_data.pt')
labels = ["a", "b", "c"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
data.save_to_disk("test_data.pt")
loaded_data = DataProto.load_from_disk("test_data.pt")
assert torch.all(torch.eq(loaded_data.batch['obs'], data.batch['obs']))
assert (loaded_data.non_tensor_batch['labels'] == data.non_tensor_batch['labels']).all()
assert torch.all(torch.eq(loaded_data.batch["obs"], data.batch["obs"]))
assert (loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]).all()
assert loaded_data.meta_info == data.meta_info
import os
os.remove('test_data.pt')
os.remove("test_data.pt")
def test_len():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = np.array(['a', 'b', 'c'], dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
labels = np.array(["a", "b", "c"], dtype=object)
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
assert len(data) == 3
data = DataProto(batch=None, non_tensor_batch={'labels': labels}, meta_info={'info': 'test_info'})
data = DataProto(batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"})
assert len(data) == 3
data = DataProto(batch=None, non_tensor_batch={}, meta_info={'info': 'test_info'})
data = DataProto(batch=None, non_tensor_batch={}, meta_info={"info": "test_info"})
assert len(data) == 0
data = DataProto(batch=None, non_tensor_batch=None, meta_info={'info': 'test_info'})
data = DataProto(batch=None, non_tensor_batch=None, meta_info={"info": "test_info"})
assert len(data) == 0
def test_seqlen_balancing():
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
input_ids = torch.randint(low=0, high=10, size=(20, 100))
from verl.utils.model import create_random_mask
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.5)
data = {'input_ids': input_ids, 'attention_mask': attention_mask}
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5
)
data = {"input_ids": input_ids, "attention_mask": attention_mask}
dataproto = DataProto.from_single_dict(data)
micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)
batch = torch.cat(micro_batches)
@ -298,4 +296,4 @@ def test_seqlen_balancing():
reverse_idx_map = get_reverse_idx(micro_bsz_idx)
reverse_idx_map = torch.tensor(reverse_idx_map)
new_batch = batch[reverse_idx_map]
torch.testing.assert_close(new_batch, dataproto.batch)
torch.testing.assert_close(new_batch, dataproto.batch)

View File

@ -14,10 +14,13 @@
"""
Test the MultiTurnSFTDataset implementation
"""
import os
import pandas as pd
import torch
from transformers import AutoTokenizer
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
@ -25,51 +28,35 @@ def test_multiturn_sft_dataset():
print("Starting test...")
# Create a temporary parquet file with test data
test_data = {
'messages': [[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "What is 2+2?"
}, {
"role": "assistant",
"content": "2+2 equals 4."
}, {
"role": "user",
"content": "And what is 4+4?"
}, {
"role": "assistant",
"content": "4+4 equals 8."
}],
[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Tell me a joke."
}, {
"role": "assistant",
"content": "Why did the chicken cross the road?"
}, {
"role": "user",
"content": "Why?"
}, {
"role": "assistant",
"content": "To get to the other side!"
}]]
"messages": [
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."},
{"role": "user", "content": "And what is 4+4?"},
{"role": "assistant", "content": "4+4 equals 8."},
],
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me a joke."},
{"role": "assistant", "content": "Why did the chicken cross the road?"},
{"role": "user", "content": "Why?"},
{"role": "assistant", "content": "To get to the other side!"},
],
]
}
# Create test directory if it doesn't exist
os.makedirs('test_data', exist_ok=True)
test_file = 'test_data/test.parquet'
os.makedirs("test_data", exist_ok=True)
test_file = "test_data/test.parquet"
# Save test data to parquet
df = pd.DataFrame(test_data)
df.to_parquet(test_file)
# Initialize tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct')
config = {'max_length': 512, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}}
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct")
config = {"max_length": 512, "truncation": "error", "multiturn": {"messages_key": "messages"}}
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
# Test 1: Dataset Length
@ -80,23 +67,22 @@ def test_multiturn_sft_dataset():
item1 = dataset[1] # Joke conversation
# Test 2: Required Keys and Types
required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask']
required_keys = ["input_ids", "attention_mask", "position_ids", "loss_mask"]
for key in required_keys:
assert key in item0, f"Missing key {key} in dataset item"
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}"
assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}"
# Test 3: Shape Consistency
assert item0['loss_mask'].shape == item0['input_ids'].shape, \
"Loss mask shape doesn't match input_ids shape"
assert item0['attention_mask'].shape == item0['input_ids'].shape, \
assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape"
assert item0["attention_mask"].shape == item0["input_ids"].shape, (
"Attention mask shape doesn't match input_ids shape"
assert item0['position_ids'].shape == item0['input_ids'].shape, \
"Position IDs shape doesn't match input_ids shape"
)
assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape"
# Test 4: Loss Mask Pattern - Math Conversation
loss_mask0 = item0['loss_mask']
input_ids0 = item0['input_ids']
loss_mask0 = item0["loss_mask"]
input_ids0 = item0["input_ids"]
# Find assistant response positions
assistant_positions0 = torch.where(loss_mask0 == 1)[0]
@ -109,8 +95,8 @@ def test_multiturn_sft_dataset():
assert "4+4 equals 8" in assistant_text0, "Second assistant response not found"
# Test 5: Loss Mask Pattern - Joke Conversation
loss_mask1 = item1['loss_mask']
input_ids1 = item1['input_ids']
loss_mask1 = item1["loss_mask"]
input_ids1 = item1["input_ids"]
# Find assistant response positions
assistant_positions1 = torch.where(loss_mask1 == 1)[0]
@ -123,7 +109,7 @@ def test_multiturn_sft_dataset():
assert "other side" in assistant_text1, "Second assistant response not found"
# Test 6: Attention Mask Pattern
attention_mask0 = item0['attention_mask']
attention_mask0 = item0["attention_mask"]
sequence_length = torch.sum(attention_mask0)
assert sequence_length > 0, "No tokens marked as attended in attention mask"
assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern"
@ -131,9 +117,10 @@ def test_multiturn_sft_dataset():
assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked"
# Test 7: Position IDs Pattern
position_ids0 = item0['position_ids']
assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \
position_ids0 = item0["position_ids"]
assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), (
"Position IDs not sequential for non-padded tokens"
)
if sequence_length < len(position_ids0):
assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero"
@ -147,16 +134,16 @@ def test_multiturn_sft_dataset():
print(f"\nAssistant responses (from loss mask):\n{assistant_text}")
# Verify that loss mask is set for all assistant responses
for msg in test_data['messages'][0]: # First conversation
if msg['role'] == 'assistant':
for msg in test_data["messages"][0]: # First conversation
if msg["role"] == "assistant":
# The content should appear in the masked text
assert msg['content'] in assistant_text, \
f"Assistant message '{msg['content']}' not found in masked text"
assert msg["content"] in assistant_text, f"Assistant message '{msg['content']}' not found in masked text"
# The content should NOT appear in the non-masked text
non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])
assert msg['content'] not in non_assistant_text, \
assert msg["content"] not in non_assistant_text, (
f"Assistant message '{msg['content']}' found in non-assistant text"
)
# Test 9: Verify non-assistant parts have loss_mask=0
# Get non-assistant text
@ -164,30 +151,31 @@ def test_multiturn_sft_dataset():
print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}")
# Verify that system and user messages are in the non-assistant text
for msg in test_data['messages'][0]: # First conversation
if msg['role'] in ['system', 'user']:
assert msg['content'] in non_assistant_text, \
for msg in test_data["messages"][0]: # First conversation
if msg["role"] in ["system", "user"]:
assert msg["content"] in non_assistant_text, (
f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text"
)
# And verify they're NOT in the assistant text
assert msg['content'] not in assistant_text, \
assert msg["content"] not in assistant_text, (
f"{msg['role'].title()} message '{msg['content']}' found in assistant text"
)
# Test 10: Verify padding behavior
padding_config = {'max_length': 1024, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}}
padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}}
small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config)
padded_item = small_dataset[0]
# Get actual sequence length (before padding)
actual_length = torch.sum(padded_item['attention_mask'])
actual_length = torch.sum(padded_item["attention_mask"])
# Verify padding tokens
assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \
assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), (
"Padding tokens not set correctly"
assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \
"Attention mask not set correctly for padding"
assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \
"Loss mask not set correctly for padding"
)
assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding"
assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding"
print("All tests passed!")
print("Starting test...")

View File

@ -12,32 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
def get_gsm8k_data():
# prepare test dataset
url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet"
local_folder = os.path.expanduser('~/verl-data/gsm8k/')
local_path = os.path.join(local_folder, 'train.parquet')
local_folder = os.path.expanduser("~/verl-data/gsm8k/")
local_path = os.path.join(local_folder, "train.parquet")
os.makedirs(local_folder, exist_ok=True)
return local_path
def test_rl_dataset():
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct')
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct")
local_path = get_gsm8k_data()
config = OmegaConf.create({
"prompt_key": "prompt",
"max_prompt_length": 256,
"filter_overlong_prompts": True,
"filter_overlong_prompts_workers": 2,
})
config = OmegaConf.create(
{
"prompt_key": "prompt",
"max_prompt_length": 256,
"filter_overlong_prompts": True,
"filter_overlong_prompts_workers": 2,
}
)
dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config)
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
@ -56,29 +59,34 @@ def test_rl_dataset():
non_tensors[key] = val
data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)
assert 'input_ids' in data_proto.batch
assert "input_ids" in data_proto.batch
data = dataset[0]['input_ids']
data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f'type: type{output}')
print(f'\n\noutput: {output}')
print(f"type: type{output}")
print(f"\n\noutput: {output}")
def test_image_rl_data():
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils import hf_tokenizer, hf_processor
tokenizer = hf_tokenizer('Qwen/Qwen2-VL-2B-Instruct')
processor = hf_processor('Qwen/Qwen2-VL-2B-Instruct')
config = OmegaConf.create({
"prompt_key": "prompt",
"max_prompt_length": 1024,
"filter_overlong_prompts": True,
"filter_overlong_prompts_workers": 2,
})
dataset = RLHFDataset(data_files=os.path.expanduser("~/data/geo3k/train.parquet"),
tokenizer=tokenizer,
config=config,
processor=processor)
tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct")
processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct")
config = OmegaConf.create(
{
"prompt_key": "prompt",
"max_prompt_length": 1024,
"filter_overlong_prompts": True,
"filter_overlong_prompts_workers": 2,
}
)
dataset = RLHFDataset(
data_files=os.path.expanduser("~/data/geo3k/train.parquet"),
tokenizer=tokenizer,
config=config,
processor=processor,
)
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
@ -97,10 +105,57 @@ def test_image_rl_data():
data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)
assert 'multi_modal_data' in data_proto.non_tensor_batch
assert 'multi_modal_inputs' in data_proto.non_tensor_batch
assert "multi_modal_data" in data_proto.non_tensor_batch
assert "multi_modal_inputs" in data_proto.non_tensor_batch
data = dataset[0]['input_ids']
data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f'type: type{output}')
print(f'\n\noutput: {output}')
print(f"type: type{output}")
print(f"\n\noutput: {output}")
def test_image_rl_data():
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct")
processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct")
config = OmegaConf.create(
{
"prompt_key": "prompt",
"max_prompt_length": 1024,
"filter_overlong_prompts": True,
"filter_overlong_prompts_workers": 2,
}
)
dataset = RLHFDataset(
data_files=os.path.expanduser("~/data/geo3k/train.parquet"),
tokenizer=tokenizer,
config=config,
processor=processor,
)
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
a = next(iter(dataloader))
from verl import DataProto
tensors = {}
non_tensors = {}
for key, val in a.items():
if isinstance(val, torch.Tensor):
tensors[key] = val
else:
non_tensors[key] = val
data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)
assert "multi_modal_data" in data_proto.non_tensor_batch
assert "multi_modal_inputs" in data_proto.non_tensor_batch
data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f"type: type{output}")
print(f"\n\noutput: {output}")

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
from transformers import AutoTokenizer
from verl.utils import hf_tokenizer
from verl.utils.dataset.rm_dataset import RMDataset
@ -21,8 +20,8 @@ from verl.utils.dataset.rm_dataset import RMDataset
def get_rm_data():
# prepare test dataset
url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/full_hh_rlhf/rm/test.parquet"
local_folder = os.path.expanduser('~/verl-data/full_hh_rlhf/rm/')
local_path = os.path.join(local_folder, 'test.parquet')
local_folder = os.path.expanduser("~/verl-data/full_hh_rlhf/rm/")
local_path = os.path.join(local_folder, "test.parquet")
os.makedirs(local_folder, exist_ok=True)
return local_path
@ -31,7 +30,7 @@ def test_rm_dataset():
tokenizer = hf_tokenizer("facebook/opt-1.3b")
local_path = get_rm_data()
dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512)
data = dataset[0]['input_ids']
data = dataset[0]["input_ids"]
output = tokenizer.batch_decode(data)
assert len(output) > 1
assert type(output[0]) == str

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
from transformers import AutoTokenizer
from verl.utils import hf_tokenizer
from verl.utils.dataset.sft_dataset import SFTDataset
@ -21,46 +20,56 @@ from verl.utils.dataset.sft_dataset import SFTDataset
def get_gsm8k_data():
# prepare test dataset
url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet"
local_folder = os.path.expanduser('~/verl-data/gsm8k/')
local_path = os.path.join(local_folder, 'train.parquet')
local_folder = os.path.expanduser("~/verl-data/gsm8k/")
local_path = os.path.join(local_folder, "train.parquet")
return local_path
def test_sft_cot_dataset():
tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
local_path = get_gsm8k_data()
from omegaconf import OmegaConf
dataset = SFTDataset(parquet_files=local_path,
tokenizer=tokenizer,
config=OmegaConf.create({
'prompt_key': 'prompt',
'prompt_dict_keys': ['content'],
'response_key': 'extra_info',
'response_dict_keys': ['answer'],
'max_length': 512,
}))
data = dataset[0]['input_ids']
dataset = SFTDataset(
parquet_files=local_path,
tokenizer=tokenizer,
config=OmegaConf.create(
{
"prompt_key": "prompt",
"prompt_dict_keys": ["content"],
"response_key": "extra_info",
"response_dict_keys": ["answer"],
"max_length": 512,
}
),
)
data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
assert len(output) > 1
assert type(output) == str
def test_sft_dataset():
tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
local_path = get_gsm8k_data()
from omegaconf import OmegaConf
dataset = SFTDataset(parquet_files=local_path,
tokenizer=tokenizer,
config=OmegaConf.create({
"prompt_key": 'extra_info',
'prompt_dict_keys': ['question'],
'response_key': 'extra_info',
'response_dict_keys': ['answer'],
'max_length': 512
}))
data = dataset[0]['input_ids']
dataset = SFTDataset(
parquet_files=local_path,
tokenizer=tokenizer,
config=OmegaConf.create(
{
"prompt_key": "extra_info",
"prompt_dict_keys": ["question"],
"response_key": "extra_info",
"response_dict_keys": ["answer"],
"max_length": 512,
}
),
)
data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
assert len(output) > 1
assert type(output) == str

View File

@ -13,9 +13,9 @@
# limitations under the License.
import os
import sys
import importlib.util
import pytest
from verl.utils.import_utils import load_extern_type
# Path to the test module
@ -84,7 +84,7 @@ def test_load_extern_type_invalid_module():
# Create a temporary file with syntax errors
import tempfile
with tempfile.NamedTemporaryFile(suffix='.py', mode='w+', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file:
temp_file.write("This is not valid Python syntax :")
temp_path = temp_file.name

View File

@ -16,24 +16,26 @@ import os
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
with open(os.path.join(version_folder, 'version/version')) as f:
with open(os.path.join(version_folder, "version/version")) as f:
__version__ = f.read().strip()
from .protocol import DataProto
from .utils.logging_utils import set_basic_config
import logging
from .protocol import DataProto
from .utils.logging_utils import set_basic_config
set_basic_config(level=logging.WARNING)
from . import single_controller
__all__ = ['DataProto', "__version__"]
__all__ = ["DataProto", "__version__"]
if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true':
if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true":
import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")
# Patch hub to download models from modelscope to speed up.
from modelscope.utils.hf_util import patch_hub
patch_hub()

View File

@ -13,12 +13,13 @@
# limitations under the License.
from .modeling_llama_megatron import (
# original model with megatron
ParallelLlamaModel,
ParallelLlamaForCausalLM,
# rmpad with megatron
ParallelLlamaForCausalLMRmPad,
ParallelLlamaForValueRmPad,
# rmpad with megatron and pipeline parallelism
ParallelLlamaForCausalLMRmPadPP,
ParallelLlamaForValueRmPadPP)
ParallelLlamaForValueRmPad,
ParallelLlamaForValueRmPadPP,
# original model with megatron
ParallelLlamaModel,
)

View File

@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from packaging.version import Version
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch
import torch.distributed as dist
@ -29,7 +27,7 @@ def _megatron_calc_layer_map(config):
"""
from megatron.core import mpu
print(f'get megatron data parallel size: {mpu.get_data_parallel_world_size()}')
print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}")
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
@ -40,8 +38,9 @@ def _megatron_calc_layer_map(config):
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
layer_offset = (
virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
@ -51,20 +50,17 @@ def _megatron_calc_layer_map(config):
return layer_map
def load_state_dict_to_megatron_llama(state_dict,
wrapped_models,
config,
params_dtype,
is_value_model=False,
tie_word_embeddings=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
from megatron.core import mpu
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
def load_state_dict_to_megatron_llama(
state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False
):
"""Load merged state_dict to sharded Megatron module in training."""
from megatron.core import DistributedDataParallel as LocalDDP
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import print_rank_0, unwrap_model
start_time = time.time()
def _get_gpt_model(model):
@ -72,9 +68,9 @@ def load_state_dict_to_megatron_llama(state_dict,
def fetch_params(module):
for param in module.parameters():
torch.distributed.fetch(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
torch.distributed.fetch(
param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()
)
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
@ -92,7 +88,9 @@ def load_state_dict_to_megatron_llama(state_dict,
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}'
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (
f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
)
models = [None] * len(wrapped_models)
@ -148,16 +146,16 @@ def load_state_dict_to_megatron_llama(state_dict,
if gate_name in state_dict and up_name in state_dict:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_gate_up_weight = torch.empty(
config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0)
)
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
if tensor is not None:
@ -171,7 +169,7 @@ def load_state_dict_to_megatron_llama(state_dict,
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
@ -182,31 +180,29 @@ def load_state_dict_to_megatron_llama(state_dict,
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_weight_qkv = torch.empty(
total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_weight_qkv = torch.empty(
total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
if tensor is not None:
@ -235,9 +231,9 @@ def load_state_dict_to_megatron_llama(state_dict,
for vpp_rank in range(vpp_size):
num_layer_vpp_chunk = num_layer_per_pp // vpp_size
num_layer_this_model = num_layer_vpp_chunk
offset = vpp_rank * (
config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + \
(mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)
offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (
mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk
)
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
else:
num_layer_this_model = num_layer_per_pp
@ -275,8 +271,11 @@ def load_state_dict_to_megatron_llama(state_dict,
f"{layer_name}.post_attention_layernorm.weight",
)
_fetch_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_fetch_tp_shard_tensor_gate_up(
sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
)
_fetch_tp_shard_tensor(
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
@ -297,15 +296,15 @@ def load_state_dict_to_megatron_llama(state_dict,
lm_head_weight = gpt_model_module.lm_head.weight
if is_value_model:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
_fetch_tensor(lm_head_weight, "lm_head.weight")
print_rank_0('load lm_head weight')
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
print_rank_0("load lm_head weight")
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
_fetch_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
print_rank_0("load lm_head from value_head weight")
else:
_fetch_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
print_rank_0("fail to match lm_head in value_model")
else:
_fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight")

View File

@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from packaging.version import Version
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch
import torch.distributed as dist
@ -29,7 +27,7 @@ def _megatron_calc_layer_map(config):
"""
from megatron.core import mpu
print(f'get megatron data parallel size: {mpu.get_data_parallel_world_size()}')
print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}")
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
@ -40,8 +38,9 @@ def _megatron_calc_layer_map(config):
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
layer_offset = (
virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
@ -51,20 +50,17 @@ def _megatron_calc_layer_map(config):
return layer_map
def load_state_dict_to_megatron_llama(state_dict,
wrapped_models,
config,
params_dtype,
is_value_model=False,
tie_word_embeddings=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
from megatron.core import mpu
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
def load_state_dict_to_megatron_llama(
state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False
):
"""Load merged state_dict to sharded Megatron module in training."""
from megatron.core import DistributedDataParallel as LocalDDP
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import print_rank_0, unwrap_model
start_time = time.time()
def _get_gpt_model(model):
@ -72,9 +68,9 @@ def load_state_dict_to_megatron_llama(state_dict,
def broadcast_params(module):
for param in module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
torch.distributed.broadcast(
param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()
)
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
@ -92,7 +88,9 @@ def load_state_dict_to_megatron_llama(state_dict,
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}'
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (
f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
)
models = [None] * len(wrapped_models)
@ -171,8 +169,9 @@ def load_state_dict_to_megatron_llama(state_dict,
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -217,8 +216,9 @@ def load_state_dict_to_megatron_llama(state_dict,
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -238,16 +238,16 @@ def load_state_dict_to_megatron_llama(state_dict,
if torch.distributed.get_rank() == 0:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_gate_up_weight = torch.empty(
config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0)
)
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
@ -270,9 +270,9 @@ def load_state_dict_to_megatron_llama(state_dict,
requires_grad=False,
)
else:
assert (
tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -290,7 +290,7 @@ def load_state_dict_to_megatron_llama(state_dict,
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
@ -301,33 +301,33 @@ def load_state_dict_to_megatron_llama(state_dict,
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_weight_qkv = torch.empty(
total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
dim=0))
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(
torch.cat([q_part, k_part, v_part], dim=0)
)
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_weight_qkv = torch.empty(
total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
dim=0))
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(
torch.cat([q_part, k_part, v_part], dim=0)
)
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
@ -350,8 +350,9 @@ def load_state_dict_to_megatron_llama(state_dict,
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -406,8 +407,11 @@ def load_state_dict_to_megatron_llama(state_dict,
f"{layer_name}.post_attention_layernorm.weight",
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_broadcast_tp_shard_tensor_gate_up(
sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
)
_broadcast_tp_shard_tensor(
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
@ -429,15 +433,15 @@ def load_state_dict_to_megatron_llama(state_dict,
lm_head_weight = gpt_model_module.lm_head.weight
if is_value_model:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "lm_head.weight")
print_rank_0('load lm_head weight')
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
print_rank_0("load lm_head weight")
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
print_rank_0("load lm_head from value_head weight")
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
print_rank_0("fail to match lm_head in value_model")
else:
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
dist.barrier()
@ -446,4 +450,4 @@ def load_state_dict_to_megatron_llama(state_dict,
broadcast_params(wrapped_model)
torch.cuda.empty_cache()
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")

View File

@ -30,8 +30,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int
tp_size = mpu.get_tensor_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), (
f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
)
# We only support TP-DP-PP grouping, for correctness when resharding
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
@ -54,8 +55,9 @@ def _megatron_calc_layer_map(config):
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
layer_offset = (
virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
@ -107,9 +109,11 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
assert len(models[i].model.layers
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
len(models[i].model.layers), num_layers_per_model)
assert len(models[i].model.layers) == num_layers_per_model, (
"len model layers {} not equal to num_layers_per_model {}".format(
len(models[i].model.layers), num_layers_per_model
)
)
state_dict = dict()
@ -247,7 +251,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
gate_weight_list = []
up_weight_list = []
for i in range(tp_size):
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
gate_weight_list.append(gate_weight_tp)
@ -306,10 +310,10 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
q_part = qkv_part[:q_size_tp]
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
q_weight_list.append(q_part)
k_weight_list.append(k_part)
v_weight_list.append(v_part)
@ -318,10 +322,10 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
q_part = qkv_part[:q_size_tp]
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
q_weight_list.append(q_part)
if i * config.num_key_value_heads % tp_size == 0:
k_weight_list.append(k_part)
@ -384,10 +388,12 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank)
_broadcast_tp_shard_tensor_gate_up(
sync_layer.mlp.gate_up_proj.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor(
sync_layer.mlp.down_proj.weight,
@ -410,14 +416,19 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
if is_value_model:
if pp_rank == pp_size - 1:
print(f'gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}')
_broadcast_tensor(gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
"lm_head.weight",
src_pp_rank=pp_size - 1)
_broadcast_tensor(gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and
getattr(gpt_model_module, "reward_weight", None) is not None else None,
"reward_head.weight",
src_pp_rank=pp_size - 1)
print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}")
_broadcast_tensor(
gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
"lm_head.weight",
src_pp_rank=pp_size - 1,
)
_broadcast_tensor(
gpt_model_module.reward_head.weight
if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None
else None,
"reward_head.weight",
src_pp_rank=pp_size - 1,
)
else:
_broadcast_tp_shard_tensor(

View File

@ -22,31 +22,29 @@ import math
from typing import Optional, Tuple
import torch
from megatron.core import ModelParallelConfig, tensor_parallel
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
@ -99,9 +97,10 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
@ -114,7 +113,6 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None):
super().__init__(dim, max_position_embeddings, base, device)
@ -122,7 +120,8 @@ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):
self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation
self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation
self.old_context_len = config.rope_scaling[
"original_max_position_embeddings"] # `8192` in the original implementation
"original_max_position_embeddings"
] # `8192` in the original implementation
low_freq_wavelen = self.old_context_len / self.low_freq_factor
high_freq_wavelen = self.old_context_len / self.high_freq_factor
@ -131,8 +130,9 @@ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):
# wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor -
self.low_freq_factor)
smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
@ -140,15 +140,15 @@ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@ -189,47 +189,56 @@ class ParallelLlamaAttention(nn.Module):
# assign values after tp
tp_size = mpu.get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
assert self.num_key_value_heads % tp_size == 0, \
f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'
assert self.num_heads % tp_size == 0, (
f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}"
)
assert self.num_key_value_heads % tp_size == 0, (
f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}"
)
self.num_heads_per_tp = self.num_heads // tp_size
self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
self.hidden_size_per_tp = self.hidden_size // tp_size
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
assert row_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
# [self.q_size, self.k_size, self.v_size]
self.qkv_proj = QKVParallelLinear(input_size=self.hidden_size,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
bias=config.attention_bias,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
self.qkv_proj = QKVParallelLinear(
input_size=self.hidden_size,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
bias=config.attention_bias,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
self.q_size = self.num_heads_per_tp * self.head_dim
self.k_size = self.num_key_value_heads_per_tp * self.head_dim
self.v_size = self.num_key_value_heads_per_tp * self.head_dim
self.o_proj = tensor_parallel.RowParallelLinear(input_size=self.num_heads * self.head_dim,
output_size=self.hidden_size,
bias=config.attention_bias,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self.o_proj = tensor_parallel.RowParallelLinear(
input_size=self.num_heads * self.head_dim,
output_size=self.hidden_size,
bias=config.attention_bias,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs,
)
self._init_rope()
@ -297,12 +306,14 @@ class ParallelLlamaAttention(nn.Module):
if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}")
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
@ -312,7 +323,8 @@ class ParallelLlamaAttention(nn.Module):
if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
@ -326,10 +338,9 @@ Remove padding Attention
- Compatible with sequence parallel
"""
from transformers.utils import is_flash_attn_2_available
import torch.nn.functional as F
from einops import rearrange
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
@ -358,40 +369,34 @@ from flash_attn.layers.rotary import apply_rotary_emb
# use flash-attn rotary embeddings with rmpad
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
q_embed = apply_rotary_emb(q,
cos,
sin,
interleaved=False,
inplace=False,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen)
k_embed = apply_rotary_emb(k,
cos,
sin,
interleaved=False,
inplace=False,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen)
q_embed = apply_rotary_emb(
q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
k_embed = apply_rotary_emb(
k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
return q_embed, k_embed
class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
def forward(self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: torch.Tensor = None,
max_seqlen_in_batch: int = None):
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: torch.Tensor = None,
max_seqlen_in_batch: int = None,
):
total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
if self.megatron_config.sequence_parallel:
total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
dim=-1) # (total_nnz, 1, hidden_size)
query_states, key_states, value_states = qkv.split(
[self.q_size, self.k_size, self.v_size], dim=-1
) # (total_nnz, 1, hidden_size)
if self.megatron_config.sequence_parallel:
sequence_parallel_pad = total_nnz - cu_seqlens[-1]
@ -408,13 +413,10 @@ class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
key_states,
cos,
sin,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch)
cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(
query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch
)
# query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
# TODO: llama does not have dropout in the config??

View File

@ -21,19 +21,18 @@
from typing import Optional, Tuple
import torch
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from megatron.core import ModelParallelConfig
from verl.utils.megatron_utils import TransformerConfig, convert_config
from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
from .parallel_mlp import ParallelLlamaMLP
from .parallel_rmsnorm import ParallelLlamaRMSNorm
from verl.utils.megatron_utils import TransformerConfig, convert_config
class ParallelLlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
@ -101,7 +100,6 @@ class ParallelLlamaDecoderLayer(nn.Module):
class ParallelLlamaDecoderLayerRmPad(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
@ -120,7 +118,7 @@ class ParallelLlamaDecoderLayerRmPad(nn.Module):
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None
max_seqlen_in_batch: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
@ -129,12 +127,14 @@ class ParallelLlamaDecoderLayerRmPad(nn.Module):
# Self Attention
# (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
# -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
hidden_states = self.self_attn(hidden_states=hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch,
)
hidden_states = residual + hidden_states

View File

@ -13,23 +13,23 @@
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
from typing import Optional, Tuple
from megatron.core import tensor_parallel
class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
def __init__(self,
input_size,
num_heads,
num_key_value_heads,
head_dim,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs):
def __init__(
self,
input_size,
num_heads,
num_key_value_heads,
head_dim,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs,
):
# Keep input parameters, and already restrict the head numbers
self.input_size = input_size
self.q_output_size = num_heads * head_dim
@ -41,44 +41,48 @@ class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
input_size = self.input_size
output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)
super().__init__(
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs,
)
class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
def __init__(self,
input_size,
gate_ouput_size,
up_output_size,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs):
def __init__(
self,
input_size,
gate_ouput_size,
up_output_size,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs,
):
# Keep input parameters, and already restrict the head numbers
self.input_size = input_size
self.output_size = gate_ouput_size + up_output_size
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
super().__init__(input_size=self.input_size,
output_size=self.output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)
super().__init__(
input_size=self.input_size,
output_size=self.output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs,
)
import torch
class LinearForLastLayer(torch.nn.Linear):
def __init__(
self,
input_size,
@ -90,7 +94,7 @@ class LinearForLastLayer(torch.nn.Linear):
super().__init__(in_features=input_size, out_features=output_size, bias=bias)
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel:
setattr(self.weight, 'sequence_parallel', True)
self.weight.sequence_parallel = True
def forward(
self,

View File

@ -18,18 +18,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from megatron.core import ModelParallelConfig, tensor_parallel
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.activations import ACT2FN
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class ParallelLlamaMLP(nn.Module):
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
super().__init__()
self.config = config
@ -41,8 +39,8 @@ class ParallelLlamaMLP(nn.Module):
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
assert row_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
@ -59,12 +57,14 @@ class ParallelLlamaMLP(nn.Module):
)
self.gate_size = self.intermediate_size // tp_size
self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
output_size=self.hidden_size,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self.down_proj = tensor_parallel.RowParallelLinear(
input_size=self.intermediate_size,
output_size=self.hidden_size,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs,
)
self.act_fn = ACT2FN[config.hidden_act]

View File

@ -13,17 +13,17 @@
# limitations under the License.
import numbers
import torch
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from verl.utils.megatron import sequence_parallel as sp_utils
class ParallelLlamaRMSNorm(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
@ -39,8 +39,10 @@ class ParallelLlamaRMSNorm(nn.Module):
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
def forward(self, hidden_states):
return fused_rms_norm_affine(input=hidden_states,
weight=self.weight,
normalized_shape=self.normalized_shape,
eps=self.variance_epsilon,
memory_efficient=True)
return fused_rms_norm_affine(
input=hidden_states,
weight=self.weight,
normalized_shape=self.normalized_shape,
eps=self.variance_epsilon,
memory_efficient=True,
)

View File

@ -23,10 +23,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from megatron.core import mpu
from megatron.core import ModelParallelConfig, mpu, tensor_parallel
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
@ -35,7 +32,9 @@ from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron import tensor_parallel as tp_utils
from verl.utils.megatron_utils import TransformerConfig, convert_config
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm
"""
TODO:
1. Add weight initialization. Here we need to be careful on TP weight init.
@ -87,14 +86,15 @@ class ParallelLlamaModel(nn.Module):
self.vocab_size = config.vocab_size
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(
num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs
)
self.layers = nn.ModuleList(
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]
)
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
@ -111,10 +111,12 @@ class ParallelLlamaModel(nn.Module):
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
combined_attention_mask)
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@ -157,7 +159,6 @@ class ParallelLlamaModel(nn.Module):
class ParallelLlamaForCausalLM(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
@ -166,15 +167,17 @@ class ParallelLlamaForCausalLM(nn.Module):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
self.lm_head = tensor_parallel.ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
def forward(
self,
@ -233,23 +236,26 @@ class ParallelLlamaModelRmPad(nn.Module):
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
self.megatron_config = megatron_config
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(
num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs
)
self.layers = nn.ModuleList(
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]
)
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
def forward(self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
@ -268,12 +274,14 @@ class ParallelLlamaModelRmPad(nn.Module):
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
layer_outputs = decoder_layer(
hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch,
)
hidden_states = layer_outputs
@ -283,7 +291,6 @@ class ParallelLlamaModelRmPad(nn.Module):
class ParallelLlamaForCausalLMRmPad(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
@ -295,14 +302,16 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
self.lm_head = tensor_parallel.ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
def _forward_head(self, hidden_states):
# all_gather from sequence parallel region is performed inside lm_head
@ -329,8 +338,9 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(
input_ids.unsqueeze(dim=-1), attention_mask
) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
@ -339,12 +349,14 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
outputs = self.model(input_ids=input_ids,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch,
)
hidden_states = outputs
@ -357,8 +369,9 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
# add removed padding back
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
logits = pad_input(
logits, indices, batch_size, seqlen=sequence_length
) # (batch_size, sequence_length, vocab_size)
return CausalLMOutputWithPast(
loss=None,
@ -370,11 +383,10 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
@ -423,12 +435,12 @@ class ParallelLlamaModelRmPadPP(nn.Module):
self.megatron_config = megatron_config
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
if pre_process:
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(
num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs
)
else:
self.embed_tokens = None
@ -442,9 +454,7 @@ class ParallelLlamaModelRmPadPP(nn.Module):
self.layers = nn.ModuleList()
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
self.num_layer_this_model = self.num_layer_vpp_chunk
offset = vpp_rank * (
config.num_hidden_layers // vpp_size) + \
(pp_rank * self.num_layer_vpp_chunk)
offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)
else:
self.num_layer_this_model = self.num_layer_per_pp
offset = pp_rank * self.num_layer_per_pp
@ -452,7 +462,7 @@ class ParallelLlamaModelRmPadPP(nn.Module):
self.layers = nn.ModuleList()
for i in range(self.num_layer_this_model):
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)
self.layers.add_module(f'{i}', layer)
self.layers.add_module(f"{i}", layer)
if post_process:
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
@ -469,13 +479,15 @@ class ParallelLlamaModelRmPadPP(nn.Module):
forward_step_func"""
self.input_tensor = input_tensor
def forward(self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
@ -501,12 +513,14 @@ class ParallelLlamaModelRmPadPP(nn.Module):
hidden_states = self.input_tensor
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
layer_outputs = decoder_layer(
hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch,
)
hidden_states = layer_outputs
@ -517,21 +531,23 @@ class ParallelLlamaModelRmPadPP(nn.Module):
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
def __init__(self,
config: LlamaConfig,
megatron_config: ModelParallelConfig,
pre_process,
post_process,
share_embeddings_and_output_weights=False):
def __init__(
self,
config: LlamaConfig,
megatron_config: ModelParallelConfig,
pre_process,
post_process,
share_embeddings_and_output_weights=False,
):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPadPP(config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
assert share_embeddings_and_output_weights == False, f'Llama Model not supports sharing embedding and output weights'
self.model = ParallelLlamaModelRmPadPP(
config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process
)
assert share_embeddings_and_output_weights == False, (
"Llama Model not supports sharing embedding and output weights"
)
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.vocab_size = config.vocab_size
self.pre_process = pre_process
@ -553,14 +569,16 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
self.lm_head = tensor_parallel.ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
def _forward_head(self, hidden_states):
# all_gather from sequence parallel region is performed inside lm_head
@ -592,8 +610,9 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(
input_ids.unsqueeze(dim=-1), attention_mask
) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
@ -602,12 +621,14 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
outputs = self.model(input_ids=input_ids_rmpad,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
outputs = self.model(
input_ids=input_ids_rmpad,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch,
)
if self.post_process:
hidden_states = outputs
@ -620,8 +641,9 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back. If input is already rmpad, we let the caller pad_input
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
logits = pad_input(
logits, indices, batch_size, seqlen=sequence_length
) # (batch_size, sequence_length, vocab_size)
return CausalLMOutputWithPast(
loss=None,
@ -635,11 +657,10 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel

View File

@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import time
import torch
import torch.distributed as dist
from .saver import _megatron_calc_global_rank
@ -26,7 +28,6 @@ def _megatron_calc_layer_map(config):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
import megatron
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
@ -38,8 +39,9 @@ def _megatron_calc_layer_map(config):
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
layer_offset = (
virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
@ -50,15 +52,14 @@ def _megatron_calc_layer_map(config):
def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
import megatron
from megatron.core import mpu
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
"""Load merged state_dict to sharded Megatron module in training."""
from megatron.core import DistributedDataParallel as LocalDDP
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import print_rank_0, unwrap_model
start_time = time.time()
def _get_gpt_model(model):
@ -66,9 +67,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
def broadcast_params(module):
for param in module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
torch.distributed.broadcast(
param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()
)
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
@ -167,8 +168,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -213,8 +215,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -234,16 +237,16 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
if torch.distributed.get_rank() == src_rank:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
new_gate_up_weight = torch.empty(
config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()
)
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0)
)
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
@ -266,9 +269,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
requires_grad=False,
)
else:
assert (
tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -286,7 +289,7 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == src_rank:
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
@ -302,18 +305,19 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
sizes.append(config.hidden_size)
new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
num_query_groups_per_partition = models[0].config.num_query_groups // tp_size
new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size]
new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]
q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)
k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)
v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)
total_size_per_head = total_size // num_query_groups_per_partition
for j in range(num_query_groups_per_partition):
new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_(
torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(
torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)
)
else:
q_size_tp = config.hidden_size // tp_size
@ -324,19 +328,20 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
sizes.append(config.hidden_size)
new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size]
new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]
q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)
k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)
v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)
total_size_per_head = total_size // config.num_attention_heads
for j in range(config.num_attention_heads):
new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_(
torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(
torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)
)
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
@ -359,8 +364,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
assert tensor.shape == chunk_shape, (
f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
)
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
@ -409,7 +415,8 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
f"{layer_name}.self_attn.q_proj.bias",
f"{layer_name}.self_attn.k_proj.bias",
f"{layer_name}.self_attn.v_proj.bias",
bias=True)
bias=True,
)
_broadcast_tp_shard_tensor(
sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,
@ -421,8 +428,11 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
f"{layer_name}.post_attention_layernorm.weight",
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_broadcast_tp_shard_tensor_gate_up(
sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
)
_broadcast_tp_shard_tensor(
sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,
@ -445,14 +455,14 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
if is_value_model:
# if torch.distributed.get_rank() == src_rank:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "lm_head.weight")
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
print_rank_0("load lm_head from value_head weight")
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
print_rank_0("fail to match lm_head in value_model")
# else:
# _broadcast_tensor(lm_head_weight, "lm_head.weight")

View File

@ -13,23 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch
import time
import torch
import torch.distributed as dist
from megatron.core import mpu
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import print_rank_0, unwrap_model
def _megatron_calc_global_rank(tp_rank: int = 0,
dp_rank: int = 0,
pp_rank: int = 0,
cp_rank: int = 0,
ep_rank: int = 0):
def _megatron_calc_global_rank(
tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0
):
"""Calculate global rank with support for CP/EP parallelism"""
# Get parallel sizes for each dimension
@ -41,8 +39,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0,
# Verify total GPU count matches (must be consistent with parallel_state.py)
total_size = tp_size * dp_size * pp_size * cp_size
assert total_size == torch.distributed.get_world_size(), \
assert total_size == torch.distributed.get_world_size(), (
f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}"
)
# Core calculation logic (corresponds to RankGenerator order parameter)
# Assumes default order is "tp-cp-ep-dp-pp"
@ -67,8 +66,9 @@ def _megatron_calc_layer_map(config):
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
layer_offset = (
virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
@ -121,9 +121,11 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
assert len(models[i].decoder.layers
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
len(models[i].decoder.layers), num_layers_per_model)
assert len(models[i].decoder.layers) == num_layers_per_model, (
"len model layers {} not equal to num_layers_per_model {}".format(
len(models[i].decoder.layers), num_layers_per_model
)
)
state_dict = dict()
@ -261,7 +263,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
gate_weight_list = []
up_weight_list = []
for i in range(tp_size):
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
gate_weight_list.append(gate_weight_tp)
@ -321,13 +323,13 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
q_size_chunk = q_size_tp // num_query_groups_per_partition
kv_size_chunk = kv_size_tp // num_query_groups_per_partition
for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
q_part = qkv_part_chunk[:q_size_chunk]
k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk]
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:]
k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]
q_weight_list.append(q_part)
k_weight_list.append(k_part)
v_weight_list.append(v_part)
@ -337,13 +339,13 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
q_size_chunk = q_size_tp // num_query_groups_per_partition
kv_size_chunk = kv_size_tp // num_query_groups_per_partition
for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
q_part = qkv_part_chunk[:q_size_chunk]
k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk]
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:]
k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]
q_weight_list.append(q_part)
if i * config.num_key_value_heads % tp_size == 0:
k_weight_list.append(k_part)
@ -393,8 +395,10 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
src_pp_rank=src_pp_rank,
)
if getattr(sync_layer.self_attention.linear_qkv, 'bias',
None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0:
if (
getattr(sync_layer.self_attention.linear_qkv, "bias", None) is not None
and sync_layer.self_attention.linear_qkv.bias.numel() > 0
):
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attention.linear_qkv.bias,
f"{layer_name}.self_attn.q_proj.bias",
@ -416,10 +420,12 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank)
_broadcast_tp_shard_tensor_gate_up(
sync_layer.mlp.linear_fc1.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor(
sync_layer.mlp.linear_fc2.weight,
@ -439,7 +445,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
)
if tie_word_embeddings:
print_rank_0(f"tie word embedding skip load lm_head...")
print_rank_0("tie word embedding skip load lm_head...")
else:
print_rank_0("collecting lm_head...")
@ -459,7 +465,6 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
dist.barrier()
torch.cuda.empty_cache()
if torch.distributed.get_rank() == 0:
for k, v in state_dict.items():
if dtype != v.dtype:
state_dict[k] = v.to(dtype)

Some files were not shown because too many files have changed in this diff Show More