From b00f77d8559b48d57a33c0132a5ba1c81891a536 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Fri, 18 Apr 2025 22:49:31 +0800 Subject: [PATCH] [dev] feat: immigrate from yapf & pylint to ruff based on pre-commit (#1010) > [!WARNING] > We are [immigrating to `ruff` as the linter and formatter and `pre-commit` as the managing tool](https://github.com/volcengine/verl/pull/1010). > > If your branch is based on a previous commit using `yapf` and `pylint`, simply merging might trigger overwhelming linting errors, while **you are only expected to resolve ones in the files related to your PR**. > > To resolve this issue, please try the following workaround to only include the files you **really changed** in the PR: > > 1. In your branch, fix linting and format with `ruff`: `ruff check --fix && ruff-format` > 2. Squash into a single commit in a new branch: `git reset --soft $(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."` > 3. Merge with the latest main: `git merge origin/main` > 4. Force push to your branch: `git push --force` We add the reminder above to the documentation to tell contributors how to avoid overwhelming linting errors. ### Motivation According to dicussion in #896, this PR immigrates from yapf & pylint to ruff based on pre-commit, which allows unified version control and automatic hook on committing. ### Summary The `pre-commit` hook and CI - checks staged / committed files in commits / PR's - checks all files each month (This should fail before we fix all the files by the ruff standard) ### Explanation for the Failing CI Workflow `pre-commit` For now, we only apply `ruff format` and `ruff check --fix` **without resolving all the errors**, since there are too many errors to resolve, which causes the CI workflow `pre-commit` fails. For resolving the remaining errors, we leave to future commits. Specifically, the `pre-commit` hook and CI will require every commit to fix its related files with `ruff`, which will fix all the files incrementally. ### Reviewing Suggestion The commit https://github.com/volcengine/verl/pull/1010/commits/3d93f51ba8838909096c9233426557ffc0df3431 is huge since we apply `ruff` to all the files. To review the main changes, please check the commits before and after it. --- .github/workflows/pre-commit-full.yml | 30 + .github/workflows/pre-commit.yml | 30 + .github/workflows/pylint.yml | 40 - .github/workflows/yapf_format.yml | 56 -- .pre-commit-config.yaml | 8 + .style.yapf | 5 - .vscode/settings.json | 5 +- README.md | 28 +- docker/Dockerfile.ngc.vllm | 14 +- docker/Dockerfile.ngc.vllm0.8 | 2 +- docker/Dockerfile.ngc.vllm0.8.sagemaker | 2 +- docker/Dockerfile.sglang | 2 +- docs/conf.py | 30 +- docs/index.rst | 35 +- examples/data_preprocess/full_hh_rlhf.py | 87 +- examples/data_preprocess/geo3k.py | 59 +- examples/data_preprocess/gsm8k.py | 61 +- examples/data_preprocess/hellaswag.py | 52 +- examples/data_preprocess/math_dataset.py | 48 +- examples/data_preprocess/multiturn.py | 123 ++- .../ppo_trainer/verl_getting_started.ipynb | 8 +- examples/ray/tutorial.ipynb | 146 +-- examples/split_placement/main_ppo_split.py | 84 +- .../split_placement/split_monkey_patch.py | 115 ++- pyproject.toml | 152 +--- recipe/dapo/src/dapo_ray_trainer.py | 164 ++-- recipe/dapo/src/main_dapo.py | 108 ++- recipe/prime/__init__.py | 2 +- recipe/prime/main_prime.py | 59 +- recipe/prime/prime_core_algos.py | 60 +- recipe/prime/prime_dp_rm.py | 277 +++--- recipe/prime/prime_fsdp_workers.py | 287 +++--- recipe/prime/prime_ray_trainer.py | 444 ++++----- recipe/r1/data_process.py | 108 +-- recipe/r1/main_eval.py | 23 +- recipe/r1/reward_score.py | 9 +- recipe/r1/tasks/livecodebench.py | 10 +- recipe/r1/tasks/math.py | 6 +- requirements.txt | 2 +- scripts/converter_hf_to_mcore.py | 115 +-- scripts/diagnose.py | 174 ++-- scripts/format.sh | 3 - scripts/model_merger.py | 143 +-- setup.py | 100 +- tests/__init__.py | 2 +- tests/checkpoint/test_fsdp_ckpt.py | 58 +- tests/distributed/test_tensor_dict.py | 94 +- .../data/create_dataset.py | 24 +- .../model/create_model_tokenizer.py | 43 +- .../arithmetic_sequence/rl/main_trainer.py | 66 +- tests/e2e/check_custom_rwd_fn.py | 8 +- tests/e2e/check_results.py | 24 +- tests/e2e/envs/__init__.py | 2 +- tests/e2e/envs/digit_completion/__init__.py | 6 +- tests/e2e/envs/digit_completion/task.py | 46 +- tests/e2e/envs/digit_completion/tokenizer.py | 23 +- tests/e2e/sft/test_sp_loss_match.py | 21 +- tests/gpu_utility/test_memory_buffers.py | 29 +- tests/gpu_utility/test_ops.py | 24 +- tests/gpu_utility/test_torch_functional.py | 49 +- tests/model/test_transformer.py | 171 ++-- tests/model/test_transformers_ulysses.py | 193 ++-- tests/ray/check_worker_alive/main.py | 17 +- tests/ray/detached_worker/client.py | 34 +- tests/ray/detached_worker/server.py | 98 +- tests/ray/test_check_worker_alive.py | 9 +- tests/ray/test_colocated_workers.py | 29 +- tests/ray/test_data_transfer.py | 38 +- tests/ray/test_driverfunc_to_worker.py | 39 +- tests/ray/test_high_level_scheduling_api.py | 6 +- tests/ray/test_ray_local_envs.py | 14 +- tests/ray/test_rvdz.py | 2 +- tests/ray/test_worker_group_basics.py | 16 +- tests/ray/test_worker_group_torch.py | 32 +- tests/rollout/run_fsdp_vllm.py | 108 +-- tests/rollout/test_sglang_spmd.py | 73 +- tests/rollout/test_vllm_hf_loader.py | 69 +- tests/rollout/test_vllm_spmd.py | 83 +- tests/sandbox/test_sandbox.py | 28 +- tests/sanity/check_license.py | 12 +- tests/sanity/test_import.py | 2 + tests/utility/test_tensor_dict_utilities.py | 216 +++-- .../dataset/test_multiturn_sft_dataset.py | 118 ++- tests/verl/utils/dataset/test_rl_dataset.py | 123 ++- tests/verl/utils/dataset/test_rm_dataset.py | 7 +- tests/verl/utils/dataset/test_sft_dataset.py | 59 +- tests/verl/utils/test_import_utils.py | 6 +- verl/__init__.py | 16 +- verl/models/llama/megatron/__init__.py | 9 +- .../megatron/checkpoint_utils/llama_loader.py | 109 ++- .../llama_loader_depracated.py | 132 +-- .../megatron/checkpoint_utils/llama_saver.py | 63 +- .../megatron/layers/parallel_attention.py | 156 ++-- .../llama/megatron/layers/parallel_decoder.py | 24 +- .../llama/megatron/layers/parallel_linear.py | 76 +- .../llama/megatron/layers/parallel_mlp.py | 24 +- .../llama/megatron/layers/parallel_rmsnorm.py | 16 +- .../llama/megatron/modeling_llama_megatron.py | 245 ++--- verl/models/mcore/loader.py | 106 ++- verl/models/mcore/saver.py | 69 +- verl/models/mcore/util.py | 113 +-- verl/models/qwen2/megatron/__init__.py | 9 +- .../megatron/checkpoint_utils/qwen2_loader.py | 129 +-- .../qwen2_loader_depracated.py | 150 +-- .../megatron/checkpoint_utils/qwen2_saver.py | 76 +- .../megatron/layers/parallel_attention.py | 115 ++- .../qwen2/megatron/layers/parallel_decoder.py | 24 +- .../qwen2/megatron/layers/parallel_linear.py | 73 +- .../qwen2/megatron/layers/parallel_mlp.py | 24 +- .../qwen2/megatron/layers/parallel_rmsnorm.py | 16 +- .../qwen2/megatron/modeling_qwen2_megatron.py | 242 ++--- verl/models/registry.py | 19 +- verl/models/transformers/llama.py | 59 +- verl/models/transformers/monkey_patch.py | 33 +- verl/models/transformers/qwen2.py | 58 +- verl/models/transformers/qwen2_vl.py | 53 +- verl/models/weight_loader_registry.py | 26 +- verl/protocol.py | 165 ++-- verl/single_controller/__init__.py | 4 +- verl/single_controller/base/__init__.py | 4 +- verl/single_controller/base/decorator.py | 104 ++- .../single_controller/base/megatron/worker.py | 35 +- .../base/megatron/worker_group.py | 8 +- .../base/register_center/ray.py | 1 - verl/single_controller/base/worker.py | 57 +- verl/single_controller/base/worker_group.py | 64 +- verl/single_controller/ray/__init__.py | 2 +- verl/single_controller/ray/base.py | 204 ++--- verl/single_controller/ray/megatron.py | 39 +- verl/third_party/sglang/__init__.py | 2 +- verl/third_party/sglang/parallel_state.py | 10 +- verl/third_party/vllm/__init__.py | 43 +- .../vllm/vllm_v_0_3_1/arg_utils.py | 330 ++++--- verl/third_party/vllm/vllm_v_0_3_1/config.py | 177 ++-- verl/third_party/vllm/vllm_v_0_3_1/llm.py | 44 +- .../vllm/vllm_v_0_3_1/llm_engine_sp.py | 175 ++-- .../vllm/vllm_v_0_3_1/model_loader.py | 138 +-- .../vllm/vllm_v_0_3_1/model_runner.py | 91 +- .../vllm/vllm_v_0_3_1/parallel_state.py | 22 +- .../vllm/vllm_v_0_3_1/tokenizer.py | 26 +- .../vllm/vllm_v_0_3_1/weight_loaders.py | 29 +- verl/third_party/vllm/vllm_v_0_3_1/worker.py | 66 +- .../vllm/vllm_v_0_4_2/arg_utils.py | 369 ++++---- verl/third_party/vllm/vllm_v_0_4_2/config.py | 68 +- .../vllm_v_0_4_2/dtensor_weight_loaders.py | 60 +- .../vllm/vllm_v_0_4_2/hf_weight_loader.py | 9 +- verl/third_party/vllm/vllm_v_0_4_2/llm.py | 74 +- .../vllm/vllm_v_0_4_2/llm_engine_sp.py | 62 +- .../vllm_v_0_4_2/megatron_weight_loaders.py | 147 +-- .../vllm/vllm_v_0_4_2/model_loader.py | 150 +-- .../vllm/vllm_v_0_4_2/model_runner.py | 113 ++- .../vllm/vllm_v_0_4_2/parallel_state.py | 58 +- .../vllm/vllm_v_0_4_2/spmd_gpu_executor.py | 38 +- .../vllm/vllm_v_0_4_2/tokenizer.py | 26 +- verl/third_party/vllm/vllm_v_0_4_2/worker.py | 67 +- .../vllm/vllm_v_0_5_4/arg_utils.py | 479 +++++----- verl/third_party/vllm/vllm_v_0_5_4/config.py | 110 ++- .../vllm_v_0_5_4/dtensor_weight_loaders.py | 90 +- .../vllm/vllm_v_0_5_4/hf_weight_loader.py | 9 +- verl/third_party/vllm/vllm_v_0_5_4/llm.py | 48 +- .../vllm/vllm_v_0_5_4/llm_engine_sp.py | 78 +- .../vllm_v_0_5_4/megatron_weight_loaders.py | 121 +-- .../vllm/vllm_v_0_5_4/model_loader.py | 208 +++-- .../vllm/vllm_v_0_5_4/model_runner.py | 84 +- .../vllm/vllm_v_0_5_4/parallel_state.py | 76 +- .../vllm/vllm_v_0_5_4/spmd_gpu_executor.py | 50 +- .../vllm/vllm_v_0_5_4/tokenizer.py | 26 +- verl/third_party/vllm/vllm_v_0_5_4/worker.py | 117 ++- verl/third_party/vllm/vllm_v_0_6_3/config.py | 9 +- .../vllm_v_0_6_3/dtensor_weight_loaders.py | 17 +- .../vllm/vllm_v_0_6_3/hf_weight_loader.py | 2 +- verl/third_party/vllm/vllm_v_0_6_3/llm.py | 12 +- .../vllm/vllm_v_0_6_3/llm_engine_sp.py | 38 +- .../vllm_v_0_6_3/megatron_weight_loaders.py | 26 +- .../vllm/vllm_v_0_6_3/model_loader.py | 37 +- .../vllm/vllm_v_0_6_3/model_runner.py | 17 +- .../vllm/vllm_v_0_6_3/parallel_state.py | 7 +- .../vllm/vllm_v_0_6_3/spmd_gpu_executor.py | 5 +- .../vllm/vllm_v_0_6_3/tokenizer.py | 5 +- verl/third_party/vllm/vllm_v_0_6_3/worker.py | 50 +- verl/trainer/fsdp_sft_trainer.py | 346 +++---- verl/trainer/main_eval.py | 25 +- verl/trainer/main_generation.py | 80 +- verl/trainer/main_ppo.py | 105 ++- verl/trainer/ppo/core_algos.py | 115 +-- verl/trainer/ppo/metric_utils.py | 190 ++-- verl/trainer/ppo/ray_trainer.py | 627 +++++++------ verl/utils/__init__.py | 4 +- verl/utils/checkpoint/__init__.py | 2 +- verl/utils/checkpoint/checkpoint_manager.py | 48 +- .../checkpoint/fsdp_checkpoint_manager.py | 90 +- .../checkpoint/megatron_checkpoint_manager.py | 216 +++-- verl/utils/dataset/multiturn_sft_dataset.py | 80 +- verl/utils/dataset/rl_dataset.py | 70 +- verl/utils/dataset/rm_dataset.py | 54 +- verl/utils/dataset/sft_dataset.py | 77 +- verl/utils/debug/__init__.py | 2 +- verl/utils/debug/performance.py | 5 +- verl/utils/debug/trajectory_tracker.py | 41 +- verl/utils/distributed.py | 7 +- verl/utils/flops_counter.py | 34 +- verl/utils/fs.py | 17 +- verl/utils/fsdp_utils.py | 62 +- verl/utils/hdfs_io.py | 17 +- verl/utils/import_utils.py | 14 +- verl/utils/logger/aggregate_logger.py | 12 +- verl/utils/logging_utils.py | 9 +- verl/utils/megatron/memory.py | 11 +- verl/utils/megatron/optimizer.py | 32 +- verl/utils/megatron/pipeline_parallel.py | 20 +- verl/utils/megatron/sequence_parallel.py | 6 +- verl/utils/megatron/tensor_parallel.py | 58 +- verl/utils/megatron_utils.py | 221 ++--- verl/utils/memory_buffer.py | 27 +- verl/utils/model.py | 259 +++--- verl/utils/py_functional.py | 6 +- verl/utils/ray_utils.py | 5 +- verl/utils/rendezvous/ray_backend.py | 14 +- verl/utils/reward_score/__init__.py | 24 +- verl/utils/reward_score/geo3k.py | 3 +- verl/utils/reward_score/gsm8k.py | 16 +- verl/utils/reward_score/math.py | 14 +- verl/utils/reward_score/math_dapo.py | 67 +- verl/utils/reward_score/math_verify.py | 8 +- .../utils/reward_score/prime_code/__init__.py | 11 +- .../reward_score/prime_code/testing_util.py | 72 +- verl/utils/reward_score/prime_code/utils.py | 17 +- .../utils/reward_score/prime_math/__init__.py | 73 +- verl/utils/reward_score/prime_math/grader.py | 87 +- .../reward_score/prime_math/math_normalize.py | 3 +- verl/utils/seqlen_balancing.py | 48 +- verl/utils/tokenizer.py | 17 +- verl/utils/torch_dtypes.py | 12 +- verl/utils/torch_functional.py | 107 +-- verl/utils/tracking.py | 114 +-- verl/utils/ulysses.py | 58 +- verl/workers/actor/base.py | 11 +- verl/workers/actor/dp_actor.py | 242 ++--- verl/workers/actor/megatron_actor.py | 263 +++--- verl/workers/critic/base.py | 4 +- verl/workers/critic/dp_critic.py | 163 ++-- verl/workers/critic/megatron_critic.py | 162 ++-- verl/workers/fsdp_workers.py | 857 ++++++++++-------- verl/workers/megatron_workers.py | 496 +++++----- verl/workers/reward_manager/__init__.py | 4 +- verl/workers/reward_manager/batch.py | 52 +- verl/workers/reward_manager/dapo.py | 53 +- verl/workers/reward_manager/naive.py | 33 +- verl/workers/reward_manager/prime.py | 70 +- verl/workers/reward_model/base.py | 1 - .../reward_model/megatron/reward_model.py | 163 ++-- verl/workers/rollout/__init__.py | 2 +- verl/workers/rollout/base.py | 4 +- verl/workers/rollout/hf_rollout.py | 55 +- verl/workers/rollout/naive/naive_rollout.py | 32 +- .../rollout/sglang_rollout/sglang_rollout.py | 152 ++-- verl/workers/rollout/tokenizer.py | 16 +- verl/workers/rollout/vllm_rollout/__init__.py | 16 +- .../rollout/vllm_rollout/fire_vllm_rollout.py | 72 +- .../rollout/vllm_rollout/vllm_rollout.py | 145 +-- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 172 ++-- verl/workers/sharding_manager/__init__.py | 4 +- verl/workers/sharding_manager/base.py | 1 - verl/workers/sharding_manager/fsdp_sglang.py | 61 +- verl/workers/sharding_manager/fsdp_ulysses.py | 25 +- verl/workers/sharding_manager/fsdp_vllm.py | 97 +- .../workers/sharding_manager/megatron_vllm.py | 154 ++-- .../sharding_manager/patch/fsdp_vllm_patch.py | 49 +- 268 files changed, 10660 insertions(+), 9233 deletions(-) create mode 100644 .github/workflows/pre-commit-full.yml create mode 100644 .github/workflows/pre-commit.yml delete mode 100644 .github/workflows/pylint.yml delete mode 100644 .github/workflows/yapf_format.yml create mode 100644 .pre-commit-config.yaml delete mode 100644 .style.yapf delete mode 100755 scripts/format.sh diff --git a/.github/workflows/pre-commit-full.yml b/.github/workflows/pre-commit-full.yml new file mode 100644 index 000000000..46ef0b5c8 --- /dev/null +++ b/.github/workflows/pre-commit-full.yml @@ -0,0 +1,30 @@ +name: pre-commit-full + +# Run weekly on Sunday at 00:00 UTC +on: + schedule: + - cron: '0 0 * * 0' + # Allow manual triggering + workflow_dispatch: + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + pre-commit-full: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v3.0.1 + env: + RUFF_OUTPUT_FORMAT: github + with: + extra_args: --all-files \ No newline at end of file diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 000000000..b0fa24fd6 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,30 @@ +# c.f. https://github.com/pre-commit/action?tab=readme-ov-file#using-this-action +name: pre-commit + +# No need to avoid / cancel lightweight pre-commit jobs +on: + pull_request: + push: + branches: + - main + - v0.2.x + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + pre-commit: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v3.0.1 + env: + RUFF_OUTPUT_FORMAT: github \ No newline at end of file diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml deleted file mode 100644 index 17b0e372b..000000000 --- a/.github/workflows/pylint.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Pylint Check - -on: - push: - paths: - - '**.py' - - 'requirements.txt' - - 'pyproject.toml' - pull_request: - paths: - - '**.py' - - 'requirements.txt' - - 'pyproject.toml' - -jobs: - lint: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.12' - - - name: Install pylint (version from requirements.txt) - run: | - PYLINT_VERSION=$(grep '^pylint' requirements.txt) - if [ -z "$PYLINT_VERSION" ]; then - echo "No pylint version found in requirements.txt" - exit 1 - fi - # only install pylint to avoid dependency problems on CPU - pip install "$PYLINT_VERSION" - - - name: Run pylint - run: | - pylint --recursive=y --rcfile=pyproject.toml ./ diff --git a/.github/workflows/yapf_format.yml b/.github/workflows/yapf_format.yml deleted file mode 100644 index 78468f835..000000000 --- a/.github/workflows/yapf_format.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: yapf - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.2.x - paths: - - "**/*.py" - - .github/workflows/yapf_format.yml - pull_request: - branches: - - main - - v0.2.x - paths: - - "**/*.py" - - .github/workflows/yapf_format.yml - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - yapf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # - name: checkout - # run: | - # commits=${{ github.event.pull_request.commits }} - # if [[ -n "$commits" ]]; then - # # Prepare enough depth for diffs with main - # git fetch --depth="$(( commits + 1 ))" - # fi - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --upgrade yapf - pip install toml==0.10.2 - - name: Running yapf - run: | - yapf -r -vv -d --style=./.style.yapf verl tests examples recipe diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..7f29d8b44 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.11.4" + hooks: + - id: ruff + pass_filenames: true + entry: bash -c 'ruff check --fix --show-fixes --output-format=${RUFF_OUTPUT_FORMAT:-full} "$@"' + - id: ruff-format diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index a09a27c31..000000000 --- a/.style.yapf +++ /dev/null @@ -1,5 +0,0 @@ -[style] -based_on_style = google -column_limit = 120 -indent_width = 4 -split_arguments_when_comma_terminated: true \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 29f3e4cb0..ce6fc42a1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,9 +1,8 @@ { - "pylint.enabled": true, "[python]": { - "editor.defaultFormatter": "eeyore.yapf", + "editor.defaultFormatter": "charliermarsh.ruff", "editor.codeActionsOnSave": { - "source.organizeImports": "never", + "source.organizeImports": "always", } } } \ No newline at end of file diff --git a/README.md b/README.md index 18a683acf..d6df6c6c7 100644 --- a/README.md +++ b/README.md @@ -170,16 +170,34 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The - [all-hands/openhands-lm-32b-v0.1](https://www.all-hands.dev/blog/introducing-openhands-lm-32b----a-strong-open-coding-agent-model): A strong, open coding agent model, trained with [multi-turn fine-tuning](https://github.com/volcengine/verl/pull/195) ## Contribution Guide + Contributions from the community are welcome! Please check out our [project roadmap](https://github.com/volcengine/verl/issues/710) and [good first issues](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22) to see where you can contribute. -### Code formatting -We use yapf (Google style) to enforce strict code formatting when reviewing PRs. To reformat your code locally, make sure you have installed the **latest** version of `yapf` +### Code Linting and Formatting + +> [!WARNING] +> We are [immigrating to `ruff` as the linter and formatter and `pre-commit` as the managing tool](https://github.com/volcengine/verl/pull/1010). +> +> If your branch is based on a previous commit using `yapf` and `pylint`, simply merging might trigger overwhelming linting errors, while **you are only expected to resolve ones in the files related to your PR**. +> +> To resolve this issue, please try the following workaround to only include the files you **really changed** in the PR: +> +> 1. In your branch, fix linting and format with `ruff`: `ruff check --fix && ruff-format` +> 2. Squash into a new single commit: `git reset --soft $(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."` +> 3. Merge with the latest main: `git merge origin/main` +> 4. Force push to your branch: `git push --force` + +We use pre-commit to help improve code quality. To initialize pre-commit, run: + ```bash -pip3 install yapf --upgrade +pip install pre-commit +pre-commit install ``` -Then, make sure you are at top level of verl repo and run + +You can also manually run pre-commit by: + ```bash -bash scripts/format.sh +pre-commit run ``` ### Adding CI tests diff --git a/docker/Dockerfile.ngc.vllm b/docker/Dockerfile.ngc.vllm index 3a43a2504..7f29f8a55 100644 --- a/docker/Dockerfile.ngc.vllm +++ b/docker/Dockerfile.ngc.vllm @@ -3,12 +3,12 @@ FROM nvcr.io/nvidia/pytorch:24.05-py3 # uninstall nv-pytorch fork RUN pip3 uninstall pytorch-quantization \ - pytorch-triton \ - torch \ - torch-tensorrt \ - torchvision \ - xgboost transformer_engine flash_attn \ - apex megatron-core -y + pytorch-triton \ + torch \ + torch-tensorrt \ + torchvision \ + xgboost transformer_engine flash_attn \ + apex megatron-core -y RUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 @@ -38,7 +38,7 @@ RUN pip3 install --no-cache-dir \ 'wandb' # full dependencies -RUN pip3 install pytest yapf py-spy pyext liger-kernel +RUN pip3 install pytest pre-commit py-spy pyext liger-kernel # =============== Megatron dependencies (optional) ================= # install Transformer Engine, which requires FA 2.5.8. Do it in a separate step for docker cache diff --git a/docker/Dockerfile.ngc.vllm0.8 b/docker/Dockerfile.ngc.vllm0.8 index df70cd673..127839fe7 100644 --- a/docker/Dockerfile.ngc.vllm0.8 +++ b/docker/Dockerfile.ngc.vllm0.8 @@ -51,7 +51,7 @@ RUN pip install --no-cache-dir "vllm==0.8.3" "torch==2.6.0" "torchvision==0.21.0 "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest yapf py-spy pyext pre-commit ruff + pytest py-spy pyext pre-commit ruff # Install flash-attn-2.7.4.post1 (cxx11abi=False) RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ diff --git a/docker/Dockerfile.ngc.vllm0.8.sagemaker b/docker/Dockerfile.ngc.vllm0.8.sagemaker index b9c458d0b..d14cf725e 100644 --- a/docker/Dockerfile.ngc.vllm0.8.sagemaker +++ b/docker/Dockerfile.ngc.vllm0.8.sagemaker @@ -27,7 +27,7 @@ RUN apt-get update && \ RUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata==0.11.0 \ transformers>=4.49.0 accelerate datasets peft hf-transfer \ ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest yapf py-spy pyext pre-commit ruff + pytest pre-commit py-spy pyext pre-commit ruff # Install flash_attn-2.7.4.post1 RUN pip uninstall -y transformer-engine flash-attn && \ diff --git a/docker/Dockerfile.sglang b/docker/Dockerfile.sglang index c5ab53c49..8fe3b7d5f 100644 --- a/docker/Dockerfile.sglang +++ b/docker/Dockerfile.sglang @@ -43,7 +43,7 @@ RUN pip install "sglang[all]==0.4.4.post4" --no-cache-dir --find-links https://f RUN pip install --no-cache-dir torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \ transformers>=4.49.0 accelerate datasets peft hf_transfer \ ray codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel \ - pytest yapf py-spy pyext + pytest pre-commit py-spy pyext # Install flash_attn-2.7.4.post1 RUN pip uninstall -y transformer-engine flash-attn && \ diff --git a/docs/conf.py b/docs/conf.py index 24a488d36..796ca160c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,43 +31,43 @@ # -- Project information ----------------------------------------------------- -project = u'verl' -# pylint: disable=W0622 -copyright = u'2024 ByteDance Seed Foundation MLSys Team' -author = u'Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin' +project = "verl" +copyright = "2024 ByteDance Seed Foundation MLSys Team" +author = "Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin" # -- General configuration --------------------------------------------------- # The master toctree document. -master_doc = 'index' +master_doc = "index" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['recommonmark', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.autosectionlabel', +extensions = [ + "recommonmark", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.autosectionlabel", ] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: -source_suffix = ['.rst', 'rest', '.md'] +source_suffix = [".rst", "rest", ".md"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = u'en' +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -75,9 +75,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] \ No newline at end of file +html_static_path = ["_static"] diff --git a/docs/index.rst b/docs/index.rst index 0cf9b1aaa..122af513d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -113,14 +113,35 @@ verl is free software; you can redistribute it and/or modify it under the terms of the Apache License 2.0. We welcome contributions. Join us on `GitHub `_, `Slack `_ and `Wechat `_ for discussions. -Code formatting -^^^^^^^^^^^^^^^^^^^^^^^^ -We use yapf (Google style) to enforce strict code formatting when reviewing MRs. Run yapf at the top level of verl repo: +Contributions from the community are welcome! Please check out our `project roadmap `_ and `good first issues `_ to see where you can contribute. + +Code Linting and Formatting +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. warning:: + We are `immigrating to ``ruff`` as the linter and formatter and ``pre-commit`` as the managing tool `_. + + If your branch is based on a previous commit using ``yapf`` and ``pylint``, simply merging might trigger overwhelming linting errors, while **you are only expected to resolve ones in the files related to your PR**. + + To resolve this issue, please try the following workaround to only include the files you **really changed** in the PR: + + 1. In your branch, fix linting and format with ``ruff``: ``ruff check --fix && ruff-format`` + 2. Squash into a new single commit: ``git reset --soft $(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`` + 3. Merge with the latest main: ``git merge origin/main`` + 4. Force push to your branch: ``git push --force`` + +We use pre-commit to help improve code quality. To initialize pre-commit, run: .. code-block:: bash - pip3 install yapf - yapf -ir -vv --style ./.style.yapf verl examples tests + pip install pre-commit + pre-commit install + +You can also manually run pre-commit by: + +.. code-block:: bash + + pre-commit run Adding CI tests ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -129,4 +150,6 @@ If possible, please add CI test(s) for your new feature: 1. Find the most relevant workflow yml file, which usually corresponds to a ``hydra`` default config (e.g. ``ppo_trainer``, ``ppo_megatron_trainer``, ``sft_trainer``, etc). 2. Add related path patterns to the ``paths`` section if not already included. -3. Minimize the workload of the test script(s) (see existing scripts for examples). \ No newline at end of file +3. Minimize the workload of the test script(s) (see existing scripts for examples). + +We are HIRING! Send us an `email `_ if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment. diff --git a/examples/data_preprocess/full_hh_rlhf.py b/examples/data_preprocess/full_hh_rlhf.py index 07e0884cb..10a0aa9d7 100644 --- a/examples/data_preprocess/full_hh_rlhf.py +++ b/examples/data_preprocess/full_hh_rlhf.py @@ -16,131 +16,124 @@ - All the training data is used to train SFT and RL. - Both chosen and rejected is used to train SFT """ + import argparse import os import pandas as pd from datasets import load_dataset - from tqdm.auto import tqdm from verl.utils.fs import copy, makedirs -def generate_sft_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlh/sft'): - dataset = load_dataset('Dahoas/full-hh-rlhf') - output = {'prompt': [], 'response': []} - for data in tqdm(dataset['train']): +def generate_sft_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/sft"): + dataset = load_dataset("Dahoas/full-hh-rlhf") + output = {"prompt": [], "response": []} + for data in tqdm(dataset["train"]): # add chosen - output['prompt'].append(data['prompt']) - output['response'].append(data['chosen']) + output["prompt"].append(data["prompt"]) + output["response"].append(data["chosen"]) # add rejection - output['prompt'].append(data['prompt']) - output['response'].append(data['rejected']) + output["prompt"].append(data["prompt"]) + output["response"].append(data["rejected"]) df = pd.DataFrame(output) local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) - local_path = os.path.join(local_dir, 'train.parquet') + local_path = os.path.join(local_dir, "train.parquet") df.to_parquet(path=local_path) if target_hdfs_path_dir is not None: - hdfs_dir = target_hdfs_path_dir + '/' + 'train.parquet' + hdfs_dir = target_hdfs_path_dir + "/" + "train.parquet" makedirs(hdfs_dir) copy(local_path, hdfs_dir) -def generate_rm_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlh/rm'): - train_dataset = load_dataset('Dahoas/full-hh-rlhf', split='train[:75%]') - test_dataset = load_dataset('Dahoas/full-hh-rlhf', split='train[-25%:]') +def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm"): + train_dataset = load_dataset("Dahoas/full-hh-rlhf", split="train[:75%]") + test_dataset = load_dataset("Dahoas/full-hh-rlhf", split="train[-25%:]") local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) - for dataset, name in zip([train_dataset, test_dataset], ['train', 'test']): - output = {'prompt': [], 'chosen': [], 'rejected': []} + for dataset, name in zip([train_dataset, test_dataset], ["train", "test"]): + output = {"prompt": [], "chosen": [], "rejected": []} for data in tqdm(dataset): # add chosen - output['prompt'].append(data['prompt']) - output['chosen'].append(data['chosen']) - output['rejected'].append(data['rejected']) + output["prompt"].append(data["prompt"]) + output["chosen"].append(data["chosen"]) + output["rejected"].append(data["rejected"]) df = pd.DataFrame(output) - local_path = os.path.join(local_dir, name + '.parquet') + local_path = os.path.join(local_dir, name + ".parquet") df.to_parquet(path=local_path) if target_hdfs_path_dir is not None: - hdfs_dir = target_hdfs_path_dir + '/' + name + '.parquet' + hdfs_dir = target_hdfs_path_dir + "/" + name + ".parquet" makedirs(hdfs_dir) copy(local_path, hdfs_dir) -def generate_rl_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlhf/rl'): - dataset = load_dataset('Dahoas/full-hh-rlhf') - train_dataset = dataset['train'] +def generate_rl_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlhf/rl"): + dataset = load_dataset("Dahoas/full-hh-rlhf") + train_dataset = dataset["train"] - data_source = 'Dahoas/full-hh-rlhf' + data_source = "Dahoas/full-hh-rlhf" # add a row to each data item that represents a unique id def make_map_fn(split): - def process_fn(example, idx): - prompt = example.pop('prompt') - response = example.pop('response') + prompt = example.pop("prompt") + response = example.pop("response") data = { "data_source": data_source, - "prompt": [{ - "role": "user", - "content": prompt - }], + "prompt": [{"role": "user", "content": prompt}], "ability": "alignment", "reward_model": { "style": "model", - "ground_truth": response # should not be used + "ground_truth": response, # should not be used }, - "extra_info": { - 'split': split, - 'index': idx - } + "extra_info": {"split": split, "index": idx}, } return data return process_fn - train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) local_dir = os.path.expanduser(local_dir) - local_path = os.path.join(local_dir, 'train.parquet') + local_path = os.path.join(local_dir, "train.parquet") train_dataset.to_parquet(local_path) if target_hdfs_path_dir is not None: - hdfs_dir = target_hdfs_path_dir + '/' + 'train.parquet' + hdfs_dir = target_hdfs_path_dir + "/" + "train.parquet" makedirs(hdfs_dir) copy(local_path, hdfs_dir) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--split', type=str, choices=['sft', 'rm', 'rl'], required=True) - parser.add_argument('--local_dir', type=str, default='~/data/full_hh_rlhf') - parser.add_argument('--hdfs_dir', type=str, required=False, default=None) + parser.add_argument("--split", type=str, choices=["sft", "rm", "rl"], required=True) + parser.add_argument("--local_dir", type=str, default="~/data/full_hh_rlhf") + parser.add_argument("--hdfs_dir", type=str, required=False, default=None) args = parser.parse_args() - if args.split == 'sft': + if args.split == "sft": generate_sft_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split)) - elif args.split == 'rm': + elif args.split == "rm": generate_rm_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split)) - elif args.split == 'rl': + elif args.split == "rl": generate_rl_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split)) else: raise NotImplementedError diff --git a/examples/data_preprocess/geo3k.py b/examples/data_preprocess/geo3k.py index 3a0c77673..eb6a388fe 100644 --- a/examples/data_preprocess/geo3k.py +++ b/examples/data_preprocess/geo3k.py @@ -15,71 +15,70 @@ Preprocess the Geometry3k dataset to parquet format """ +import argparse import os + import datasets from verl.utils.hdfs_io import copy, makedirs -import argparse -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='~/data/geo3k') - parser.add_argument('--hdfs_dir', default=None) + parser.add_argument("--local_dir", default="~/data/geo3k") + parser.add_argument("--hdfs_dir", default=None) args = parser.parse_args() - data_source = 'hiyouga/geometry3k' + data_source = "hiyouga/geometry3k" dataset = datasets.load_dataset(data_source) - train_dataset = dataset['train'] - test_dataset = dataset['test'] + train_dataset = dataset["train"] + test_dataset = dataset["test"] instruction_following = ( - r'You FIRST think about the reasoning process as an internal monologue and then provide the final answer. ' - r'The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.' + r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. " + r"The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}." ) # add a row to each data item that represents a unique id def make_map_fn(split): - def process_fn(example, idx): - problem = example.pop('problem') - prompt = problem + ' ' + instruction_following - answer = example.pop('answer') - images = example.pop('images') + problem = example.pop("problem") + prompt = problem + " " + instruction_following + answer = example.pop("answer") + images = example.pop("images") data = { "data_source": data_source, - "prompt": [{ - "role": "user", - "content": prompt, - }], + "prompt": [ + { + "role": "user", + "content": prompt, + } + ], "images": images, "ability": "math", - "reward_model": { - "style": "rule", - "ground_truth": answer - }, + "reward_model": {"style": "rule", "ground_truth": answer}, "extra_info": { - 'split': split, - 'index': idx, - 'answer': answer, + "split": split, + "index": idx, + "answer": answer, "question": problem, - } + }, } return data return process_fn - train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True, num_proc=8) - test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True, num_proc=8) + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8) local_dir = args.local_dir hdfs_dir = args.hdfs_dir - train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) - test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) if hdfs_dir is not None: makedirs(hdfs_dir) diff --git a/examples/data_preprocess/gsm8k.py b/examples/data_preprocess/gsm8k.py index b82d7d71a..f39c4f09e 100644 --- a/examples/data_preprocess/gsm8k.py +++ b/examples/data_preprocess/gsm8k.py @@ -15,78 +15,77 @@ Preprocess the GSM8k dataset to parquet format """ -import re +import argparse import os +import re + import datasets from verl.utils.hdfs_io import copy, makedirs -import argparse def extract_solution(solution_str): solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) assert solution is not None final_solution = solution.group(0) - final_solution = final_solution.split('#### ')[1].replace(',', '') + final_solution = final_solution.split("#### ")[1].replace(",", "") return final_solution -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='~/data/gsm8k') - parser.add_argument('--hdfs_dir', default=None) + parser.add_argument("--local_dir", default="~/data/gsm8k") + parser.add_argument("--hdfs_dir", default=None) args = parser.parse_args() - data_source = 'openai/gsm8k' + data_source = "openai/gsm8k" - dataset = datasets.load_dataset(data_source, 'main') + dataset = datasets.load_dataset(data_source, "main") - train_dataset = dataset['train'] - test_dataset = dataset['test'] + train_dataset = dataset["train"] + test_dataset = dataset["test"] - instruction_following = "Let's think step by step and output the final answer after \"####\"." + instruction_following = 'Let\'s think step by step and output the final answer after "####".' # add a row to each data item that represents a unique id def make_map_fn(split): - def process_fn(example, idx): - question_raw = example.pop('question') + question_raw = example.pop("question") - question = question_raw + ' ' + instruction_following + question = question_raw + " " + instruction_following - answer_raw = example.pop('answer') + answer_raw = example.pop("answer") solution = extract_solution(answer_raw) data = { "data_source": data_source, - "prompt": [{ - "role": "user", - "content": question, - }], + "prompt": [ + { + "role": "user", + "content": question, + } + ], "ability": "math", - "reward_model": { - "style": "rule", - "ground_truth": solution - }, + "reward_model": {"style": "rule", "ground_truth": solution}, "extra_info": { - 'split': split, - 'index': idx, - 'answer': answer_raw, + "split": split, + "index": idx, + "answer": answer_raw, "question": question_raw, - } + }, } return data return process_fn - train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) local_dir = args.local_dir hdfs_dir = args.hdfs_dir - train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) - test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) if hdfs_dir is not None: makedirs(hdfs_dir) diff --git a/examples/data_preprocess/hellaswag.py b/examples/data_preprocess/hellaswag.py index 39c8e8f55..1b3f20080 100644 --- a/examples/data_preprocess/hellaswag.py +++ b/examples/data_preprocess/hellaswag.py @@ -16,12 +16,13 @@ Preprocess Hellaswag dataset. """ -import re +import argparse import os +import re + import datasets from verl.utils.hdfs_io import copy, makedirs -import argparse def preprocess(text): @@ -33,25 +34,24 @@ def preprocess(text): return text -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='/opt/tiger/hellaswag') - parser.add_argument('--hdfs_dir', default=None) + parser.add_argument("--local_dir", default="/opt/tiger/hellaswag") + parser.add_argument("--hdfs_dir", default=None) args = parser.parse_args() - data_source = 'Rowan/hellaswag' + data_source = "Rowan/hellaswag" dataset = datasets.load_dataset(data_source, trust_remote_code=True) - train_dataset = dataset['train'] - val_dataset = dataset['validation'] - test_dataset = dataset['test'] + train_dataset = dataset["train"] + val_dataset = dataset["validation"] + test_dataset = dataset["test"] - instruction = 'Please complete the following sentence.\n' + instruction = "Please complete the following sentence.\n" def make_map_fn(split): - def process_fn(doc, idx): ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() query = preprocess(doc["activity_label"] + ": " + ctx) @@ -60,41 +60,35 @@ if __name__ == '__main__': data = { "data_source": data_source, - "prompt": [{ - "role": "user", - "content": query - }], + "prompt": [{"role": "user", "content": query}], "ability": "nlp", "reward_model": { "style": "model", "eval": "multiple_choice", # using loglikelihood "ground_truth": gold, - "choices": choices + "choices": choices, }, - "extra_info": { - 'split': split, - 'index': idx - } + "extra_info": {"split": split, "index": idx}, } return data return process_fn # filter data that doesn't have a label - train_dataset = train_dataset.filter(lambda x: len(x['label']) > 0) - val_dataset = val_dataset.filter(lambda x: len(x['label']) > 0) - test_dataset = test_dataset.filter(lambda x: len(x['label']) > 0) + train_dataset = train_dataset.filter(lambda x: len(x["label"]) > 0) + val_dataset = val_dataset.filter(lambda x: len(x["label"]) > 0) + test_dataset = test_dataset.filter(lambda x: len(x["label"]) > 0) - train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) - val_dataset = val_dataset.map(function=make_map_fn('validation'), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + val_dataset = val_dataset.map(function=make_map_fn("validation"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) local_dir = args.local_dir hdfs_dir = args.hdfs_dir - train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) - val_dataset.to_parquet(os.path.join(local_dir, 'validation.parquet')) - test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + val_dataset.to_parquet(os.path.join(local_dir, "validation.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) if hdfs_dir is not None: makedirs(hdfs_dir) diff --git a/examples/data_preprocess/math_dataset.py b/examples/data_preprocess/math_dataset.py index 632f28801..e2e5d3524 100644 --- a/examples/data_preprocess/math_dataset.py +++ b/examples/data_preprocess/math_dataset.py @@ -15,75 +15,65 @@ Preprocess the MATH-lighteval dataset to parquet format """ +import argparse import os + import datasets from verl.utils.hdfs_io import copy, makedirs -import argparse - -from verl.utils.reward_score.math import remove_boxed, last_boxed_only_string +from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed def extract_solution(solution_str): return remove_boxed(last_boxed_only_string(solution_str)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='~/data/math') - parser.add_argument('--hdfs_dir', default=None) + parser.add_argument("--local_dir", default="~/data/math") + parser.add_argument("--hdfs_dir", default=None) args = parser.parse_args() # 'lighteval/MATH' is no longer available on huggingface. # Use mirror repo: DigitalLearningGmbH/MATH-lighteval - data_source = 'DigitalLearningGmbH/MATH-lighteval' + data_source = "DigitalLearningGmbH/MATH-lighteval" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = datasets.load_dataset(data_source, trust_remote_code=True) - train_dataset = dataset['train'] - test_dataset = dataset['test'] + train_dataset = dataset["train"] + test_dataset = dataset["test"] instruction_following = "Let's think step by step and output the final answer within \\boxed{}." # add a row to each data item that represents a unique id def make_map_fn(split): - def process_fn(example, idx): - question = example.pop('problem') + question = example.pop("problem") - question = question + ' ' + instruction_following + question = question + " " + instruction_following - answer = example.pop('solution') + answer = example.pop("solution") solution = extract_solution(answer) data = { "data_source": data_source, - "prompt": [{ - "role": "user", - "content": question - }], + "prompt": [{"role": "user", "content": question}], "ability": "math", - "reward_model": { - "style": "rule", - "ground_truth": solution - }, - "extra_info": { - 'split': split, - 'index': idx - } + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, } return data return process_fn - train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) local_dir = args.local_dir hdfs_dir = args.hdfs_dir - train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) - test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) if hdfs_dir is not None: makedirs(hdfs_dir) diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py index 98407bec2..f312cfbf4 100644 --- a/examples/data_preprocess/multiturn.py +++ b/examples/data_preprocess/multiturn.py @@ -15,87 +15,71 @@ Create a simple multi-turn dataset for testing """ -import os -import pandas as pd import argparse +import os + +import pandas as pd def main(): parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='~/data/multiturn') - parser.add_argument('--hdfs_dir', default=None) + parser.add_argument("--local_dir", default="~/data/multiturn") + parser.add_argument("--hdfs_dir", default=None) args = parser.parse_args() # Create example conversations conversations = [] # Conversation 1 - conversations.append({ - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" - }, { - "role": "assistant", - "content": "The capital of France is Paris." - }, { - "role": "user", - "content": "And what about Germany?" - }, { - "role": "assistant", - "content": "The capital of Germany is Berlin." - }] - }) + conversations.append( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "And what about Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + ] + } + ) # Conversation 2 - conversations.append({ - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Can you explain quantum computing?" - }, { - "role": - "assistant", - "content": - "Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data." - }, { - "role": "user", - "content": "How is it different from classical computing?" - }, { - "role": - "assistant", - "content": - "Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition." - }] - }) + conversations.append( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Can you explain quantum computing?"}, + { + "role": "assistant", + "content": "Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data.", + }, + {"role": "user", "content": "How is it different from classical computing?"}, + { + "role": "assistant", + "content": "Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition.", + }, + ] + } + ) # Conversation 3 - conversations.append({ - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Write a simple Python function to calculate factorial." - }, { - "role": - "assistant", - "content": - "```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number." - }, { - "role": "user", - "content": "Can you make it iterative instead?" - }, { - "role": - "assistant", - "content": - "```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function." - }] - }) + conversations.append( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Write a simple Python function to calculate factorial."}, + { + "role": "assistant", + "content": "```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number.", + }, + {"role": "user", "content": "Can you make it iterative instead?"}, + { + "role": "assistant", + "content": "```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function.", + }, + ] + } + ) # Create train and test datasets train_data = conversations[:2] # First 2 conversations for training @@ -109,13 +93,14 @@ def main(): train_df = pd.DataFrame(train_data) test_df = pd.DataFrame(test_data) - train_df.to_parquet(os.path.join(local_dir, 'train.parquet')) - test_df.to_parquet(os.path.join(local_dir, 'test.parquet')) + train_df.to_parquet(os.path.join(local_dir, "train.parquet")) + test_df.to_parquet(os.path.join(local_dir, "test.parquet")) # Handle HDFS if specified if args.hdfs_dir is not None: try: from verl.utils.hdfs_io import copy, makedirs + makedirs(args.hdfs_dir) copy(src=local_dir, dst=args.hdfs_dir) except ImportError: @@ -127,5 +112,5 @@ def main(): print(f"Data saved to {local_dir}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/ppo_trainer/verl_getting_started.ipynb b/examples/ppo_trainer/verl_getting_started.ipynb index 73df079ea..297e6addf 100644 --- a/examples/ppo_trainer/verl_getting_started.ipynb +++ b/examples/ppo_trainer/verl_getting_started.ipynb @@ -313,6 +313,7 @@ "outputs": [], "source": [ "import torch\n", + "\n", "try:\n", " assert torch.cuda.is_available() is True\n", " torch.ones(1, dtype=torch.bfloat16).cuda()\n", @@ -320,12 +321,10 @@ " print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n", "\n", "try:\n", - " import verl\n", + " pass\n", "except Exception as e:\n", " print(\"Please install verl via pip and restart the kernel\")\n", - " raise e\n", - "\n", - "import flash_attn" + " raise e" ] }, { @@ -560,6 +559,7 @@ ], "source": [ "import inspect\n", + "\n", "from verl.utils.reward_score.gsm8k import compute_score as gsm8k_reward\n", "\n", "print(inspect.getsource(gsm8k_reward))" diff --git a/examples/ray/tutorial.ipynb b/examples/ray/tutorial.ipynb index 37784f6f9..58c9a3302 100644 --- a/examples/ray/tutorial.ipynb +++ b/examples/ray/tutorial.ipynb @@ -37,10 +37,12 @@ }, "outputs": [], "source": [ + "import warnings\n", + "\n", "import ray\n", "import torch\n", - "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -146,10 +148,10 @@ "class Accumulator:\n", " def __init__(self):\n", " self.value = 0\n", - " \n", + "\n", " def add(self, x):\n", " self.value += x\n", - " \n", + "\n", " def get_value(self):\n", " return self.value" ] @@ -184,7 +186,7 @@ } ], "source": [ - "value_ref = accumulator.get_value.remote() # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n", + "value_ref = accumulator.get_value.remote() # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n", "# Get the value\n", "value = ray.get(value_ref)\n", "print(value)" @@ -232,8 +234,8 @@ }, "outputs": [], "source": [ - "from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n", - "from verl.single_controller.base import Worker" + "from verl.single_controller.base import Worker\n", + "from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool" ] }, { @@ -259,16 +261,15 @@ "source": [ "@ray.remote\n", "class GPUAccumulator(Worker):\n", - "\n", " def __init__(self) -> None:\n", " super().__init__()\n", " # The initial value of each rank is the same as the rank\n", - " self.value = torch.zeros(size=(1,), device='cuda') + self.rank\n", + " self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n", "\n", " def add(self, x):\n", " self.value += x\n", - " print(f'rank {self.rank}, value: {self.value}')\n", - " return self.value.cpu()\n" + " print(f\"rank {self.rank}, value: {self.value}\")\n", + " return self.value.cpu()" ] }, { @@ -291,7 +292,7 @@ "# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\n", "class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\n", "worker_group = RayWorkerGroup(resource_pool, class_with_args)\n", - "print(worker_group.execute_all_sync('add', x=[1,1,1,1]))" + "print(worker_group.execute_all_sync(\"add\", x=[1, 1, 1, 1]))" ] }, { @@ -329,7 +330,7 @@ "outputs": [], "source": [ "# Create a new resource pool and then merge the newly created resource pool with the previous one.\n", - "resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix='a')\n", + "resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix=\"a\")\n", "resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)" ] }, @@ -365,7 +366,7 @@ ], "source": [ "# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\n", - "output_1 = worker_group_1.execute_all_sync('add', x=[2,2,2,2])\n", + "output_1 = worker_group_1.execute_all_sync(\"add\", x=[2, 2, 2, 2])\n", "print(output_1)" ] }, @@ -387,7 +388,7 @@ ], "source": [ "# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\n", - "output_merge = worker_group_merge.execute_all_sync('add', x=[3,3,3,3,3,3,3,3])\n", + "output_merge = worker_group_merge.execute_all_sync(\"add\", x=[3, 3, 3, 3, 3, 3, 3, 3])\n", "print(output_merge)" ] }, @@ -437,7 +438,7 @@ }, "outputs": [], "source": [ - "from verl.single_controller.base.decorator import register, Dispatch, Execute" + "from verl.single_controller.base.decorator import Dispatch, Execute, register" ] }, { @@ -451,18 +452,17 @@ "source": [ "@ray.remote\n", "class GPUAccumulatorDecorator(Worker):\n", - "\n", " def __init__(self) -> None:\n", " super().__init__()\n", " # The initial value of each rank is the same as the rank\n", - " self.value = torch.zeros(size=(1,), device='cuda') + self.rank\n", - " \n", + " self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n", + "\n", " # map from a single input to all the worker\n", " @register(Dispatch.ONE_TO_ALL)\n", " def add(self, x):\n", " print(x)\n", " self.value = self.value + x\n", - " print(f'rank {self.rank}, value: {self.value}')\n", + " print(f\"rank {self.rank}, value: {self.value}\")\n", " return self.value.cpu()" ] }, @@ -518,7 +518,7 @@ }, "outputs": [], "source": [ - "from verl.single_controller.base.decorator import register, Dispatch, collect_all_to_all, Execute" + "from verl.single_controller.base.decorator import Dispatch, collect_all_to_all, register" ] }, { @@ -559,7 +559,7 @@ " def foo_rank_zero(self, x, y):\n", " return self._x + y + x\n", "\n", - " @register(dispatch_mode={'dispatch_fn': two_to_all_dispatch_fn, 'collect_fn': collect_all_to_all})\n", + " @register(dispatch_mode={\"dispatch_fn\": two_to_all_dispatch_fn, \"collect_fn\": collect_all_to_all})\n", " def foo_custom(self, x, y):\n", " return self._x + y + x" ] @@ -691,26 +691,24 @@ } ], "source": [ - "import os\n", "import sys\n", - "import site\n", "\n", + "current_pythonpath = os.environ.get(\"PYTHONPATH\", \"\")\n", "\n", - "current_pythonpath = os.environ.get('PYTHONPATH', '')\n", - "\n", - "new_path = '/opt/tiger/Megatron-LM'\n", + "new_path = \"/opt/tiger/Megatron-LM\"\n", "\n", "if current_pythonpath:\n", - " new_pythonpath = f'{new_path}:{current_pythonpath}'\n", + " new_pythonpath = f\"{new_path}:{current_pythonpath}\"\n", "else:\n", " new_pythonpath = new_path\n", "\n", - "os.environ['PYTHONPATH'] = new_pythonpath\n", + "os.environ[\"PYTHONPATH\"] = new_pythonpath\n", "\n", "print(new_path)\n", "sys.path.append(new_path)\n", "\n", "import megatron\n", + "\n", "print(megatron.__file__)" ] }, @@ -723,12 +721,13 @@ }, "outputs": [], "source": [ - "from verl.single_controller.base.decorator import register, Dispatch, Execute\n", - "from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n", - "from verl.single_controller.base.megatron.worker import MegatronWorker\n", - "from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n", + "from megatron.core import parallel_state as mpu\n", "from omegaconf import OmegaConf\n", - "from megatron.core import parallel_state as mpu" + "\n", + "from verl.single_controller.base.decorator import Dispatch, Execute, register\n", + "from verl.single_controller.base.megatron.worker import MegatronWorker\n", + "from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n", + "from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup" ] }, { @@ -756,52 +755,56 @@ "class MLPLayerWorker(MegatronWorker):\n", " def __init__(self):\n", " super().__init__()\n", - " rank = int(os.environ['LOCAL_RANK'])\n", + " rank = int(os.environ[\"LOCAL_RANK\"])\n", " torch.distributed.init_process_group(backend=\"nccl\")\n", " torch.cuda.set_device(rank)\n", "\n", " mpu.initialize_model_parallel(\n", - " tensor_model_parallel_size=4,\n", - " pipeline_model_parallel_size=1,\n", - " virtual_pipeline_model_parallel_size=None,\n", - " pipeline_model_parallel_split_rank=None,\n", - " use_sharp=False,\n", - " context_parallel_size=1,\n", - " expert_model_parallel_size=1,\n", - " nccl_communicator_config_path=None,\n", - " )\n", + " tensor_model_parallel_size=4,\n", + " pipeline_model_parallel_size=1,\n", + " virtual_pipeline_model_parallel_size=None,\n", + " pipeline_model_parallel_split_rank=None,\n", + " use_sharp=False,\n", + " context_parallel_size=1,\n", + " expert_model_parallel_size=1,\n", + " nccl_communicator_config_path=None,\n", + " )\n", " from megatron.core import tensor_parallel\n", - " tensor_parallel.model_parallel_cuda_manual_seed(10)\n", "\n", + " tensor_parallel.model_parallel_cuda_manual_seed(10)\n", "\n", " @register(Dispatch.ONE_TO_ALL)\n", " def init_model(self, config):\n", " from omegaconf import OmegaConf\n", - " from verl.utils.megatron_utils import init_model_parallel_config\n", + "\n", " from verl.models.llama.megatron.layers import ParallelLlamaMLP\n", - " megatron_config = OmegaConf.create({\n", - " 'sequence_parallel': False,\n", - " 'param_dtype': 'fp32',\n", - " 'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),\n", - " 'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),\n", - " 'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),\n", - " 'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),\n", - " 'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()\n", - " })\n", + " from verl.utils.megatron_utils import init_model_parallel_config\n", + "\n", + " megatron_config = OmegaConf.create(\n", + " {\n", + " \"sequence_parallel\": False,\n", + " \"param_dtype\": \"fp32\",\n", + " \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n", + " \"pipeline_model_parallel_rank\": mpu.get_pipeline_model_parallel_rank(),\n", + " \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n", + " \"virtual_pipeline_model_parallel_rank\": mpu.get_virtual_pipeline_model_parallel_rank(),\n", + " \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n", + " }\n", + " )\n", "\n", " megatron_config = init_model_parallel_config(megatron_config)\n", " self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\n", - " \n", + "\n", " @register(Dispatch.ONE_TO_ALL)\n", " def get_weights(self):\n", " output = {}\n", " for key, val in self.parallel_layer.named_parameters():\n", " output[key] = val\n", " return output\n", - " \n", + "\n", " @register(Dispatch.MEGATRON_COMPUTE)\n", " def run_layer(self, x):\n", - " x = x.to('cuda')\n", + " x = x.to(\"cuda\")\n", " y = self.parallel_layer(x)\n", " return y" ] @@ -816,9 +819,10 @@ "outputs": [], "source": [ "layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\n", - "layer_worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool,\n", - " ray_cls_with_init=layer_cls,\n", - " )\n" + "layer_worker_group = NVMegatronRayWorkerGroup(\n", + " resource_pool=resource_pool,\n", + " ray_cls_with_init=layer_cls,\n", + ")" ] }, { @@ -855,13 +859,15 @@ "seq_len = 2048\n", "hidden_size = 4096\n", "\n", - "config = OmegaConf.create({\n", - " 'hidden_size': hidden_size,\n", - " 'intermediate_size': ffn_hidden_size,\n", - " 'hidden_act': 'silu',\n", - " 'pretraining_tp': 1,\n", - " 'tp': layer_worker_group.tp_size,\n", - "})" + "config = OmegaConf.create(\n", + " {\n", + " \"hidden_size\": hidden_size,\n", + " \"intermediate_size\": ffn_hidden_size,\n", + " \"hidden_act\": \"silu\",\n", + " \"pretraining_tp\": 1,\n", + " \"tp\": layer_worker_group.tp_size,\n", + " }\n", + ")" ] }, { @@ -916,7 +922,9 @@ } ], "source": [ - "output = layer_worker_group.run_layer([x]) # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n", + "output = layer_worker_group.run_layer(\n", + " [x]\n", + ") # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n", "print(output[0].shape)" ] }, diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 85347be09..44ae088d4 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -15,23 +15,23 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ -from verl import DataProto import torch -from verl.utils.reward_score import gsm8k, math + +from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.reward_score import gsm8k, math def _select_rm_score_fn(data_source): - if data_source == 'openai/gsm8k': + if data_source == "openai/gsm8k": return gsm8k.compute_score - elif data_source == 'lighteval/MATH': + elif data_source == "lighteval/MATH": return math.compute_score else: raise NotImplementedError -class RewardManager(): - +class RewardManager: def __init__(self, tokenizer, num_examine) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console @@ -40,35 +40,35 @@ class RewardManager(): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] + if "rm_scores" in data.batch.keys(): + return data.batch["rm_scores"] - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) already_print_data_sources = {} for i in range(len(data)): data_item = data[i] # DataProtoItem - prompt_ids = data_item.batch['prompts'] + prompt_ids = data_item.batch["prompts"] prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] - response_ids = data_item.batch['responses'] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode sequences = torch.cat((valid_prompt_ids, valid_response_ids)) sequences_str = self.tokenizer.decode(sequences) - ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] # select rm_score - data_source = data_item.non_tensor_batch['data_source'] + data_source = data_item.non_tensor_batch["data_source"] compute_score_fn = _select_rm_score_fn(data_source) score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) @@ -87,28 +87,29 @@ class RewardManager(): return reward_tensor -import ray import hydra +import ray from split_monkey_patch import fit -@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None) +@hydra.main(config_path="config", config_name="ppo_trainer_split", version_base=None) def main(config): if not ray.is_initialized(): # this is for local ray cluster - ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) ray.get(main_task.remote(config)) @ray.remote def main_task(config): - from verl.utils.fs import copy_to_local - from transformers import AutoTokenizer - # print initial config from pprint import pprint + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) @@ -117,19 +118,22 @@ def main_task(config): # instantiate tokenizer from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) # define worker classes - if config.actor_rollout_ref.actor.strategy == 'fsdp': + if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = RayWorkerGroup - elif config.actor_rollout_ref.actor.strategy == 'megatron': + elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -143,8 +147,8 @@ def main_task(config): } # NOTE: initialze two resource pool - actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' - critic_pool_id = 'critic_pool' + actor_rollout_ref_pool_id = "actor_rollout_ref_pool" + critic_pool_id = "critic_pool" if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: resource_pool_spec = { actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, @@ -155,13 +159,13 @@ def main_task(config): actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), } - print(f'resource_pool_spec: {resource_pool_spec}') + print(f"resource_pool_spec: {resource_pool_spec}") mapping = { Role.ActorRollout: actor_rollout_ref_pool_id, Role.Critic: critic_pool_id, } - #use reference model + # use reference model if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = actor_rollout_ref_pool_id @@ -173,9 +177,9 @@ def main_task(config): # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == 'fsdp': + if config.reward_model.strategy == "fsdp": from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == 'megatron': + elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError @@ -190,16 +194,18 @@ def main_task(config): resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) RayPPOTrainer.fit = fit - trainer = RayPPOTrainer(config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) trainer.init_workers() trainer.fit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py index c6826771f..fa6eda0b9 100644 --- a/examples/split_placement/split_monkey_patch.py +++ b/examples/split_placement/split_monkey_patch.py @@ -14,13 +14,24 @@ """ An naive implementation of split placment example """ -from pprint import pprint -from verl import DataProto -from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics, AdvantageEstimator + +import uuid from copy import deepcopy +from pprint import pprint + import numpy as np import torch -import uuid + +from verl import DataProto +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + _timer, + apply_kl_penalty, + compute_advantage, + compute_data_metrics, + compute_timing_metrics, + reduce_metrics, +) def fit(self): @@ -29,13 +40,16 @@ def fit(self): The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ - from verl.utils.tracking import Tracking from omegaconf import OmegaConf - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True)) + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) self.global_steps = 0 @@ -44,11 +58,11 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') + pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get('val_only', False): + if self.config.trainer.get("val_only", False): return # we start from step 1 @@ -63,18 +77,18 @@ def fit(self): batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) is_last_step = self.global_steps >= self.total_training_steps - with _timer('step', timing_raw): + with _timer("step", timing_raw): # generate a batch - with _timer('gen', timing_raw): + with _timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer('gen_max', timing_raw): + with _timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) @@ -83,12 +97,13 @@ def fit(self): batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - batch.batch['reward_baselines'] = reward_baseline_tensor + batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], - dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -99,26 +114,26 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # recompute old_log_probs - with _timer('old_log_prob', timing_raw): + with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer('ref', timing_raw): + with _timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: - with _timer('values', timing_raw): + with _timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer('adv', timing_raw): + with _timer("adv", timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. @@ -129,57 +144,63 @@ def fit(self): # we combine with rule-based rm reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor + batch.batch["token_level_scores"] = reward_tensor # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty) + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) metrics.update(kl_metrics) else: - batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process - batch = compute_advantage(batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + ) # update critic if self.use_critic: - with _timer('update_critic_call', timing_raw): + with _timer("update_critic_call", timing_raw): critic_output = self.critic_wg.update_critic(batch) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer('update_actor_call', timing_raw): + with _timer("update_actor_call", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class - with _timer('update_actor_critic', timing_raw): + with _timer("update_actor_critic", timing_raw): critic_output = critic_output.get() - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) actor_output = actor_output.get() - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ - (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer('testing', timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with _timer("testing", timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (is_last_step or \ - self.global_steps % self.config.trainer.save_freq == 0): - with _timer('save_checkpoint', timing_raw): + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics @@ -190,7 +211,7 @@ def fit(self): logger.log(data=metrics, step=self.global_steps) if self.global_steps >= self.total_training_steps: - pprint(f'Final validation metrics: {last_val_metrics}') + pprint(f"Final validation metrics: {last_val_metrics}") return self.global_steps += 1 diff --git a/pyproject.toml b/pyproject.toml index f45c177e7..6e2e89f8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,45 @@ license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifie readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.8" +# ------------------------------- +# tool.ruff - Linting configuration +# ------------------------------- +[tool.ruff] +line-length = 120 + +# Enable import sorting + +[tool.ruff.lint] +isort = {known-first-party = ["verl"]} +# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + # ------------------------------- # tool.setuptools - Additional config # ------------------------------- @@ -45,117 +84,4 @@ version = {file = "verl/version/version"} verl = [ "version/*", "trainer/config/*.yaml" -] - - -[tool.pylint.message_control] -disable = [ - "abstract-method", - "anomalous-backslash-in-string", - "arguments-differ", - "arguments-renamed", - "assignment-from-none", - "attribute-defined-outside-init", - "bad-str-strip-call", - "bare-except", - "broad-exception-caught", - "broad-exception-raised", - "cell-var-from-loop", - "chained-comparison", - "consider-iterating-dictionary", - "consider-using-enumerate", - "consider-using-f-string", - "consider-using-from-import", - "consider-using-generator", - "consider-using-in", - "consider-using-max-builtin", - "consider-using-set-comprehension", - "consider-using-sys-exit", - "consider-using-with", - "cyclic-import", - "dangerous-default-value", - "duplicate-code", - "eval-used", - "expression-not-assigned", - "f-string-without-interpolation", - "fixme", - "function-redefined", - "global-statement", - "global-variable-not-assigned", - "import-error", - "import-outside-toplevel", - "import-self", - "inconsistent-return-statements", - "invalid-character-zero-width-space", - "invalid-name", - "line-too-long", - "logging-fstring-interpolation", - "logging-not-lazy", - "missing-class-docstring", - "missing-final-newline", - "missing-function-docstring", - "missing-module-docstring", - "multiple-imports", - "no-else-continue", - "no-else-raise", - "no-else-return", - "no-member", - "no-self-argument", - "no-value-for-parameter", - "not-an-iterable", - "not-callable", - "notimplemented-raised", - "pointless-exception-statement", - "pointless-string-statement", - "pointless-statement", - "possibly-used-before-assignment", - "protected-access", - "raise-missing-from", - "raising-format-tuple", - "redefined-argument-from-local", - "redefined-builtin", - "redefined-outer-name", - "redundant-u-string-prefix", - "reimported", - "simplifiable-if-expression", - "simplifiable-if-statement", - "singleton-comparison", - "super-init-not-called", - "superfluous-parens", - "too-few-public-methods", - "too-many-arguments", - "too-many-boolean-expressions", - "too-many-branches", - "too-many-instance-attributes", - "too-many-lines", - "too-many-locals", - "too-many-positional-arguments", - "too-many-return-statements", - "too-many-statements", - "trailing-newlines", - "trailing-newlines", - "trailing-whitespace", - "unbalanced-tuple-unpacking", - "undefined-loop-variable", - "undefined-variable", - "ungrouped-imports", - "unidiomatic-typecheck", - "unnecessary-comprehension", - "unnecessary-lambda", - "unnecessary-lambda-assignment", - "unnecessary-pass", - "unspecified-encoding", - "unused-argument", - "unused-import", - "unused-variable", - "unused-wildcard-import", - "use-a-generator", - "use-dict-literal", - "used-before-assignment", - "useless-object-inheritance", - "useless-parent-delegation", - "useless-return", - "wildcard-import", - "wrong-import-order", - "wrong-import-position", ] \ No newline at end of file diff --git a/recipe/dapo/src/dapo_ray_trainer.py b/recipe/dapo/src/dapo_ray_trainer.py index eb98270fe..620259a5f 100644 --- a/recipe/dapo/src/dapo_ray_trainer.py +++ b/recipe/dapo/src/dapo_ray_trainer.py @@ -17,17 +17,22 @@ This trainer supports model-agonistic model initialization with huggingface """ import uuid -from pprint import pprint -from copy import deepcopy from collections import defaultdict -from tqdm import tqdm +from copy import deepcopy +from pprint import pprint + import numpy as np import torch +from tqdm import tqdm from verl import DataProto -from verl.trainer.ppo.ray_trainer import RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, AdvantageEstimator -from verl.trainer.ppo.metric_utils import (compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, - reduce_metrics) +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) +from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage class RayDAPOTrainer(RayPPOTrainer): @@ -41,13 +46,16 @@ class RayDAPOTrainer(RayPPOTrainer): The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ - from verl.utils.tracking import Tracking from omegaconf import OmegaConf - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True)) + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) self.global_steps = 0 @@ -56,11 +64,11 @@ class RayDAPOTrainer(RayPPOTrainer): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') + pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get('val_only', False): + if self.config.trainer.get("val_only", False): return # add tqdm @@ -81,28 +89,28 @@ class RayDAPOTrainer(RayPPOTrainer): new_batch: DataProto = DataProto.from_single_dict(batch_dict) num_gen_batches += 1 # pop those keys for generation - if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys(): + if "multi_modal_inputs" in new_batch.non_tensor_batch.keys(): gen_batch = new_batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'], + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], ) else: gen_batch = new_batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=['raw_prompt_ids'], + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], ) is_last_step = self.global_steps >= self.total_training_steps - with _timer('step', timing_raw): + with _timer("step", timing_raw): # generate a batch - with _timer('gen', timing_raw): + with _timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer('gen_max', timing_raw): + with _timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) new_batch = new_batch.union(gen_baseline_output) @@ -111,17 +119,18 @@ class RayDAPOTrainer(RayPPOTrainer): new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - new_batch.batch['reward_baselines'] = reward_baseline_tensor + new_batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output - new_batch.non_tensor_batch['uid'] = np.array( - [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object) + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) new_batch = new_batch.union(gen_batch_output) - with _timer('reward', timing_raw): + with _timer("reward", timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. @@ -134,30 +143,31 @@ class RayDAPOTrainer(RayPPOTrainer): reward_extra_infos_dict: dict[str, list] try: reward_result = self.reward_fn(new_batch, return_dict=True) - reward_tensor = reward_result['reward_tensor'] - reward_extra_infos_dict = reward_result['reward_extra_info'] + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result["reward_extra_info"] except Exception as e: - print(f'Error in reward_fn: {e}') + print(f"Error in reward_fn: {e}") reward_tensor = self.reward_fn(new_batch) reward_extra_infos_dict = {} - new_batch.batch['token_level_scores'] = reward_tensor + new_batch.batch["token_level_scores"] = reward_tensor - print(f'{list(reward_extra_infos_dict.keys())=}') + print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: - new_batch.non_tensor_batch.update({ - k: np.array(v) for k, v in reward_extra_infos_dict.items() - }) + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - new_batch, kl_metrics = apply_kl_penalty(new_batch, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty) + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) metrics.update( - kl_metrics) # TODO: This will be cleared if we use multiple genenration batches + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches else: - new_batch.batch['token_level_rewards'] = new_batch.batch['token_level_scores'] + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] if not self.config.algorithm.filter_groups.enable: batch = new_batch @@ -165,16 +175,19 @@ class RayDAPOTrainer(RayPPOTrainer): metric_name = self.config.algorithm.filter_groups.metric if metric_name == "seq_final_reward": # Turn to numpy for easier filtering - new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch['token_level_rewards'].sum( - dim=-1).numpy() + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) elif metric_name == "seq_reward": - new_batch.non_tensor_batch["seq_reward"] = new_batch.batch['token_level_scores'].sum( - dim=-1).numpy() + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) # Collect the sequence reward for each trajectory prompt_uid2metric_vals = defaultdict(list) - for uid, metric_val in zip(new_batch.non_tensor_batch['uid'], - new_batch.non_tensor_batch[metric_name]): + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name] + ): prompt_uid2metric_vals[uid].append(metric_val) prompt_uid2metric_std = {} @@ -182,13 +195,14 @@ class RayDAPOTrainer(RayPPOTrainer): prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) kept_prompt_uids = [ - uid for uid, std in prompt_uid2metric_std.items() + uid + for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 ] num_prompt_in_batch += len(kept_prompt_uids) kept_traj_idxs = [] - for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']): + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): if traj_from_prompt_uid in kept_prompt_uids: kept_traj_idxs.append(idx) @@ -200,14 +214,14 @@ class RayDAPOTrainer(RayPPOTrainer): prompt_bsz = self.config.data.train_batch_size if num_prompt_in_batch < prompt_bsz: - print(f'{num_prompt_in_batch=} < {prompt_bsz=}') + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: - print(f'{num_gen_batches=}. Keep generating...') + print(f"{num_gen_batches=}. Keep generating...") continue else: raise ValueError( - f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' + f"{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data." ) else: # Align the batch @@ -221,60 +235,66 @@ class RayDAPOTrainer(RayPPOTrainer): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # recompute old_log_probs - with _timer('old_log_prob', timing_raw): + with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer('ref', timing_raw): + with _timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: - with _timer('values', timing_raw): + with _timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer('adv', timing_raw): + with _timer("adv", timing_raw): # compute advantages, executed on the driver process - batch = compute_advantage(batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + ) # update critic if self.use_critic: - with _timer('update_critic', timing_raw): + with _timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer('update_actor', timing_raw): + with _timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ - (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer('testing', timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with _timer("testing", timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (is_last_step or - self.global_steps % self.config.trainer.save_freq == 0): - with _timer('save_checkpoint', timing_raw): + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics @@ -294,7 +314,7 @@ class RayDAPOTrainer(RayPPOTrainer): logger.log(data=metrics, step=self.global_steps) if is_last_step: - pprint(f'Final validation metrics: {last_val_metrics}') + pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return diff --git a/recipe/dapo/src/main_dapo.py b/recipe/dapo/src/main_dapo.py index df96d4bbe..e4069f955 100644 --- a/recipe/dapo/src/main_dapo.py +++ b/recipe/dapo/src/main_dapo.py @@ -14,15 +14,18 @@ """ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ -from .dapo_ray_trainer import RayDAPOTrainer import os -import ray + import hydra +import ray + +from .dapo_ray_trainer import RayDAPOTrainer def get_custom_reward_fn(config): - import importlib.util, os + import importlib.util + import os reward_fn_config = config.get("custom_reward_function") or {} file_path = reward_fn_config.get("path") @@ -49,7 +52,7 @@ def get_custom_reward_fn(config): return getattr(module, function_name) -@hydra.main(config_path='config', config_name='dapo_trainer', version_base=None) +@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) def main(config): run_ppo(config) @@ -57,16 +60,14 @@ def main(config): def run_ppo(config) -> None: # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '') + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") if not ray.is_initialized(): # this is for local ray cluster - ray.init(runtime_env={ - 'env_vars': { - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN', - 'VLLM_LOGGING_LEVEL': 'WARN' + ray.init( + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} } - }) + ) runner = TaskRunner.remote() ray.get(runner.run.remote(config)) @@ -74,12 +75,14 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - def run(self, config): - from verl.utils.fs import copy_to_local # print initial config from pprint import pprint + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) @@ -87,21 +90,24 @@ class TaskRunner: local_path = copy_to_local(config.actor_rollout_ref.model.path) # instantiate tokenizer - from verl.utils import hf_tokenizer, hf_processor + from verl.utils import hf_processor, hf_tokenizer + tokenizer = hf_tokenizer(local_path) processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes - if config.actor_rollout_ref.actor.strategy == 'fsdp': + if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = RayWorkerGroup - elif config.actor_rollout_ref.actor.strategy == 'megatron': + elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -112,10 +118,10 @@ class TaskRunner: role_worker_mapping = { Role.ActorRollout: ray.remote(ActorRolloutRefWorker), Role.Critic: ray.remote(CriticWorker), - Role.RefPolicy: ray.remote(ActorRolloutRefWorker) + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), } - global_pool_id = 'global_pool' + global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } @@ -132,9 +138,9 @@ class TaskRunner: # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == 'fsdp': + if config.reward_model.strategy == "fsdp": from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == 'megatron': + elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError @@ -147,47 +153,55 @@ class TaskRunner: mapping[Role.RefPolicy] = global_pool_id reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == 'naive': + if reward_manager_name == "naive": from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager - elif reward_manager_name == 'prime': + elif reward_manager_name == "prime": from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager - elif reward_manager_name == 'dapo': + elif reward_manager_name == "dapo": from verl.workers.reward_manager import DAPORewardManager + reward_manager_cls = DAPORewardManager else: - raise NotImplementedError compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls(tokenizer=tokenizer, - num_examine=0, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, - max_resp_len=config.data.max_response_length, - overlong_buffer_cfg=config.reward_model.overlong_buffer) + reward_fn = reward_manager_cls( + tokenizer=tokenizer, + num_examine=0, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, - num_examine=1, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, - max_resp_len=config.data.max_response_length, - overlong_buffer_cfg=config.reward_model.overlong_buffer) + val_reward_fn = reward_manager_cls( + tokenizer=tokenizer, + num_examine=1, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - trainer = RayDAPOTrainer(config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) + trainer = RayDAPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) trainer.init_workers() trainer.fit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/recipe/prime/__init__.py b/recipe/prime/__init__.py index b1697c70a..6b76ea65c 100644 --- a/recipe/prime/__init__.py +++ b/recipe/prime/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py index 0bdae571a..3d69eaaf7 100644 --- a/recipe/prime/main_prime.py +++ b/recipe/prime/main_prime.py @@ -28,13 +28,14 @@ """ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ + +import hydra +import ray + from .prime_ray_trainer import RayPRIMETrainer -import ray -import hydra - -@hydra.main(config_path='config', config_name='prime_trainer', version_base=None) +@hydra.main(config_path="config", config_name="prime_trainer", version_base=None) def main(config): run_prime(config) @@ -43,10 +44,7 @@ def run_prime(config, compute_score=None): if not ray.is_initialized(): # this is for local ray cluster ray.init( - runtime_env={'env_vars': { - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN' - }}, + runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, ) ray.get(main_task.remote(config, compute_score)) @@ -54,10 +52,13 @@ def run_prime(config, compute_score=None): @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head def main_task(config, compute_score=None): - from verl.utils.fs import copy_local_path_from_hdfs # print initial config from pprint import pprint + from omegaconf import OmegaConf + + from verl.utils.fs import copy_local_path_from_hdfs + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) @@ -66,19 +67,22 @@ def main_task(config, compute_score=None): # instantiate tokenizer from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) # define worker classes - if config.actor_rollout_ref.actor.strategy == 'fsdp': + if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup - elif config.actor_rollout_ref.actor.strategy == 'megatron': + elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -90,7 +94,7 @@ def main_task(config, compute_score=None): Role.ActorRollout: ray.remote(ActorRolloutRefWorker), } - global_pool_id = 'global_pool' + global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } @@ -98,22 +102,25 @@ def main_task(config, compute_score=None): Role.ActorRollout: global_pool_id, } - #use reference model + # use reference model if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id if config.reward_model.enable: from .prime_fsdp_workers import PRIMERewardModelWorker + role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) mapping[Role.RewardModel] = global_pool_id reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == 'naive': + if reward_manager_name == "naive": from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager - elif reward_manager_name == 'prime': + elif reward_manager_name == "prime": from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager else: raise NotImplementedError @@ -124,16 +131,18 @@ def main_task(config, compute_score=None): resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - trainer = RayPRIMETrainer(config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) + trainer = RayPRIMETrainer( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) trainer.init_workers() trainer.fit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py index 699df1aeb..d2a723316 100644 --- a/recipe/prime/prime_core_algos.py +++ b/recipe/prime/prime_core_algos.py @@ -13,6 +13,7 @@ # limitations under the License. import torch + import verl import verl.utils.torch_functional as verl_F @@ -23,45 +24,48 @@ def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Ten reward_tensor = reward_tensor_original.clone() reward_tensor[~mask_tensor] = 0 for start_pos in range(0, reward_tensor.shape[0], n_samples): - cur_rewards_mean = torch.cat([ - reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True) - for pos in range(start_pos, start_pos + n_samples) - ], - dim=0) + cur_rewards_mean = torch.cat( + [ + reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) + for pos in range(start_pos, start_pos + n_samples) + ], + dim=0, + ) cur_rewards_sum = cur_rewards_mean.sum() cur_reward_baseline = cur_rewards_sum / (n_samples - 1) - reward_tensor[start_pos:start_pos + n_samples][ - mask_tensor[start_pos:start_pos + n_samples]] = \ - reward_tensor[start_pos:start_pos + n_samples][ - mask_tensor[start_pos:start_pos + n_samples]] * ( - n_samples / (n_samples - 1)) - cur_reward_baseline + reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = ( + reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] + * (n_samples / (n_samples - 1)) + - cur_reward_baseline + ) return reward_tensor reward_tensors = [] with torch.no_grad(): - - if 'rm_scores' in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.: - reward_tensor = data.batch['rm_scores'] + if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0: + reward_tensor = data.batch["rm_scores"] reward_mask = response_mask.bool() reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) - if 'acc' in data.batch.keys() and config.algorithm.reward_gt_coef != 0.: + if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0: reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32) reward_mask = torch.zeros_like(response_mask, dtype=torch.bool) - prompt_ids = data.batch['prompts'] + prompt_ids = data.batch["prompts"] prompt_length = prompt_ids.shape[-1] - valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1) + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1) reward_mask[ torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), - valid_response_length - 1] = True + valid_response_length - 1, + ] = True reward_tensor[ torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), - valid_response_length - 1] = data.batch['acc'] + valid_response_length - 1, + ] = data.batch["acc"] reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef) @@ -81,7 +85,7 @@ def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta): return cur_dpo_loss -def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode='none'): +def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"): # we always assume that the BoN size equals n_samples # mode1: use acc as rm # mode2: use Q as rm @@ -97,15 +101,15 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_m else: other_Q[i] = 0 dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1))) - if bon_mode == 'none': + if bon_mode == "none": dpo_loss = dpo_loss.mean() else: weight = torch.zeros_like(dpo_loss) n_samples = acc_bc.shape[1] - if bon_mode == 'bon_rm': + if bon_mode == "bon_rm": for i in range(token_level_scores.shape[0]): weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1) - elif bon_mode == 'bon_acc': + elif bon_mode == "bon_acc": for i in range(token_level_scores.shape[0]): weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1) else: @@ -118,22 +122,24 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_m def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples): dpo_acc = [] for start_id in range(0, token_level_scores.shape[0], n_samples): - cur_scores = (token_level_scores[start_id:start_id + n_samples] * - response_mask[start_id:start_id + n_samples]).sum(dim=1) + cur_scores = ( + token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples] + ).sum(dim=1) def get_upper_triangle(tensor_x): diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1) return diff_matrix[upper_tri_indices] - cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples]) # in range [-1,1] + cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1] cur_score_diff = get_upper_triangle(cur_scores) # in R cur_score_prediction = (cur_score_diff > 0).float() # in [0,1] if cur_acc_diff.abs().sum() == 0: cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 else: - cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * - cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum() + cur_acc = ( + ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs() + ).sum() / cur_acc_diff.abs().sum() dpo_acc.append(cur_acc.unsqueeze(0)) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index 16c2d32f4..45e37dead 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -14,62 +14,56 @@ """ Implement a multiprocess PPOCritic """ + import itertools -from typing import Iterable import torch import torch.distributed +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn, optim - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm -from verl import DataProto -from verl.trainer.ppo import core_algos -from verl.workers.critic import BasePPOCritic -from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import masked_mean -from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad -from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm -__all__ = ['DataParallelPRIMERewardModel'] +__all__ = ["DataParallelPRIMERewardModel"] class DataParallelPRIMERewardModel: - def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): self.config = config self.reward_module = reward_module self.ref_module = ref_module self.reward_optimizer = reward_optimizer - self.use_remove_padding = self.config.model.get('use_remove_padding', False) - print(f'Reward model use_remove_padding={self.use_remove_padding}') + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + print(f"Reward model use_remove_padding={self.use_remove_padding}") - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) def _forward_micro_batch(self, micro_batch, prompt_length): - from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange - from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad - - input_ids = micro_batch['input_ids'] + input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] - num_actions = micro_batch['input_ids'].shape[-1] - prompt_length - max_positions = micro_batch['attention_mask'][:, prompt_length:].sum(-1) + num_actions = micro_batch["input_ids"].shape[-1] - prompt_length + max_positions = micro_batch["attention_mask"][:, prompt_length:].sum(-1) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # for compute the log_prob input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) @@ -77,90 +71,93 @@ class DataParallelPRIMERewardModel: # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, - self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size + ) input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) - rm_output_logits = self.reward_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False).logits.squeeze( - 0) # copied. I don't really know why there is a squeeze + rm_output_logits = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ).logits.squeeze(0) # copied. I don't really know why there is a squeeze rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled) if self.ulysses_sequence_parallel_size > 1: rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) - rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1] + rm_log_labels = pad_input( + hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1)[:, -num_actions - 1 : -1] else: - rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'], - attention_mask=micro_batch['attention_mask'], - position_ids=micro_batch['position_ids'], - use_cache=False).logits - rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :], - dim=-1) # (batch_size, seq_length, vocab_size) - rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze( - -1) # (batch, seq_length) + rm_output_logits = self.reward_module( + input_ids=micro_batch["input_ids"], + attention_mask=micro_batch["attention_mask"], + position_ids=micro_batch["position_ids"], + use_cache=False, + ).logits + rm_log_prob = torch.nn.functional.log_softmax( + rm_output_logits[:, :-1, :], dim=-1 + ) # (batch_size, seq_length, vocab_size) + rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze( + -1 + ) # (batch, seq_length) if self.ref_module is not None: # do not have to pad again - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: - ref_output_logits = self.ref_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False).logits.squeeze(0) - ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, - labels=input_ids_rmpad_rolled) - ref_log_labels = gather_outpus_and_unpad(ref_log_labels, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) - ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1] + ref_output_logits = self.ref_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ).logits.squeeze(0) + ref_log_labels = verl_F.logprobs_from_logits( + logits=ref_output_logits, labels=input_ids_rmpad_rolled + ) + ref_log_labels = gather_outpus_and_unpad( + ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + ref_log_labels = pad_input( + hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1)[:, -num_actions - 1 : -1] else: - ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'], - attention_mask=micro_batch['attention_mask'], - position_ids=micro_batch['position_ids'], - use_cache=False).logits - ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :], - dim=-1) # (batch_size, seq_length, vocab_size) - ref_log_labels = ref_log_prob.gather(dim=-1, - index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze( - -1) # (batch, seq_length) + ref_output_logits = self.ref_module( + input_ids=micro_batch["input_ids"], + attention_mask=micro_batch["attention_mask"], + position_ids=micro_batch["position_ids"], + use_cache=False, + ).logits + ref_log_prob = torch.nn.functional.log_softmax( + ref_output_logits[:, :-1, :], dim=-1 + ) # (batch_size, seq_length, vocab_size) + ref_log_labels = ref_log_prob.gather( + dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1) + ).squeeze(-1) # (batch, seq_length) else: - ref_log_labels = micro_batch['old_log_probs'] + ref_log_labels = micro_batch["old_log_probs"] ref_log_labels.to(rm_log_labels.dtype) q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q # trim unnecessary logprobs here - for i in range(micro_batch['input_ids'].shape[0]): - q[i, max_positions[i]:] = 0 + for i in range(micro_batch["input_ids"].shape[0]): + q[i, max_positions[i] :] = 0 # reward computation does not need gradient. only q needs with torch.no_grad(): - # generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model. - lam = self.config.get('lambda', 0.) - beta = self.config.model.get('beta_train', 0.05) - if lam == 0.: + lam = self.config.get("lambda", 0.0) + beta = self.config.model.get("beta_train", 0.05) + if lam == 0.0: r = q * beta else: # reward coefficient takes no effect here - acc = micro_batch['acc'] + acc = micro_batch["acc"] q_ = q * beta r = torch.zeros_like(q) lastgaelam = 0 # change the last token and mask out all paddings to make this process easier if we rely on outcome reward to calculate V for i in range(q.shape[0]): if self.config.prime_use_gt: - q_[i, max_positions[i] - 1] = acc[i] - q_[i, :max_positions[i] - 1].sum() - q_[i, max_positions[i]:] = 0 + q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum() + q_[i, max_positions[i] :] = 0 for t in reversed(range(num_actions)): delta = q_[:, t] @@ -169,12 +166,12 @@ class DataParallelPRIMERewardModel: token_level_score = torch.zeros_like(q) - if self.config.prime_granularity == 'token': - for i in range(micro_batch['input_ids'].shape[0]): - token_level_score[i, :max_positions[i] - 1] = r[i, :max_positions[i] - 1] - elif self.config.prime_granularity == 'whole': - for i in range(micro_batch['input_ids'].shape[0]): - token_level_score[i, max_positions[i] - 1] = r[i, :max_positions[i]] + if self.config.prime_granularity == "token": + for i in range(micro_batch["input_ids"].shape[0]): + token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1] + elif self.config.prime_granularity == "whole": + for i in range(micro_batch["input_ids"].shape[0]): + token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]] else: raise NotImplementedError @@ -186,13 +183,14 @@ class DataParallelPRIMERewardModel: if isinstance(self.reward_module, FSDP): grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), - max_norm=self.config.model.optim.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip + ) self.reward_optimizer.step() return grad_norm def prime_norm(self, token_level_scores): - if self.config.prime_norm == 'batch_norm': + if self.config.prime_norm == "batch_norm": reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1]) token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6) return token_level_scores @@ -200,15 +198,15 @@ class DataParallelPRIMERewardModel: def compute_rm_score(self, data: DataProto): self.reward_module.eval() self.ref_module.eval() - micro_batch_size = data.meta_info['micro_batch_size'] - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'acc'] + micro_batch_size = data.meta_info["micro_batch_size"] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"] batch = data.select(batch_keys=select_keys).batch - use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] - prompt_length = data.batch['input_ids'].shape[-1] - data.batch['responses'].shape[-1] + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] if use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: micro_batches = batch.split(micro_batch_size) @@ -231,21 +229,25 @@ class DataParallelPRIMERewardModel: revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) rm_scores = rm_scores[revert_indices] - return rm_scores, q.detach(), { - 'reward_model/reward': rm_scores.sum(dim=-1).mean().item(), - 'reward_model/raw_reward': q.sum(dim=-1).mean().item() - } + return ( + rm_scores, + q.detach(), + { + "reward_model/reward": rm_scores.sum(dim=-1).mean().item(), + "reward_model/raw_reward": q.sum(dim=-1).mean().item(), + }, + ) def update_rm(self, data: DataProto): # make sure we are in training mode self.reward_module.train() metrics = {} - beta = self.config.model.get('beta_train', 0.05) + beta = self.config.model.get("beta_train", 0.05) - select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'acc', 'prompts'] + select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "acc", "prompts"] - for key in ['Q_bc', 'acc_bc']: + for key in ["Q_bc", "acc_bc"]: if key in data.batch.keys(): select_keys.append(key) @@ -271,10 +273,10 @@ class DataParallelPRIMERewardModel: for data in micro_batches: data = data.cuda() - attention_mask = data['attention_mask'] - acc = data['acc'] + attention_mask = data["attention_mask"] + acc = data["acc"] - prompt_ids = data['prompts'] + prompt_ids = data["prompts"] prompt_length = prompt_ids.shape[-1] response_mask = attention_mask[:, prompt_length:] @@ -284,37 +286,38 @@ class DataParallelPRIMERewardModel: rm_scores_lst.append(rm_score) q_lst.append(q.detach()) - if self.config.model.loss_type == 'ce': + if self.config.model.loss_type == "ce": dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta) - elif self.config.model.loss_type == 'dpo': + elif self.config.model.loss_type == "dpo": # the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update. - dpo_loss = compute_detach_dpo_loss_rm(q, - acc, - Q_bc=data['Q_bc'], - acc_bc=data['acc_bc'], - response_mask=response_mask, - beta=beta) - elif self.config.model.loss_type == 'bon_acc': + dpo_loss = compute_detach_dpo_loss_rm( + q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta + ) + elif self.config.model.loss_type == "bon_acc": # change the original distribution of each sample to BoN distribution, then update reward model - dpo_loss = compute_detach_dpo_loss_rm(q, - acc, - Q_bc=data['Q_bc'], - acc_bc=data['acc_bc'], - response_mask=response_mask, - beta=beta, - bon_mode='bon_acc') - elif self.config.model.loss_type == 'bon_rm': - dpo_loss = compute_detach_dpo_loss_rm(q, - acc, - Q_bc=data['Q_bc'], - acc_bc=data['acc_bc'], - response_mask=response_mask, - beta=beta, - bon_mode='bon_rm') + dpo_loss = compute_detach_dpo_loss_rm( + q, + acc, + Q_bc=data["Q_bc"], + acc_bc=data["acc_bc"], + response_mask=response_mask, + beta=beta, + bon_mode="bon_acc", + ) + elif self.config.model.loss_type == "bon_rm": + dpo_loss = compute_detach_dpo_loss_rm( + q, + acc, + Q_bc=data["Q_bc"], + acc_bc=data["acc_bc"], + response_mask=response_mask, + beta=beta, + bon_mode="bon_rm", + ) else: raise NotImplementedError - data = {'reward_model/dpo_loss': dpo_loss.detach().item()} + data = {"reward_model/dpo_loss": dpo_loss.detach().item()} if self.config.use_dynamic_bsz: # relative to the dynamic bsz @@ -327,7 +330,7 @@ class DataParallelPRIMERewardModel: append_to_dict(metrics, data) grad_norm = self._optimizer_step() - data = {'reward_model/grad_norm': grad_norm.detach().item()} + data = {"reward_model/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, data) self.reward_optimizer.zero_grad() @@ -336,9 +339,11 @@ class DataParallelPRIMERewardModel: rm_scores = self.prime_norm(rm_scores) - metrics.update({ - 'reward_model/reward': rm_scores.sum(dim=-1).mean().item(), - 'reward_model/raw_reward': q.sum(dim=-1).mean().item() - }) + metrics.update( + { + "reward_model/reward": rm_scores.sum(dim=-1).mean().item(), + "reward_model/raw_reward": q.sum(dim=-1).mean().item(), + } + ) return rm_scores, metrics diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 9a94afc46..9e7b8f5db 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import logging import os import warnings @@ -19,54 +18,56 @@ import warnings import torch import torch.distributed from torch.distributed.device_mesh import init_device_mesh -import verl.utils.torch_functional as verl_F -from omegaconf import DictConfig, open_dict + from verl import DataProto from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import register, Dispatch +from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer -from verl.utils.debug import log_gpu_memory_usage -from verl.utils.fs import copy_local_path_from_hdfs -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager -from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ - load_fsdp_model_to_gpu -from verl.utils.import_utils import import_external_libs -from verl.utils.model import compute_position_id_with_mask -from verl.utils.flops_counter import FlopsCounter from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.fsdp_utils import ( + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) +from verl.utils.import_utils import import_external_libs +from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from codetiming import Timer -from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy -from .prime_core_algos import compute_dpo_accuracy, compute_dpo_abs_accuracy +from .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) class PRIMERewardModelWorker(Worker): - def __init__(self, config): super().__init__() import torch.distributed + if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.config = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=['dp', 'sp']) + self.ulysses_device_mesh = init_device_mesh( + "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -75,40 +76,42 @@ class PRIMERewardModelWorker(Worker): self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config - self.config.mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size self.config.micro_batch_size_per_gpu = self.config.micro_batch_size assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 def _build_reward_ref_model_optimizer(self, config): # the following line is necessary - from verl.utils.model import LambdaLayer, print_model_size, squeeze - from verl.utils.torch_dtypes import PrecisionType - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision from torch import optim + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import MixedPrecision + + from verl.utils.model import print_model_size + from verl.utils.torch_dtypes import PrecisionType local_path = copy_local_path_from_hdfs(config.model.path) tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) - self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) from omegaconf import OmegaConf - override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + + override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) override_config_kwargs = { - 'bos_token_id': self.tokenizer.bos_token_id, - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_config) if self.rank == 0: - print(f'Reward model overriding config {override_config_kwargs}') + print(f"Reward model overriding config {override_config_kwargs}") - torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32') + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig, AutoModelForCausalLM - from torch import nn trust_remote_code = False reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) @@ -117,34 +120,37 @@ class PRIMERewardModelWorker(Worker): init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(reward_model_config, 'classifier_dropout', 0.) - setattr(reward_model_config, 'hidden_dropout', '0') - reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=reward_model_config, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + reward_model_config.classifier_dropout = 0.0 + reward_model_config.hidden_dropout = "0" + reward_module = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=reward_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) - if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1: + if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) # some parameters may not in torch_dtype reward_module.to(torch_dtype) - if config.model.get('enable_gradient_checkpointing', False): - reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + if config.model.get("enable_gradient_checkpointing", False): + reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if self.rank == 0: print_model_size(reward_module) self.reward_model_config = reward_model_config fsdp_config = self.config.model.fsdp_config - mixed_precision_config = fsdp_config.get('mixed_precision', None) + mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 @@ -154,78 +160,89 @@ class PRIMERewardModelWorker(Worker): auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) - log_gpu_memory_usage('Before reward model FSDP', logger=None) + log_gpu_memory_usage("Before reward model FSDP", logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(reward_model_config, 'classifier_dropout', 0.) - setattr(reward_model_config, 'hidden_dropout', '0') - ref_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=copy_local_path_from_hdfs( - config.model.ref_path), - torch_dtype=torch_dtype, - config=reward_model_config, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + reward_model_config.classifier_dropout = 0.0 + reward_model_config.hidden_dropout = "0" + ref_module = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path), + torch_dtype=torch_dtype, + config=reward_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) # some parameters may not in torch_dtype ref_module.to(torch_dtype) - reward_module = FSDP(reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None) + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) - log_gpu_memory_usage('After reward FSDP', logger=None) + log_gpu_memory_usage("After reward FSDP", logger=None) - ref_module = FSDP(ref_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None) + ref_module = FSDP( + ref_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) - reward_optimizer = optim.AdamW(reward_module.parameters(), - lr=config.model.optim.lr, - betas=config.model.optim.get('betas', (0.9, 0.999)), - weight_decay=config.model.optim.get('weight_decay', 1e-2)) + reward_optimizer = optim.AdamW( + reward_module.parameters(), + lr=config.model.optim.lr, + betas=config.model.optim.get("betas", (0.9, 0.999)), + weight_decay=config.model.optim.get("weight_decay", 1e-2), + ) - total_steps = config.model.optim.get('total_training_steps', 0) - num_warmup_steps = int(config.model.optim.get('lr_warmup_steps', -1)) + total_steps = config.model.optim.get("total_training_steps", 0) + num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1)) if num_warmup_steps < 0: - num_warmup_steps_ratio = config.model.optim.get('lr_warmup_steps_ratio', 0.) + num_warmup_steps_ratio = config.model.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import get_constant_schedule_with_warmup - reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=reward_optimizer, - num_warmup_steps=num_warmup_steps) + + reward_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps + ) return reward_module, ref_module, reward_optimizer, reward_lr_scheduler @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) from .prime_dp_rm import DataParallelPRIMERewardModel - self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = self._build_reward_ref_model_optimizer( - config=self.config) + + self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = ( + self._build_reward_ref_model_optimizer(config=self.config) + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.reward_module) @@ -233,47 +250,51 @@ class PRIMERewardModelWorker(Worker): if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.reward_optimizer) - self.rm = DataParallelPRIMERewardModel(config=self.config, - reward_module=self.reward_module, - ref_module=self.ref_module, - reward_optimizer=self.reward_optimizer) + self.rm = DataParallelPRIMERewardModel( + config=self.config, + reward_module=self.reward_module, + ref_module=self.ref_module, + reward_optimizer=self.reward_optimizer, + ) self.flops_counter = FlopsCounter(self.reward_model_config) - self.checkpoint_manager = FSDPCheckpointManager(model=self.reward_module, - optimizer=self.reward_optimizer, - lr_scheduler=self.reward_lr_scheduler, - tokenizer=self.tokenizer) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.reward_module, + optimizer=self.reward_optimizer, + lr_scheduler=self.reward_lr_scheduler, + tokenizer=self.tokenizer, + ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): - data = data.to('cuda') + data = data.to("cuda") if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) load_fsdp_model_to_gpu(self.ref_module) micro_batch_size = self.config.micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size - data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) rm_scores, q, metrics = self.rm.compute_rm_score(data=data) - prompt_length = data.batch['prompts'].shape[-1] - response_mask = data.batch['attention_mask'][:, prompt_length:] - acc = data.batch['acc'] + prompt_length = data.batch["prompts"].shape[-1] + response_mask = data.batch["attention_mask"][:, prompt_length:] + acc = data.batch["acc"] - dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info['n']) - dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n']) + dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]) + dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) - metrics['reward_model/dpo_acc'] = dpo_acc.detach().item() - metrics['reward_model/dpo_acc_abs'] = dpo_acc_abs.detach().item() + metrics["reward_model/dpo_acc"] = dpo_acc.detach().item() + metrics["reward_model/dpo_acc_abs"] = dpo_acc_abs.detach().item() - output = DataProto.from_dict(tensors={'rm_scores': rm_scores, 'q': q}, meta_info={'metrics': metrics}) + output = DataProto.from_dict(tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to('cpu') + output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.reward_module) offload_fsdp_model_to_cpu(self.ref_module) @@ -281,7 +302,7 @@ class PRIMERewardModelWorker(Worker): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_rm(self, data: DataProto): - data = data.to('cuda') + data = data.to("cuda") if self._is_offload_param: load_fsdp_model_to_gpu(self.ref_module) load_fsdp_model_to_gpu(self.reward_module) @@ -296,22 +317,21 @@ class PRIMERewardModelWorker(Worker): self.reward_lr_scheduler.step() lr = self.reward_lr_scheduler.get_last_lr()[0] - metrics['rm/lr'] = lr + metrics["rm/lr"] = lr - prompt_length = data.batch['prompts'].shape[-1] - response_mask = data.batch['attention_mask'][:, prompt_length:] - acc = data.batch['acc'] + prompt_length = data.batch["prompts"].shape[-1] + response_mask = data.batch["attention_mask"][:, prompt_length:] + acc = data.batch["acc"] - dpo_acc_before = compute_dpo_accuracy(rm_scores, - acc, - response_mask=response_mask, - n_samples=data.meta_info['n']) - dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n']) + dpo_acc_before = compute_dpo_accuracy( + rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"] + ) + dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) - metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item() - metrics['reward_model/dpo_acc_abs_before'] = dpo_acc_abs.detach().item() + metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item() + metrics["reward_model/dpo_acc_abs_before"] = dpo_acc_abs.detach().item() - output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics}) + output = DataProto.from_dict(tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: @@ -319,19 +339,19 @@ class PRIMERewardModelWorker(Worker): offload_fsdp_model_to_cpu(self.ref_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.reward_optimizer) - output = output.to('cpu') + output = output.to("cpu") return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): import torch + if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) - self.checkpoint_manager.save_checkpoint(local_path=local_path, - hdfs_path=hdfs_path, - global_step=global_step, - max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) torch.distributed.barrier() if self._is_offload_param: @@ -340,6 +360,7 @@ class PRIMERewardModelWorker(Worker): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, del_local_after_load=True): import torch + if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 7b1d40458..a21416e19 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -28,129 +28,115 @@ from omegaconf import OmegaConf, open_dict from verl import DataProto from verl.single_controller.ray import RayWorkerGroup -from verl.trainer.ppo.ray_trainer import RayPPOTrainer -from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _timer -from verl.trainer.ppo.metric_utils import _compute_response_info from verl.trainer.ppo.core_algos import agg_loss -from verl.utils.py_functional import append_to_dict +from verl.trainer.ppo.metric_utils import _compute_response_info +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + from . import prime_core_algos def compute_advantage(data: DataProto, adv_estimator, config): - if adv_estimator == 'rloo': - responses = data.batch['responses'] + if adv_estimator == "rloo": + responses = data.batch["responses"] response_length = responses.size(-1) - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] response_mask = attention_mask[:, -response_length:] - advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, - config.actor_rollout_ref.rollout.n, config) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + advantages, returns = prime_core_algos.compute_rloo_advantage_return( + data, response_mask, config.actor_rollout_ref.rollout.n, config + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns else: raise NotImplementedError return data def compute_data_metrics(batch, use_critic=True): + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] - advantages = batch.batch['advantages'] - returns = batch.batch['returns'] + max_response_length = batch.batch["responses"].shape[-1] - max_response_length = batch.batch['responses'].shape[-1] - - prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() - response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) - prompt_length = response_info['prompt_length'] - response_length = response_info['response_length'] + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) if use_critic: - values = batch.batch['values'] + values = batch.batch["values"] valid_values = torch.masked_select(values, response_mask) return_diff_var = torch.var(valid_returns - valid_values) return_var = torch.var(valid_returns) metrics = { # adv - 'critic/advantages/mean': - torch.mean(valid_adv).detach().item(), - 'critic/advantages/max': - torch.max(valid_adv).detach().item(), - 'critic/advantages/min': - torch.min(valid_adv).detach().item(), + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), # returns - 'critic/returns/mean': - torch.mean(valid_returns).detach().item(), - 'critic/returns/max': - torch.max(valid_returns).detach().item(), - 'critic/returns/min': - torch.min(valid_returns).detach().item(), - **({ - # values - 'critic/values/mean': torch.mean(valid_values).detach().item(), - 'critic/values/max': torch.max(valid_values).detach().item(), - 'critic/values/min': torch.min(valid_values).detach().item(), - # vf explained var - 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), - } if use_critic else {}), - + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), # response length - 'response_length/mean': - torch.mean(response_length).detach().item(), - 'response_length/max': - torch.max(response_length).detach().item(), - 'response_length/min': - torch.min(response_length).detach().item(), - 'response_length/clip_ratio': - torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), # prompt length - 'prompt_length/mean': - torch.mean(prompt_length).detach().item(), - 'prompt_length/max': - torch.max(prompt_length).detach().item(), - 'prompt_length/min': - torch.min(prompt_length).detach().item(), - 'prompt_length/clip_ratio': - torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } return metrics def compute_response_mask(data: DataProto): - responses = data.batch['responses'] + responses = data.batch["responses"] response_length = responses.size(1) - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] return attention_mask[:, -response_length:] def compute_timing_metrics(batch, timing_raw): response_info = _compute_response_info(batch) - num_prompt_tokens = torch.sum(response_info['prompt_length']).item() - num_response_tokens = torch.sum(response_info['response_length']).item() + num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() + num_response_tokens = torch.sum(response_info["response_length"]).item() num_overall_tokens = num_prompt_tokens + num_response_tokens num_tokens_of_section = { - 'gen': num_response_tokens, - **{ - name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor'] - }, + "gen": num_response_tokens, + **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, } return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ - f'timing_s/{name}': value for name, value in timing_raw.items() - }, - **{ - f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( - )) & set(timing_raw.keys()) + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) }, } @@ -162,19 +148,27 @@ class RayPRIMETrainer(RayPPOTrainer): # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role - def __init__(self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - reward_fn=None, - val_reward_fn=None): - + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + reward_fn=None, + val_reward_fn=None, + ): # assert torch.cuda.is_available(), 'cuda must be available on driver' - super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn, - val_reward_fn) + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + reward_fn, + val_reward_fn, + ) self.use_critic = False @@ -185,39 +179,43 @@ class RayPRIMETrainer(RayPPOTrainer): def _create_dataloader(self): from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + # TODO: we have to make sure the batch size is divisible by the dp size - self.train_dataset = RLHFDataset(data_files=self.config.data.train_files, - tokenizer=self.tokenizer, - config=self.config.data) + self.train_dataset = RLHFDataset( + data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data + ) # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) else: sampler = SequentialSampler(data_source=self.train_dataset) - self.train_dataloader = DataLoader(dataset=self.train_dataset, - batch_size=int(self.config.data.train_batch_size * - self.config.data.oversample_factor), - drop_last=True, - collate_fn=collate_fn, - sampler=sampler) + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor), + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) - self.val_dataset = RLHFDataset(data_files=self.config.data.val_files, - tokenizer=self.tokenizer, - config=self.config.data) - self.val_dataloader = DataLoader(dataset=self.val_dataset, - batch_size=len(self.val_dataset), - shuffle=True, - drop_last=True, - collate_fn=collate_fn) + self.val_dataset = RLHFDataset( + data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data + ) + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) assert len(self.train_dataloader) >= 1 assert len(self.val_dataloader) >= 1 - print(f'Size of train dataloader: {len(self.train_dataloader)}') - print(f'Size of val dataloader: {len(self.val_dataloader)}') + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs @@ -226,7 +224,7 @@ class RayPRIMETrainer(RayPPOTrainer): total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps - print(f'Total training steps: {self.total_training_steps}') + print(f"Total training steps: {self.total_training_steps}") OmegaConf.set_struct(self.config, True) with open_dict(self.config): @@ -235,45 +233,58 @@ class RayPRIMETrainer(RayPPOTrainer): def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, - f'global_step_{self.global_steps}') - print(f'local_global_step_folder: {local_global_step_folder}') - actor_local_path = os.path.join(local_global_step_folder, 'actor') + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') - self.actor_rollout_wg.save_checkpoint(actor_local_path, - actor_remote_path, - self.global_steps, - remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + self.actor_rollout_wg.save_checkpoint( + actor_local_path, + actor_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save, + ) if self.use_rm: - reward_local_path = os.path.join(local_global_step_folder, 'reward') - reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward') - self.rm_wg.save_checkpoint(reward_local_path, - reward_remote_path, - self.global_steps, - remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + reward_local_path = os.path.join(local_global_step_folder, "reward") + reward_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") + ) + self.rm_wg.save_checkpoint( + reward_local_path, + reward_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save, + ) # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") import dill + torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, - 'latest_checkpointed_iteration.txt') - with open(local_latest_checkpointed_iteration, 'w') as f: + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) def _load_checkpoint(self): - if self.config.trainer.resume_mode == 'disable': + if self.config.trainer.resume_mode == "disable": return 0 # load from hdfs if self.config.trainer.default_hdfs_dir is not None: - NotImplementedError('load from hdfs is not implemented yet') + NotImplementedError("load from hdfs is not implemented yet") else: checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path if not os.path.isabs(checkpoint_folder): @@ -282,37 +293,40 @@ class RayPRIMETrainer(RayPPOTrainer): global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder - if self.config.trainer.resume_mode == 'auto': + if self.config.trainer.resume_mode == "auto": if global_step_folder is None: - print('Training from scratch') + print("Training from scratch") return 0 else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) - print(f'Load from checkpoint folder: {global_step_folder}') + print(f"Load from checkpoint folder: {global_step_folder}") # set global step - self.global_steps = int(global_step_folder.split('global_step_')[-1]) + self.global_steps = int(global_step_folder.split("global_step_")[-1]) - print(f'Setting global step to {self.global_steps}') - print(f'Resuming from {global_step_folder}') + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") - actor_path = os.path.join(global_step_folder, 'actor') - reward_path = os.path.join(global_step_folder, 'reward') + actor_path = os.path.join(global_step_folder, "actor") + reward_path = os.path.join(global_step_folder, "reward") # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, - del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load rm if self.use_rm: self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, 'data.pt') + dataloader_local_path = os.path.join(global_step_folder, "data.pt") self.train_dataloader = torch.load(dataloader_local_path) if isinstance(self.train_dataloader.dataset, RLHFDataset): self.train_dataloader.dataset.resume_dataset_state() @@ -323,13 +337,16 @@ class RayPRIMETrainer(RayPPOTrainer): The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ - from verl.utils.tracking import Tracking from omegaconf import OmegaConf - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True)) + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) self.global_steps = 0 @@ -338,11 +355,11 @@ class RayPRIMETrainer(RayPPOTrainer): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') + pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get('val_only', False): + if self.config.trainer.get("val_only", False): return # we start from step 1 @@ -356,17 +373,17 @@ class RayPRIMETrainer(RayPPOTrainer): batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) - with _timer('step', timing_raw): + with _timer("step", timing_raw): # generate a batch - with _timer('gen', timing_raw): + with _timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - if self.config.algorithm.adv_estimator == 'remax': - with _timer('gen_max', timing_raw): + if self.config.algorithm.adv_estimator == "remax": + with _timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) @@ -375,12 +392,13 @@ class RayPRIMETrainer(RayPPOTrainer): batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - batch.batch['reward_baselines'] = reward_baseline_tensor + batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], - dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -391,96 +409,105 @@ class RayPRIMETrainer(RayPPOTrainer): # self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # verify - with _timer('verify', timing_raw): + with _timer("verify", timing_raw): scores = self.reward_fn.verify(batch) - metrics['acc'] = statistics.mean(scores) + metrics["acc"] = statistics.mean(scores) # filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized. batch = self.filter_and_downsample(scores, batch) - batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n + batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n n_samples = self.config.actor_rollout_ref.rollout.n # recompute old_log_probs - with _timer('old_log_prob', timing_raw): + with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch['entropys'] + entropys = old_log_prob.batch["entropys"] response_masks = compute_response_mask(batch) loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, - loss_mask=response_masks, - loss_agg_mode=loss_agg_mode) + entropy_loss = agg_loss( + loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode + ) old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop('entropys') + old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer('ref', timing_raw): + with _timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) - with _timer('adv', timing_raw): - + with _timer("adv", timing_raw): if self.use_rm: - update_style = self.config.reward_model.model.get('update', 'none') - if update_style == 'none': # only run forward + update_style = self.config.reward_model.model.get("update", "none") + if update_style == "none": # only run forward reward_output = self.rm_wg.compute_rm_score(batch) - elif update_style == 'after': # update and directly return the reward + elif update_style == "after": # update and directly return the reward reward_output = self.rm_wg.update_rm(batch) - elif update_style == 'before': # update reward model, and then run forward + elif update_style == "before": # update reward model, and then run forward reward_output = self.rm_wg.update_rm(batch) - if 'metrics' in reward_output.meta_info.keys(): - reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) + if "metrics" in reward_output.meta_info.keys(): + reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) metrics.update(reward_output_metrics) reward_output = self.rm_wg.compute_rm_score(batch) - elif update_style == 'reverse': # run forward to calculate statistics, then update reward model + elif ( + update_style == "reverse" + ): # run forward to calculate statistics, then update reward model reward_output = self.rm_wg.compute_rm_score(batch) # broadcast q and acc tensor to each result bc_td = DataProto.from_dict( tensors={ - 'Q_bc': - reward_output.batch['q'].sum(dim=-1).view(-1, n_samples).unsqueeze( - 1).expand(-1, n_samples, -1).reshape(-1, n_samples), - 'acc_bc': - batch.batch['acc'].view(-1, n_samples).unsqueeze(1).expand( - -1, n_samples, -1).reshape(-1, n_samples) - }) + "Q_bc": reward_output.batch["q"] + .sum(dim=-1) + .view(-1, n_samples) + .unsqueeze(1) + .expand(-1, n_samples, -1) + .reshape(-1, n_samples), + "acc_bc": batch.batch["acc"] + .view(-1, n_samples) + .unsqueeze(1) + .expand(-1, n_samples, -1) + .reshape(-1, n_samples), + } + ) batch = batch.union(bc_td) reward_output = self.rm_wg.update_rm(batch) else: raise NotImplementedError batch = batch.union(reward_output) - if 'metrics' in reward_output.meta_info.keys(): - reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) + if "metrics" in reward_output.meta_info.keys(): + reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) metrics.update(reward_output_metrics) # compute advantages, executed on the driver process - batch = compute_advantage(batch, - adv_estimator=self.config.algorithm.adv_estimator, - config=self.config) + batch = compute_advantage( + batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config + ) # update actor - with _timer('update_actor', timing_raw): + with _timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ - self.global_steps % self.config.trainer.test_freq == 0: - with _timer('testing', timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and self.global_steps % self.config.trainer.test_freq == 0 + ): + with _timer("testing", timing_raw): val_metrics: dict = self._validate() metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and \ - self.global_steps % self.config.trainer.save_freq == 0: - with _timer('save_checkpoint', timing_raw): + if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: + with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics @@ -493,15 +520,16 @@ class RayPRIMETrainer(RayPPOTrainer): self.global_steps += 1 if self.global_steps >= self.total_training_steps: - # perform validation after training if self.val_reward_fn is not None: val_metrics = self._validate() - pprint(f'Final validation metrics: {val_metrics}') + pprint(f"Final validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.save_freq > 0 and \ - (self.global_steps - 1) % self.config.trainer.save_freq != 0: - with _timer('save_checkpoint', timing_raw): + if ( + self.config.trainer.save_freq > 0 + and (self.global_steps - 1) % self.config.trainer.save_freq != 0 + ): + with _timer("save_checkpoint", timing_raw): self._save_checkpoint() return @@ -517,18 +545,24 @@ class RayPRIMETrainer(RayPPOTrainer): if self.config.data.filter_accuracy: acc_tensor = torch.mean(reward_matrix, dim=-1) - filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) | - (acc_tensor < self.config.data.accuracy_lower_bound)] = False + filter_mask[ + (acc_tensor > self.config.data.accuracy_upper_bound) + | (acc_tensor < self.config.data.accuracy_lower_bound) + ] = False if self.config.data.filter_truncate: - length_matrix = batch.batch['attention_mask'][:, -batch.batch['responses'].shape[-1]:].sum(dim=-1).reshape( - -1, n_samples) + length_matrix = ( + batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :] + .sum(dim=-1) + .reshape(-1, n_samples) + ) length_tensor = torch.max(length_matrix, dim=-1)[0] filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False reorder_index = torch.argsort(filter_mask, descending=True) reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) - batch.reorder(reorder_index[:int(len(batch) // - self.config.data.oversample_factor)]) # this operation is inplace + batch.reorder( + reorder_index[: int(len(batch) // self.config.data.oversample_factor)] + ) # this operation is inplace return batch diff --git a/recipe/r1/data_process.py b/recipe/r1/data_process.py index 93107865c..19ffd0104 100644 --- a/recipe/r1/data_process.py +++ b/recipe/r1/data_process.py @@ -15,54 +15,44 @@ Preprocess the dataset to parquet format """ +import argparse import os -from datasets import load_dataset, concatenate_datasets from functools import partial +from datasets import concatenate_datasets, load_dataset + from verl.utils.hdfs_io import copy, makedirs -import argparse def example_map_fn(example, idx, process_fn, data_source, ability, split): question, solution = process_fn(example) data = { "data_source": data_source, - "prompt": [{ - "role": "user", - "content": question - }], + "prompt": [{"role": "user", "content": question}], "ability": ability, - "reward_model": { - "style": "rule", - "ground_truth": solution - }, - "extra_info": { - 'split': split, - 'index': idx - } + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, } return data def build_aime2024_dataset(): - def process_aime2024(example): return example["Problem"], str(example["Answer"]) - data_source = 'Maxwell-Jia/AIME_2024' + data_source = "Maxwell-Jia/AIME_2024" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, split="train") - map_fn = partial(example_map_fn, - process_fn=process_aime2024, - data_source=data_source, - ability="English", - split="test") + map_fn = partial( + example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test" + ) dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) return dataset def build_gpqa_dimond_dataset(): import random + GPQA_QUERY_TEMPLATE = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" def process_gpqa_diamond(example): @@ -70,49 +60,40 @@ def build_gpqa_dimond_dataset(): random.shuffle(choices) gold_index = random.randint(0, 3) choices.insert(gold_index, example["Correct Answer"]) - query_prompt = GPQA_QUERY_TEMPLATE.format(A=choices[0], - B=choices[1], - C=choices[2], - D=choices[3], - Question=example["Question"]) + query_prompt = GPQA_QUERY_TEMPLATE.format( + A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"] + ) gold_choice = "ABCD"[gold_index] return query_prompt, gold_choice - data_source = 'Idavidrein/gpqa' + data_source = "Idavidrein/gpqa" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, "gpqa_diamond", split="train") - map_fn = partial(example_map_fn, - process_fn=process_gpqa_diamond, - data_source=data_source, - ability="Math", - split="test") + map_fn = partial( + example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test" + ) dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) return dataset def build_cnmo2024_dataset(): - def process_cnmo2024(example): return example["question"], example["answer"] - data_source = 'opencompass/LiveMathBench' + data_source = "opencompass/LiveMathBench" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test") - map_fn_en = partial(example_map_fn, - process_fn=process_cnmo2024, - data_source='opencompass/cnmo2024_en', - ability="Math", - split="test") + map_fn_en = partial( + example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test" + ) dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names) dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test") - map_fn_zh = partial(example_map_fn, - process_fn=process_cnmo2024, - data_source='opencompass/cnmo2024_zh', - ability="Math", - split="test") + map_fn_zh = partial( + example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test" + ) dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names) dataset = concatenate_datasets([dataset_en, dataset_zh]) @@ -120,22 +101,28 @@ def build_cnmo2024_dataset(): def build_livecodebench_dataset(): - import json, pickle, zlib, base64 + import base64 + import json + import pickle + import zlib def process_livecodebench(example): # Construct Query Prompt # From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140 query_prompt = ( "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests.\n\n" - f"Question: {example['question_content']}\n\n") + f"Question: {example['question_content']}\n\n" + ) if example["starter_code"]: query_prompt += ( "You will use the following starter code to write the solution to the problem and enclose your code within delimiters.\n" - f"```python\n{example['starter_code']}\n```") + f"```python\n{example['starter_code']}\n```" + ) else: query_prompt += ( "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." - f"```python\n# YOUR CODE HERE\n```") + "```python\n# YOUR CODE HERE\n```" + ) # Construct test cases public_test_cases = json.loads(example["public_test_cases"]) @@ -143,7 +130,8 @@ def build_livecodebench_dataset(): private_test_cases = json.loads(example["private_test_cases"]) except: private_test_cases = json.loads( - pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8"))))) + pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))) + ) full_test_cases = public_test_cases + private_test_cases metadata = json.loads(example["metadata"]) @@ -155,16 +143,14 @@ def build_livecodebench_dataset(): text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8") return query_prompt, text_cases_compressed - data_source = 'livecodebench/code_generation_lite' + data_source = "livecodebench/code_generation_lite" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, split="test") # R1 Evaluation use LiveCodeBench 24.08-25.01 dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00") - map_fn = partial(example_map_fn, - process_fn=process_livecodebench, - data_source=data_source, - ability="Code", - split="test") + map_fn = partial( + example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test" + ) dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8) return dataset @@ -178,18 +164,18 @@ TASK2DATA = { } SUPPORTED_TASKS = TASK2DATA.keys() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='~/data/r1') - parser.add_argument('--hdfs_dir', default=None) - parser.add_argument('--tasks', default="all") + parser.add_argument("--local_dir", default="~/data/r1") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--tasks", default="all") args = parser.parse_args() if args.tasks.lower() == "all": args.tasks = SUPPORTED_TASKS else: - args.tasks = [task.strip() for task in args.tasks.split(',') if task.strip()] + args.tasks = [task.strip() for task in args.tasks.split(",") if task.strip()] for task in args.tasks: if task not in SUPPORTED_TASKS: raise NotImplementedError(f"{task} has not been supported.") @@ -202,7 +188,7 @@ if __name__ == '__main__': local_dir = args.local_dir hdfs_dir = args.hdfs_dir - test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) if hdfs_dir is not None: makedirs(hdfs_dir) diff --git a/recipe/r1/main_eval.py b/recipe/r1/main_eval.py index 40e871cf8..5766a46f4 100644 --- a/recipe/r1/main_eval.py +++ b/recipe/r1/main_eval.py @@ -17,17 +17,20 @@ The input is a parquet file that contains N generated sequences and (optional) t """ -import hydra -from verl.utils.fs import copy_to_local -import pandas as pd -import numpy as np -from tqdm import tqdm from collections import defaultdict + +import hydra +import numpy as np +import pandas as pd import ray +from tqdm import tqdm + +from verl.utils.fs import copy_to_local def get_custom_reward_fn(config): - import importlib.util, os + import importlib.util + import os reward_fn_config = config.get("custom_reward_function") or {} file_path = reward_fn_config.get("path") @@ -56,12 +59,12 @@ def get_custom_reward_fn(config): @ray.remote def process_item(reward_fn, data_source, response_lst, reward_data): - ground_truth = reward_data['ground_truth'] + ground_truth = reward_data["ground_truth"] score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] return data_source, np.mean(score_lst) -@hydra.main(config_path='config', config_name='evaluation', version_base=None) +@hydra.main(config_path="config", config_name="evaluation", version_base=None) def main(config): local_path = copy_to_local(config.data.path) dataset = pd.read_parquet(local_path) @@ -97,10 +100,10 @@ def main(config): metric_dict = {} for data_source, rewards in data_source_reward.items(): - metric_dict[f'test_score/{data_source}'] = np.mean(rewards) + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) print(metric_dict) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/recipe/r1/reward_score.py b/recipe/r1/reward_score.py index 2fe2e28e1..2010665aa 100644 --- a/recipe/r1/reward_score.py +++ b/recipe/r1/reward_score.py @@ -14,14 +14,17 @@ def reward_func(data_source, solution_str, ground_truth, extra_info=None): - if data_source in ['Maxwell-Jia/AIME_2024', "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]: + if data_source in ["Maxwell-Jia/AIME_2024", "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]: from recipe.r1.tasks import math + return math.compute_score(solution_str, ground_truth) - elif data_source == 'Idavidrein/gpqa': + elif data_source == "Idavidrein/gpqa": from recipe.r1.tasks import gpqa + return gpqa.compute_score(solution_str, ground_truth) - elif data_source in ['livecodebench/code_generation_lite', 'livecodebench/code_generation']: + elif data_source in ["livecodebench/code_generation_lite", "livecodebench/code_generation"]: from recipe.r1.tasks import livecodebench + return livecodebench.compute_score(solution_str, ground_truth) else: raise NotImplementedError diff --git a/recipe/r1/tasks/livecodebench.py b/recipe/r1/tasks/livecodebench.py index 38c26e96f..22fc4965a 100644 --- a/recipe/r1/tasks/livecodebench.py +++ b/recipe/r1/tasks/livecodebench.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing +import base64 import json +import multiprocessing import pickle import zlib -import base64 # Reuse `run_test` for convenience from verl.utils.reward_score.prime_code.testing_util import run_test @@ -48,12 +48,12 @@ def check_correctness(in_outs, generation, timeout, debug=True): # consider that all tests failed result = [[-1 for i in range(len(in_outs["inputs"]))]] if debug: - print(f"global timeout") + print("global timeout") return result[0], metadata_list[0] def compute_score(completion, test_cases): - solution = completion.split('```python')[-1].split('```')[0] + solution = completion.split("```python")[-1].split("```")[0] # extract test cases try: @@ -65,7 +65,7 @@ def compute_score(completion, test_cases): try: res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False) success = all(map(lambda x: x == True, res)) - except Exception as e: + except Exception: pass return success diff --git a/recipe/r1/tasks/math.py b/recipe/r1/tasks/math.py index 83a83bb04..a06fb8cd9 100644 --- a/recipe/r1/tasks/math.py +++ b/recipe/r1/tasks/math.py @@ -14,7 +14,7 @@ try: from math_verify.metric import math_metric - from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig + from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig except ImportError: print("To use Math-Verify, please install it first by running `pip install math-verify`.") @@ -24,13 +24,13 @@ def compute_score(model_output: str, ground_truth: str) -> bool: gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), ) - ret_score = 0. + ret_score = 0.0 # Wrap the ground truth in \boxed{} format for verification ground_truth_boxed = "\\boxed{" + ground_truth + "}" try: ret_score, _ = verify_func([ground_truth_boxed], [model_output]) - except Exception as e: + except Exception: pass return ret_score diff --git a/requirements.txt b/requirements.txt index 596ec4cde..1ce8baf8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ peft pyarrow>=15.0.0 pybind11 pylatexenc -pylint==3.3.6 +pre-commit ray[default] tensordict<=0.6.2 torchdata diff --git a/scripts/converter_hf_to_mcore.py b/scripts/converter_hf_to_mcore.py index c7c10b3fc..eca361f4a 100644 --- a/scripts/converter_hf_to_mcore.py +++ b/scripts/converter_hf_to_mcore.py @@ -13,47 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Dict -import re -import os -import torch import argparse +import os import warnings -import numpy as np -from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq -from concurrent.futures import ThreadPoolExecutor -from safetensors.torch import load_file -from torch.distributed._tensor import Shard, Placement -from verl.utils.megatron_utils import get_model, convert_config -from megatron.core.models.gpt.gpt_model import ModelType -from megatron.core import parallel_state as mpu + +import torch from megatron.core import dist_checkpointing +from megatron.core import parallel_state as mpu from megatron.core.dist_checkpointing.serialization import StrictHandling +from megatron.core.models.gpt.gpt_model import ModelType +from transformers import AutoConfig, AutoModelForCausalLM + +from verl.utils.megatron_utils import convert_config, get_model def _init_args(): parser = argparse.ArgumentParser() - parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model") - parser.add_argument('--output_path', type=str, required=True, help="The path for the output mcore model") - parser.add_argument('--test', action='store_true', help="Whether to test the conversion") + parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") + parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model") + parser.add_argument("--test", action="store_true", help="Whether to test the conversion") args = parser.parse_args() return args class MegatronConfig: - def __init__(self): self.params_dtype = torch.bfloat16 class ModelConfig: - def __init__(self): self.path = None class Config: - def __init__(self): self.model = ModelConfig() @@ -65,15 +58,17 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): return # init torch distributed and mpu - os.environ['RANK'] = '0' - os.environ['WORLD_SIZE'] = '1' - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - torch.distributed.init_process_group('nccl') - mpu.initialize_model_parallel(tensor_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - context_parallel_size=1, - expert_model_parallel_size=1) + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl") + mpu.initialize_model_parallel( + tensor_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=1, + expert_model_parallel_size=1, + ) # init hf config hf_config = AutoConfig.from_pretrained(hf_model_path) @@ -87,17 +82,20 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): # init megatron model def megatron_model_provider(pre_process, post_process): from verl.utils.model import get_parallel_gptmodel_from_config - parallel_model = get_parallel_gptmodel_from_config(tfconfig, - hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=tie_word_embeddings, - value=False) + + parallel_model = get_parallel_gptmodel_from_config( + tfconfig, + hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=tie_word_embeddings, + value=False, + ) return parallel_model - model = get_model(model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True) + model = get_model( + model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -108,11 +106,14 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): # load hf state dict to megatron model from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - load_state_dict_to_megatron_gptmodel(state_dict=ref_state_dict, - wrapped_models=model, - config=hf_config, - params_dtype=torch.bfloat16, - is_value_model=False) + + load_state_dict_to_megatron_gptmodel( + state_dict=ref_state_dict, + wrapped_models=model, + config=hf_config, + params_dtype=torch.bfloat16, + is_value_model=False, + ) ssd = model[0].module.module.sharded_state_dict() del ref_state_dict, hf_model @@ -122,9 +123,9 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): if test: ########### test ########### # load model - model_test = get_model(model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True) + model_test = get_model( + model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True + ) ssd2 = model_test[0].module.module.sharded_state_dict() dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED) @@ -136,7 +137,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): d1 = sd[k].data if k in sd2: d2 = sd2[k].data - assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' + assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}" assert (d1 == d2).all(), f"{k} is not equal" for k in sd2.keys(): if sd2[k] is None: @@ -144,24 +145,24 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): d1 = sd2[k].data if k in sd: d2 = sd[k].data - assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' + assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}" assert (d1 == d2).all(), f"{k} is not equal" # load value model def megatron_value_model_provider(pre_process, post_process): from verl.utils.model import get_parallel_gptmodel_from_config - parallel_model = get_parallel_gptmodel_from_config(tfconfig, - hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True) + + parallel_model = get_parallel_gptmodel_from_config( + tfconfig, hf_config, pre_process, post_process, share_embeddings_and_output_weights=False, value=True + ) parallel_model.cuda() return parallel_model - model_value = get_model(model_provider_func=megatron_value_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True) + model_value = get_model( + model_provider_func=megatron_value_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + ) ssd2 = model_value[0].module.module.sharded_state_dict() dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL) @@ -173,7 +174,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): d1 = sd[k].data if k in sd2: d2 = sd2[k].data - assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' + assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}" assert (d1 == d2).all(), f"{k} is not equal" for k in sd2.keys(): if sd2[k] is None: @@ -181,7 +182,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False): d1 = sd2[k].data if k in sd: d2 = sd[k].data - assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' + assert d1.shape == d2.shape, f"{k=} {d1.shape=} {d2.shape=}" assert (d1 == d2).all(), f"{k} is not equal" diff --git a/scripts/diagnose.py b/scripts/diagnose.py index eed205014..cc7a25fc3 100644 --- a/scripts/diagnose.py +++ b/scripts/diagnose.py @@ -14,28 +14,35 @@ """Diagnose script for checking OS/hardware/python/pip/verl/network. The output of this script can be a very good hint to issue/problem. """ + +import os +import platform +import socket import subprocess +import sys +import time + import psutil -import platform, subprocess, sys, os -import socket, time + try: - from urllib.request import urlopen from urllib.parse import urlparse + from urllib.request import urlopen except ImportError: - from urlparse import urlparse from urllib2 import urlopen + from urlparse import urlparse import argparse import importlib.metadata + import torch URLS = { - 'PYPI': 'https://pypi.python.org/pypi/pip', + "PYPI": "https://pypi.python.org/pypi/pip", } REGIONAL_URLS = { - 'cn': { - 'PYPI(douban)': 'https://pypi.douban.com/', - 'Conda(tsinghua)': 'https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/', + "cn": { + "PYPI(douban)": "https://pypi.douban.com/", + "Conda(tsinghua)": "https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/", } } @@ -47,7 +54,7 @@ def test_connection(name, url, timeout=10): try: ip = socket.gethostbyname(urlinfo.netloc) except Exception as e: - print('Error resolving DNS for {}: {}, {}'.format(name, url, e)) + print("Error resolving DNS for {}: {}, {}".format(name, url, e)) return dns_elapsed = time.time() - start start = time.time() @@ -61,26 +68,27 @@ def test_connection(name, url, timeout=10): def check_python(): - print('----------Python Info----------') - print('Version :', platform.python_version()) - print('Compiler :', platform.python_compiler()) - print('Build :', platform.python_build()) - print('Arch :', platform.architecture()) + print("----------Python Info----------") + print("Version :", platform.python_version()) + print("Compiler :", platform.python_compiler()) + print("Build :", platform.python_build()) + print("Arch :", platform.architecture()) def check_pip(): - print('------------Pip Info-----------') + print("------------Pip Info-----------") try: import pip - print('Version :', pip.__version__) - print('Directory :', os.path.dirname(pip.__file__)) + + print("Version :", pip.__version__) + print("Directory :", os.path.dirname(pip.__file__)) except ImportError: - print('No corresponding pip install for current python.') + print("No corresponding pip install for current python.") def _get_current_git_commit(): try: - result = subprocess.run(['git', 'rev-parse', 'HEAD'], capture_output=True, text=True, check=True) + result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True) return result.stdout.strip() except subprocess.CalledProcessError as e: print(f"Error running git command: {e.stderr.strip()}") @@ -91,22 +99,24 @@ def _get_current_git_commit(): def check_verl(): - print('----------verl Info-----------') + print("----------verl Info-----------") try: sys.path.insert(0, os.getcwd()) import verl - print('Version :', verl.__version__) + + print("Version :", verl.__version__) verl_dir = os.path.dirname(verl.__file__) - print('Directory :', verl_dir) + print("Directory :", verl_dir) try: commit_hash = _get_current_git_commit() - print('Commit Hash :', commit_hash) + print("Commit Hash :", commit_hash) except AttributeError: - print('Commit hash not found. ') + print("Commit hash not found. ") except ImportError as e: - print(f'No verl installed: {e}') + print(f"No verl installed: {e}") except Exception as e: import traceback + if not isinstance(e, IOError): print("An error occurred trying to import verl.") print("This is very likely due to missing missing or incompatible library files.") @@ -114,36 +124,36 @@ def check_verl(): def check_os(): - print('----------Platform Info----------') - print('Platform :', platform.platform()) - print('system :', platform.system()) - print('node :', platform.node()) - print('release :', platform.release()) - print('version :', platform.version()) + print("----------Platform Info----------") + print("Platform :", platform.platform()) + print("system :", platform.system()) + print("node :", platform.node()) + print("release :", platform.release()) + print("version :", platform.version()) def check_hardware(): - print('----------Hardware Info----------') - print('machine :', platform.machine()) - print('processor :', platform.processor()) - if sys.platform.startswith('darwin'): - pipe = subprocess.Popen(('sysctl', '-a'), stdout=subprocess.PIPE) + print("----------Hardware Info----------") + print("machine :", platform.machine()) + print("processor :", platform.processor()) + if sys.platform.startswith("darwin"): + pipe = subprocess.Popen(("sysctl", "-a"), stdout=subprocess.PIPE) output = pipe.communicate()[0] - for line in output.split(b'\n'): - if b'brand_string' in line or b'features' in line: + for line in output.split(b"\n"): + if b"brand_string" in line or b"features" in line: print(line.strip()) - elif sys.platform.startswith('linux'): - subprocess.call(['lscpu']) - elif sys.platform.startswith('win32'): - subprocess.call(['wmic', 'cpu', 'get', 'name']) + elif sys.platform.startswith("linux"): + subprocess.call(["lscpu"]) + elif sys.platform.startswith("win32"): + subprocess.call(["wmic", "cpu", "get", "name"]) def check_network(args): - print('----------Network Test----------') + print("----------Network Test----------") if args.timeout > 0: - print('Setting timeout: {}'.format(args.timeout)) + print("Setting timeout: {}".format(args.timeout)) socket.setdefaulttimeout(10) - for region in args.region.strip().split(','): + for region in args.region.strip().split(","): r = region.strip().lower() if not r: continue @@ -151,20 +161,21 @@ def check_network(args): URLS.update(REGIONAL_URLS[r]) else: import warnings - warnings.warn('Region {} do not need specific test, please refer to global sites.'.format(r)) + + warnings.warn("Region {} do not need specific test, please refer to global sites.".format(r)) for name, url in URLS.items(): test_connection(name, url, args.timeout) def check_environment(): - print('----------Environment----------') + print("----------Environment----------") for k, v in os.environ.items(): - if k.startswith('VERL_') or k.startswith('OMP_') or k.startswith('KMP_') or k == 'CC' or k == 'CXX': + if k.startswith("VERL_") or k.startswith("OMP_") or k.startswith("KMP_") or k == "CC" or k == "CXX": print('{}="{}"'.format(k, v)) def check_pip_package_versions(): - packages = ['vllm', 'sglang', 'ray', 'torch'] + packages = ["vllm", "sglang", "ray", "torch"] for package in packages: try: version = importlib.metadata.version(package) @@ -179,8 +190,9 @@ def check_cuda_versions(): cuda_runtime_version = torch.version.cuda print(f"CUDA Runtime : {cuda_runtime_version}") import subprocess - nvcc_output = subprocess.check_output(['nvcc', '--version']).decode('utf-8') - cuda_compiler_version = next((line for line in nvcc_output.splitlines() if 'release' in line), None) + + nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") + cuda_compiler_version = next((line for line in nvcc_output.splitlines() if "release" in line), None) if cuda_compiler_version: print(f"CUDA Compiler : {cuda_compiler_version.strip()}") else: @@ -206,19 +218,23 @@ def _get_gpu_info(): Get GPU type, GPU memory, and GPU count using nvidia-smi command. """ try: - result = subprocess.run(['nvidia-smi', '--query-gpu=gpu_name,memory.total', '--format=csv,noheader,nounits'], - capture_output=True, - text=True, - check=True) - gpu_lines = result.stdout.strip().split('\n') + result = subprocess.run( + ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + check=True, + ) + gpu_lines = result.stdout.strip().split("\n") gpu_count = len(gpu_lines) gpu_info = [] for line in gpu_lines: - gpu_name, gpu_memory = line.split(', ') - gpu_info.append({ - 'type': gpu_name, - 'memory': float(gpu_memory) / 1024 # Convert to GB - }) + gpu_name, gpu_memory = line.split(", ") + gpu_info.append( + { + "type": gpu_name, + "memory": float(gpu_memory) / 1024, # Convert to GB + } + ) return gpu_count, gpu_info except subprocess.CalledProcessError: print("Failed to execute nvidia-smi command.") @@ -231,39 +247,43 @@ def _get_system_info(): """ cpu_memory = _get_cpu_memory() gpu_count, gpu_info = _get_gpu_info() - return {'cpu_memory': cpu_memory, 'gpu_count': gpu_count, 'gpu_info': gpu_info} + return {"cpu_memory": cpu_memory, "gpu_count": gpu_count, "gpu_info": gpu_info} def check_system_info(): - print('----------System Info----------') + print("----------System Info----------") system_info = _get_system_info() print(f"CPU Memory\t: {system_info['cpu_memory']:.2f} GB") print(f"GPU Count\t: {system_info['gpu_count']}") - for i, gpu in enumerate(system_info['gpu_info']): + for i, gpu in enumerate(system_info["gpu_info"]): print(f"GPU {i + 1}\tType : {gpu['type']}") print(f"GPU {i + 1}\tMemory : {gpu['memory']:.2f} GB") def parse_args(): """Parse arguments.""" - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description='Diagnose script for checking the current system.') - choices = ['python', 'pip', 'verl', 'system', 'os', 'environment'] + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Diagnose script for checking the current system.", + ) + choices = ["python", "pip", "verl", "system", "os", "environment"] for choice in choices: - parser.add_argument('--' + choice, default=1, type=int, help='Diagnose {}.'.format(choice)) - parser.add_argument('--network', default=0, type=int, help='Diagnose network.') - parser.add_argument('--hardware', default=0, type=int, help='Diagnose hardware.') - parser.add_argument('--region', - default='', - type=str, - help="Additional sites in which region(s) to test. \ - Specify 'cn' for example to test mirror sites in China.") - parser.add_argument('--timeout', default=10, type=int, help="Connection test timeout threshold, 0 to disable.") + parser.add_argument("--" + choice, default=1, type=int, help="Diagnose {}.".format(choice)) + parser.add_argument("--network", default=0, type=int, help="Diagnose network.") + parser.add_argument("--hardware", default=0, type=int, help="Diagnose hardware.") + parser.add_argument( + "--region", + default="", + type=str, + help="Additional sites in which region(s) to test. \ + Specify 'cn' for example to test mirror sites in China.", + ) + parser.add_argument("--timeout", default=10, type=int, help="Connection test timeout threshold, 0 to disable.") args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if args.python: check_python() diff --git a/scripts/format.sh b/scripts/format.sh deleted file mode 100755 index cd2d2d575..000000000 --- a/scripts/format.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -pip3 install --upgrade yapf -python3 -m yapf -ir -vv --style ./.style.yapf verl tests examples recipe scripts diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 72e4a6d0a..4901fa0d0 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Dict -import re -import os -import torch import argparse -import numpy as np -from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq +import os +import re from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Tuple + +import numpy as np +import torch from safetensors.torch import load_file -from torch.distributed._tensor import Shard, Placement +from torch.distributed._tensor import Placement, Shard +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq + try: # for torch 2.5+ from torch.distributed.tensor import DTensor @@ -29,28 +31,31 @@ except ImportError: from torch.distributed._tensor import DTensor parser = argparse.ArgumentParser() -parser.add_argument('--backend', type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"]) -parser.add_argument('--tie-word-embedding', action='store_true', help="Whether to tie word embedding weights") -parser.add_argument('--is-value-model', action='store_true', help="Whether the model loaded as value model") -parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model") +parser.add_argument("--backend", type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"]) +parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights") +parser.add_argument("--is-value-model", action="store_true", help="Whether the model loaded as value model") +parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") parser.add_argument( - '--local_dir', + "--local_dir", type=str, required=True, - help= - "The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`." + help="The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`.", ) -parser.add_argument('--target_dir', required=False, default="tmp", type=str, help="The path for the target model") +parser.add_argument("--target_dir", required=False, default="tmp", type=str, help="The path for the target model") parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") parser.add_argument("--test", action="store_true", help="test correctness of hf_model") -parser.add_argument("--test_hf_dir", - type=str, - required=False, - help="test correctness of hf_model, , with hf_model in checkpoint.contents") +parser.add_argument( + "--test_hf_dir", + type=str, + required=False, + help="test correctness of hf_model, , with hf_model in checkpoint.contents", +) args = parser.parse_args() os.makedirs(args.target_dir, exist_ok=True) if args.test: - assert args.test_hf_dir is not None, f'You must run verl save checkpoint first, with hf_model in checkpoint.contents, and provide the directory here' + assert args.test_hf_dir is not None, ( + "You must run verl save checkpoint first, with hf_model in checkpoint.contents, and provide the directory here" + ) def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): @@ -67,6 +72,7 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): def upload_model_to_huggingface(hf_path): # Push to hugging face from huggingface_hub import HfApi + api = HfApi() api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True) api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model") @@ -85,9 +91,9 @@ def convert_fsdp_checkpoints_to_hfmodels(): break assert world_size, "No model file with the proper format" - state_dict = torch.load(os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt'), - map_location='cpu', - weights_only=False) + state_dict = torch.load( + os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu", weights_only=False + ) pivot_key = sorted(list(state_dict.keys()))[0] weight = state_dict[pivot_key] @@ -99,13 +105,13 @@ def convert_fsdp_checkpoints_to_hfmodels(): else: # for non-DTensor mesh = np.array([int(world_size)], dtype=np.int64) - mesh_dim_names = ('fsdp',) + mesh_dim_names = ("fsdp",) - print(f'Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}') + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - assert mesh_dim_names in (('fsdp',), ('ddp', 'fsdp')), f'Unsupported mesh_dim_names {mesh_dim_names}' + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" - if 'tp' in mesh_dim_names: + if "tp" in mesh_dim_names: # fsdp * tp total_shards = mesh.shape[-1] * mesh.shape[-2] mesh_shape = (mesh.shape[-2], mesh.shape[-1]) @@ -114,21 +120,21 @@ def convert_fsdp_checkpoints_to_hfmodels(): total_shards = mesh.shape[-1] mesh_shape = (mesh.shape[-1],) - print(f'Processing model shards with {total_shards} {mesh_shape} in total') + print(f"Processing model shards with {total_shards} {mesh_shape} in total") model_state_dict_lst = [] model_state_dict_lst.append(state_dict) model_state_dict_lst.extend([""] * (total_shards - 1)) - def process_one_shard(rank): - model_path = os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt') - state_dict = torch.load(model_path, map_location='cpu', weights_only=False) + def process_one_shard(rank, model_state_dict_lst): + model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) model_state_dict_lst[rank] = state_dict return state_dict with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: for rank in range(1, total_shards): - executor.submit(process_one_shard, rank) + executor.submit(process_one_shard, rank, model_state_dict_lst) state_dict = {} param_placements: Dict[str, List[Placement]] = {} keys = set(model_state_dict_lst[0].keys()) @@ -144,9 +150,7 @@ def convert_fsdp_checkpoints_to_hfmodels(): state_dict[key].append(tensor._local_tensor.bfloat16()) placements = tuple(tensor.placements) # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] == 'dp': - placements = placements[1:] - elif mesh_dim_names[0] == 'ddp': + if mesh_dim_names[0] == "dp" or mesh_dim_names[0] == "ddp": placements = placements[1:] if key not in param_placements: param_placements[key] = placements @@ -175,27 +179,27 @@ def convert_fsdp_checkpoints_to_hfmodels(): else: state_dict[key] = torch.cat(state_dict[key], dim=0) - print('Writing to local disk') + print("Writing to local disk") if args.target_dir is None: - hf_path = os.path.join(local_dir, 'huggingface') + hf_path = os.path.join(local_dir, "huggingface") else: hf_path = args.target_dir config = AutoConfig.from_pretrained(args.hf_model_path) - if 'ForTokenClassification' in config.architectures[0]: + if "ForTokenClassification" in config.architectures[0]: auto_model = AutoModelForTokenClassification - elif 'ForCausalLM' in config.architectures[0]: + elif "ForCausalLM" in config.architectures[0]: auto_model = AutoModelForCausalLM - elif 'ForConditionalGeneration' in config.architectures[0]: + elif "ForConditionalGeneration" in config.architectures[0]: auto_model = AutoModelForVision2Seq else: - raise NotImplementedError(f'Unknown architecture {config["architectures"]}') + raise NotImplementedError(f"Unknown architecture {config['architectures']}") - with torch.device('meta'): + with torch.device("meta"): model = auto_model.from_config(config, torch_dtype=torch.bfloat16) - model.to_empty(device='cpu') + model.to_empty(device="cpu") - print(f'Saving model to {hf_path}') + print(f"Saving model to {hf_path}") model.save_pretrained(hf_path, state_dict=state_dict) del state_dict del model @@ -217,7 +221,7 @@ def check_megatron_checkpoint_path(model_path): for sharded_dir in sharded_dirs: match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir) assert match, f"Invalid sharded dir {sharded_dir}" - assert f"model.pt" in os.listdir(os.path.join(model_path, sharded_dir)), f"model.pt not found in {sharded_dir}" + assert "model.pt" in os.listdir(os.path.join(model_path, sharded_dir)), f"model.pt not found in {sharded_dir}" tp_rank = int(match.group(1)) pp_rank = int(match.group(2)) if tp_size < tp_rank + 1: @@ -228,7 +232,7 @@ def check_megatron_checkpoint_path(model_path): def convert_megatron_checkpoints_to_hfmodels(): - from verl.utils.megatron_utils import get_model_checkpoint_path, get_hf_model_checkpoint_path + from verl.utils.megatron_utils import get_hf_model_checkpoint_path, get_model_checkpoint_path local_path = args.local_dir @@ -243,11 +247,11 @@ def convert_megatron_checkpoints_to_hfmodels(): for j in range(tp_size): model_state_dict_lst[i].append("") - print(f'sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {mp_size}') + print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {mp_size}") - def process_one_shard(shard_dir): + def process_one_shard(shard_dir, model_state_dict_lst): model_path = os.path.join(model_ckpt_path, shard_dir, "model.pt") - state_dict = torch.load(model_path, map_location='cpu', weights_only=False) + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) tp_rank, pp_rank = get_tp_pp_rank_from_sharded_dir(shard_dir) model_state_dict_lst[pp_rank][tp_rank] = state_dict @@ -255,12 +259,12 @@ def convert_megatron_checkpoints_to_hfmodels(): # for rank in range(1, mp_size): # executor.submit(process_one_shard, sharded_dirs[rank]) for sharded_dir in sharded_dirs: - process_one_shard(sharded_dir) + process_one_shard(sharded_dir, model_state_dict_lst) state_dict = {} config = AutoConfig.from_pretrained(args.hf_model_path) if args.test: - ref_state_dict = load_file(os.path.join(args.test_hf_dir, 'model.safetensors')) + ref_state_dict = load_file(os.path.join(args.test_hf_dir, "model.safetensors")) def merge_across_tp(key, tp_data): if "linear_fc1.weight" in key: @@ -274,7 +278,7 @@ def convert_megatron_checkpoints_to_hfmodels(): gate = torch.cat(gate_lst, dim=0) up = torch.cat(up_lst, dim=0) tp_data = [gate, up] - elif "self_attention.linear_qkv." in key and 'layer_norm' not in key: + elif "self_attention.linear_qkv." in key and "layer_norm" not in key: # if the tensor is qkv, for each param on tp, split into q, k, v # concat q, k, v separately. q_lst = [] @@ -291,7 +295,7 @@ def convert_megatron_checkpoints_to_hfmodels(): split_size = [ kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition + kv_size_per_tp // num_query_groups_per_partition, ] q, k, v = chunk.split(split_size) q_lst.append(q) @@ -323,16 +327,16 @@ def convert_megatron_checkpoints_to_hfmodels(): if "extra_state" in key: continue if args.tie_word_embedding and ("output_layer" in key): - print(f'skip lm_head and reward_head loading because of tie_word_embeddings') + print("skip lm_head and reward_head loading because of tie_word_embeddings") continue new_key = key if "decoder.layers." in key: - local_layer_no = int(key.split('.')[2]) + local_layer_no = int(key.split(".")[2]) layers_handled = max(local_layer_no, layers_handled) global_layer_no = local_layer_no + layers_cum - new_key_list = key.split('.') + new_key_list = key.split(".") new_key_list[2] = str(global_layer_no) - new_key = '.'.join(new_key_list) + new_key = ".".join(new_key_list) tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] merged = merge_across_tp(new_key, tp_data) @@ -340,7 +344,7 @@ def convert_megatron_checkpoints_to_hfmodels(): state_dict[new_key] = merged elif len(merged) == 3: # split qkv - for n, d in zip(['q', 'k', 'v'], merged): + for n, d in zip(["q", "k", "v"], merged): state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d elif len(merged) == 2: # split gate up @@ -370,7 +374,6 @@ def convert_megatron_checkpoints_to_hfmodels(): ] if args.test: - for original_name, loaded_weight in state_dict.items(): name = _replace_name(original_name, params_mapping) if not name or name.endswith(".bias") and name not in ref_state_dict: @@ -380,31 +383,31 @@ def convert_megatron_checkpoints_to_hfmodels(): if args.tie_word_embedding and "lm_head.weight" in name: continue if name not in ref_state_dict: - raise RuntimeError(f'key: {name} not exist in state_dict') + raise RuntimeError(f"key: {name} not exist in state_dict") param = ref_state_dict[name] assert loaded_weight.dtype == param.dtype torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4) - print('Writing to local disk') + print("Writing to local disk") if args.target_dir is None: - hf_path = os.path.join(args.local_dir, 'huggingface') + hf_path = os.path.join(args.local_dir, "huggingface") else: hf_path = args.target_dir - if 'ForTokenClassification' in config.architectures[0]: + if "ForTokenClassification" in config.architectures[0]: auto_model = AutoModelForTokenClassification - elif 'ForCausalLM' in config.architectures[0]: + elif "ForCausalLM" in config.architectures[0]: auto_model = AutoModelForCausalLM - elif 'ForConditionalGeneration' in config.architectures[0]: + elif "ForConditionalGeneration" in config.architectures[0]: auto_model = AutoModelForVision2Seq else: - raise NotImplementedError(f'Unknown architecture {config["architectures"]}') + raise NotImplementedError(f"Unknown architecture {config['architectures']}") - with torch.device('meta'): + with torch.device("meta"): model = auto_model.from_config(config, torch_dtype=torch.bfloat16) - model.to_empty(device='cpu') + model.to_empty(device="cpu") - print(f'Saving model to {hf_path}') + print(f"Saving model to {hf_path}") model.save_pretrained(hf_path, state_dict=state_dict) del state_dict del model @@ -435,7 +438,7 @@ def _replace_name(megatron_name, name_mapping): return param_name -if __name__ == '__main__': +if __name__ == "__main__": if args.backend == "fsdp": convert_fsdp_checkpoints_to_hfmodels() elif args.backend == "megatron": diff --git a/setup.py b/setup.py index 443ba5e8a..a65efef24 100644 --- a/setup.py +++ b/setup.py @@ -13,75 +13,79 @@ # limitations under the License. # setup.py is the fallback installation script when pyproject.toml does not work -from setuptools import setup, find_packages import os +from pathlib import Path + +from setuptools import find_packages, setup version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) -with open(os.path.join(version_folder, 'verl/version/version')) as f: +with open(os.path.join(version_folder, "verl/version/version")) as f: __version__ = f.read().strip() install_requires = [ - 'accelerate', - 'codetiming', - 'datasets', - 'dill', - 'hydra-core', - 'numpy', - 'pandas', - 'datasets', - 'peft', - 'pyarrow>=15.0.0', - 'pybind11', - 'pylatexenc', - 'ray[default]>=2.10', - 'tensordict<=0.6.2', - 'torchdata', - 'transformers', - 'wandb', + "accelerate", + "codetiming", + "datasets", + "dill", + "hydra-core", + "numpy", + "pandas", + "datasets", + "peft", + "pyarrow>=15.0.0", + "pybind11", + "pylatexenc", + "ray[default]>=2.10", + "tensordict<=0.6.2", + "torchdata", + "transformers", + "wandb", ] -TEST_REQUIRES = ['pytest', 'yapf', 'py-spy'] -PRIME_REQUIRES = ['pyext'] -GEO_REQUIRES = ['mathruler'] -GPU_REQUIRES = ['liger-kernel', 'flash-attn'] -MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency -VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2'] +TEST_REQUIRES = ["pytest", "pre-commit", "py-spy"] +PRIME_REQUIRES = ["pyext"] +GEO_REQUIRES = ["mathruler"] +GPU_REQUIRES = ["liger-kernel", "flash-attn"] +MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency +VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.2"] SGLANG_REQUIRES = [ - 'tensordict<=0.6.2', - 'sglang[all]==0.4.4.post4', - 'torch-memory-saver>=0.0.5' + "tensordict<=0.6.2", + "sglang[all]==0.4.4.post4", + "torch-memory-saver>=0.0.5", ] extras_require = { - 'test': TEST_REQUIRES, - 'prime': PRIME_REQUIRES, - 'geo': GEO_REQUIRES, - 'gpu': GPU_REQUIRES, - 'math': MATH_REQUIRES, - 'vllm': VLLM_REQUIRES, - 'sglang': SGLANG_REQUIRES, + "test": TEST_REQUIRES, + "prime": PRIME_REQUIRES, + "geo": GEO_REQUIRES, + "gpu": GPU_REQUIRES, + "math": MATH_REQUIRES, + "vllm": VLLM_REQUIRES, + "sglang": SGLANG_REQUIRES, } -from pathlib import Path + this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() setup( - name='verl', + name="verl", version=__version__, - package_dir={'': '.'}, - packages=find_packages(where='.'), - url='https://github.com/volcengine/verl', - license='Apache 2.0', - author='Bytedance - Seed - MLSys', - author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk', - description='verl: Volcano Engine Reinforcement Learning for LLM', + package_dir={"": "."}, + packages=find_packages(where="."), + url="https://github.com/volcengine/verl", + license="Apache 2.0", + author="Bytedance - Seed - MLSys", + author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk", + description="verl: Volcano Engine Reinforcement Learning for LLM", install_requires=install_requires, extras_require=extras_require, - package_data={'': ['version/*'], - 'verl': ['trainer/config/*.yaml'],}, + package_data={ + "": ["version/*"], + "verl": ["trainer/config/*.yaml"], + }, include_package_data=True, long_description=long_description, - long_description_content_type='text/markdown' -) \ No newline at end of file + long_description_content_type="text/markdown", +) diff --git a/tests/__init__.py b/tests/__init__.py index 7a7aadbc9..1ce90c5eb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/tests/checkpoint/test_fsdp_ckpt.py b/tests/checkpoint/test_fsdp_ckpt.py index c8e0bcc18..7e00fcea7 100644 --- a/tests/checkpoint/test_fsdp_ckpt.py +++ b/tests/checkpoint/test_fsdp_ckpt.py @@ -12,65 +12,65 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import tempfile import shutil +import tempfile + import torch -import copy import torch.distributed from torch.distributed import init_device_mesh -from verl.utils.distributed import initialize_global_process_group -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers import Qwen2Config +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ - CPUOffload +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.distributed import initialize_global_process_group def test_fsdp_ckpt(): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',)) + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) - model_name = 'Qwen/Qwen2.5-0.5B-Instruct' + model_name = "Qwen/Qwen2.5-0.5B-Instruct" config = Qwen2Config(num_hidden_layers=1) - with torch.device('cuda'): - model = AutoModelForCausalLM.from_config(config=config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2') - model = model.to(device='cuda') + with torch.device("cuda"): + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model = model.to(device="cuda") # Wrap model with FSDP mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - model = FSDP(model, - use_orig_params=False, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - device_mesh=device_mesh) + model = FSDP( + model, + use_orig_params=False, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=device_mesh, + ) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) # Create checkpoint manager tokenizer = AutoTokenizer.from_pretrained(model_name) - checkpoint_manager = FSDPCheckpointManager(model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - tokenizer=tokenizer) + checkpoint_manager = FSDPCheckpointManager( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer + ) # Generate sample input batch_size = 2 seq_len = 32 vocab_size = 32000 # First input for initial update - input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda') + input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") attention_mask1 = torch.ones_like(input_ids1) # Second input for verification - input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda') + input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") attention_mask2 = torch.ones_like(input_ids2) # Step 1: Initial update and save checkpoint @@ -83,7 +83,7 @@ def test_fsdp_ckpt(): # Save checkpoint after first update temp_dir = tempfile.mkdtemp() - checkpoint_path = os.path.join(temp_dir, 'checkpoint') + checkpoint_path = os.path.join(temp_dir, "checkpoint") checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) # Step 2: Second update and forward pass @@ -122,5 +122,5 @@ def test_fsdp_ckpt(): torch.distributed.barrier() -if __name__ == '__main__': +if __name__ == "__main__": test_fsdp_ckpt() diff --git a/tests/distributed/test_tensor_dict.py b/tests/distributed/test_tensor_dict.py index 5ec7618c5..27da6f5a2 100644 --- a/tests/distributed/test_tensor_dict.py +++ b/tests/distributed/test_tensor_dict.py @@ -14,104 +14,108 @@ import os -os.environ['NCCL_DEBUG'] = 'WARN' +os.environ["NCCL_DEBUG"] = "WARN" -from verl.protocol import all_gather_data_proto, DataProto -from verl.utils.distributed import initialize_global_process_group +import numpy as np import torch import torch.distributed -import numpy as np + +from verl.protocol import DataProto, all_gather_data_proto +from verl.utils.distributed import initialize_global_process_group def test_all_gather_data_proto(): - device_mesh = torch.distributed.device_mesh.init_device_mesh('cuda', mesh_shape=[2, 2], mesh_dim_names=['dp', 'tp']) + device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"]) global_rank = torch.distributed.get_rank() obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]]) - labels = ['a', 'b'] if global_rank % 2 == 0 else ['b', 'a'] + labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"] labels = np.array(labels, dtype=object) - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'}) + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - all_gather_data_proto(data=data, process_group=device_mesh.get_group('dp')) + all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp")) if global_rank == 0: - expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda') - expected_labels = ['a', 'b', 'a', 'b'] + expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda") + expected_labels = ["a", "b", "a", "b"] elif global_rank == 1: - expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda') - expected_labels = ['b', 'a', 'b', 'a'] + expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda") + expected_labels = ["b", "a", "b", "a"] elif global_rank == 2: - expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda') - expected_labels = ['a', 'b', 'a', 'b'] + expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda") + expected_labels = ["a", "b", "a", "b"] elif global_rank == 3: - expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda') - expected_labels = ['b', 'a', 'b', 'a'] + expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda") + expected_labels = ["b", "a", "b", "a"] - torch.testing.assert_close(data.batch['obs'], expected_obs, atol=0, rtol=0) - assert (data.non_tensor_batch['labels'] == expected_labels).all() - assert data.meta_info == {'info': 'test_info'} + torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0) + assert (data.non_tensor_batch["labels"] == expected_labels).all() + assert data.meta_info == {"info": "test_info"} def test_vocab_parallel_entropy(): - from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy - from verl.utils.debug import log_gpu_memory_usage - from verl.utils.torch_functional import entropy_from_logits - from megatron.core import parallel_state as mpu - mpu.initialize_model_parallel(tensor_model_parallel_size=2, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None) + from verl.utils.debug import log_gpu_memory_usage + from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy + from verl.utils.torch_functional import entropy_from_logits + + mpu.initialize_model_parallel( + tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None + ) batch_size = 2 seqlen = 128 vocab_size = 155136 - logits = torch.randn(batch_size * seqlen, vocab_size, device='cuda', requires_grad=True) - target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device='cuda', dtype=torch.int64) + logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True) + target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64) # broadcast across tp - torch.distributed.broadcast(logits, - mpu.get_tensor_model_parallel_src_rank(), - group=mpu.get_tensor_model_parallel_group()) - torch.distributed.broadcast(target, - mpu.get_tensor_model_parallel_src_rank(), - group=mpu.get_tensor_model_parallel_group()) + torch.distributed.broadcast( + logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) + torch.distributed.broadcast( + target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) tp_rank = mpu.get_tensor_model_parallel_rank() vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() # get the local logits of each tp - vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * - vocab_size_per_tp].requires_grad_() + vocab_parallel_logits = ( + logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_() + ) logits.grad = None vocab_parallel_logits.grad = None - log_gpu_memory_usage('begin') + log_gpu_memory_usage("begin") output_entropy = vocab_parallel_entropy(vocab_parallel_logits) - log_gpu_memory_usage('after forward') + log_gpu_memory_usage("after forward") grad_output = torch.randn_like(output_entropy) output_entropy.backward(grad_output) - log_gpu_memory_usage('after backward') + log_gpu_memory_usage("after backward") target_entropy = entropy_from_logits(logits) torch.testing.assert_close(output_entropy, target_entropy) target_entropy.backward(grad_output) - torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], - vocab_parallel_logits.grad) + torch.testing.assert_close( + logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad + ) # make sure logits is not altered - torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], - vocab_parallel_logits) + torch.testing.assert_close( + logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits + ) if mpu.get_tensor_model_parallel_rank() == 0: - print('test_vocab_parallel_entropy passes') + print("test_vocab_parallel_entropy passes") mpu.destroy_model_parallel() -if __name__ == '__main__': +if __name__ == "__main__": local_rank, rank, world_size = initialize_global_process_group() test_all_gather_data_proto() test_vocab_parallel_entropy() diff --git a/tests/e2e/arithmetic_sequence/data/create_dataset.py b/tests/e2e/arithmetic_sequence/data/create_dataset.py index e023a2917..1729fd6af 100644 --- a/tests/e2e/arithmetic_sequence/data/create_dataset.py +++ b/tests/e2e/arithmetic_sequence/data/create_dataset.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.e2e.envs.digit_completion import DigitCompletion, generate_ground_truth_response -from torch.utils import data import os -if __name__ == '__main__': +from torch.utils import data + +from tests.e2e.envs.digit_completion import DigitCompletion + +if __name__ == "__main__": simple_task = DigitCompletion(max_number=9, max_diff=9, max_num_in_response=9) all_prompts = simple_task.get_all_prompts() @@ -25,15 +27,13 @@ if __name__ == '__main__': train_data = list(train_data) test_data = list(test_data) - train_data = [[{'role': 'user', 'content': str(item)}] \ - for item in train_data] - test_data = [[{'role': 'user', 'content': str(item)}] \ - for item in test_data] + train_data = [[{"role": "user", "content": str(item)}] for item in train_data] + test_data = [[{"role": "user", "content": str(item)}] for item in test_data] - print(f'Size of train: {len(train_data)}, size of test: {len(test_data)}') + print(f"Size of train: {len(train_data)}, size of test: {len(test_data)}") - train_data = {'prompt': train_data} - test_data = {'prompt': test_data} + train_data = {"prompt": train_data} + test_data = {"prompt": test_data} model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__))) @@ -42,5 +42,5 @@ if __name__ == '__main__': train_data_frame = pd.DataFrame(train_data) test_data_frame = pd.DataFrame(test_data) - train_data_frame.to_parquet(os.path.join(model_folder, 'train.parquet')) - test_data_frame.to_parquet(os.path.join(model_folder, 'test.parquet')) + train_data_frame.to_parquet(os.path.join(model_folder, "train.parquet")) + test_data_frame.to_parquet(os.path.join(model_folder, "test.parquet")) diff --git a/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py b/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py index 9a3135f0c..88e9501ed 100644 --- a/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py +++ b/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py @@ -15,28 +15,30 @@ Create a random model and tokenizer for PPO training """ -import torch import os -from transformers import AutoModelForCausalLM, LlamaConfig, AutoTokenizer + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig from tests.e2e.envs.digit_completion import CharTokenizer tokenizer = CharTokenizer( - characters=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',', ':'], + characters=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ",", ":"], model_max_length=2048, - chat_template= - "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}" + chat_template="{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}", ) -config = LlamaConfig(vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16, - hidden_size=128, - intermediate_size=344, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=4, - pad_token_id=tokenizer.pad_token_id, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id) +config = LlamaConfig( + vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16, + hidden_size=128, + intermediate_size=344, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, +) model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) @@ -50,12 +52,11 @@ tokenizer.save_pretrained(tokenizer_folder) load_tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder) -chat = [{'role': 'user', 'content': '1,0:2,3'}] +chat = [{"role": "user", "content": "1,0:2,3"}] -load_tokenizer.padding_side = 'left' +load_tokenizer.padding_side = "left" print( - load_tokenizer.apply_chat_template(chat, - tokenize=True, - add_generation_prompt=True, - max_length=10, - padding='max_length')) + load_tokenizer.apply_chat_template( + chat, tokenize=True, add_generation_prompt=True, max_length=10, padding="max_length" + ) +) diff --git a/tests/e2e/arithmetic_sequence/rl/main_trainer.py b/tests/e2e/arithmetic_sequence/rl/main_trainer.py index c1861b987..41f6436e0 100644 --- a/tests/e2e/arithmetic_sequence/rl/main_trainer.py +++ b/tests/e2e/arithmetic_sequence/rl/main_trainer.py @@ -14,54 +14,55 @@ """ Using FSDPTrainer """ + import os + import hydra import ray import torch -from transformers import PreTrainedTokenizer, AutoTokenizer +from transformers import AutoTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.utils.fs import copy_to_local -from tests.e2e.envs.digit_completion import CharTokenizer def make_reward_function(tokenizer, num_examine): - def arithmetic_sequence_reward_function(data: DataProto, return_dict: bool = False): from tests.e2e.envs.digit_completion.task import compute_reward - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) for i in range(data.batch.batch_size[0]): data_item = data[i] # DataProtoItem - prompt_ids = data_item.batch['prompts'] + prompt_ids = data_item.batch["prompts"] prompt_length = prompt_ids.shape[-1] # extract raw prompt - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] # extract response - response_ids = data_item.batch['responses'] + response_ids = data_item.batch["responses"] response_length = response_ids.shape[-1] - response_mask = data.batch['attention_mask'][i][-response_length:] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + response_mask = data.batch["attention_mask"][i][-response_length:] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode prompt = tokenizer.decode(valid_prompt_ids) response = tokenizer.decode(valid_response_ids) # remove bos and eos - prompt = prompt.replace(tokenizer.sep_token, '') - response = response.replace(tokenizer.eos_token, '') + prompt = prompt.replace(tokenizer.sep_token, "") + response = response.replace(tokenizer.eos_token, "") if i < num_examine: print(prompt, response) reward_output = compute_reward(prompt, response) dense_reward = reward_output[0].tolist() - ground_truth_response = reward_output[1]['ground_truth_response'] + ground_truth_response = reward_output[1]["ground_truth_response"] if len(dense_reward) > 0: last_reward = dense_reward[-1] else: @@ -85,26 +86,29 @@ def make_reward_function(tokenizer, num_examine): return arithmetic_sequence_reward_function -@hydra.main(config_path='../../../../verl/trainer/config', config_name='ppo_trainer', version_base=None) +@hydra.main(config_path="../../../../verl/trainer/config", config_name="ppo_trainer", version_base=None) def main(config): ray.init( runtime_env={ - 'env_vars': { - 'MEGATRON_USE_CUDA_TIMER': '0', - 'MEGATRON_START_PROCESS_TIMER': 'False', - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN' + "env_vars": { + "MEGATRON_USE_CUDA_TIMER": "0", + "MEGATRON_START_PROCESS_TIMER": "False", + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", } - }) + } + ) # print initial config from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values # print the config # print initial config - print('Config after normalizing batch_size') + print("Config after normalizing batch_size") pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values # download the checkpoint from hdfs @@ -112,18 +116,18 @@ def main(config): local_path = os.path.expanduser(local_path) # instantiate tokenizern tokenizer = AutoTokenizer.from_pretrained(local_path) - print(f'Tokenizer vocab_size: {tokenizer.vocab_size}') + print(f"Tokenizer vocab_size: {tokenizer.vocab_size}") # define worker classes - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker role_worker_mapping = { Role.ActorRollout: ray.remote(ActorRolloutRefWorker), Role.Critic: ray.remote(CriticWorker), } - global_pool_id = 'global_pool' + global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } @@ -141,15 +145,17 @@ def main(config): resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - trainer = RayPPOTrainer(config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - reward_fn=reward_fn, - val_reward_fn=reward_fn) + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + reward_fn=reward_fn, + val_reward_fn=reward_fn, + ) trainer.init_workers() trainer.fit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/e2e/check_custom_rwd_fn.py b/tests/e2e/check_custom_rwd_fn.py index 4cabc6a2c..8d77a5372 100644 --- a/tests/e2e/check_custom_rwd_fn.py +++ b/tests/e2e/check_custom_rwd_fn.py @@ -16,17 +16,17 @@ import argparse def check_congratulations_in_file(output_file): - with open(output_file, 'r') as f: + with open(output_file) as f: output = f.read() success_message = "Congratulations!!! You have called my_reward_function successfully!!!" - assert success_message in output, f'Success message of my_reward_function not found in {output_file}' + assert success_message in output, f"Success message of my_reward_function not found in {output_file}" print("Check passes") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--output_file', required=True, type=str) + parser.add_argument("--output_file", required=True, type=str) args = parser.parse_args() diff --git a/tests/e2e/check_results.py b/tests/e2e/check_results.py index e20ee2a56..9453282fb 100644 --- a/tests/e2e/check_results.py +++ b/tests/e2e/check_results.py @@ -20,10 +20,10 @@ import numpy as np def extract_reward_from_line(line): # TODO: this function needs error handling try: - key_vals = line.split(' - ') + key_vals = line.split(" - ") for key_val in key_vals: - key, val = key_val.split(':') - if key == 'critic/rewards/mean': + key, val = key_val.split(":") + if key == "critic/rewards/mean": reward = float(val) return reward return -np.inf @@ -31,23 +31,23 @@ def extract_reward_from_line(line): return -np.inf -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--output_file', required=True, type=str) - parser.add_argument('--target', type=float, default=0.2, help='target reward score') + parser.add_argument("--output_file", required=True, type=str) + parser.add_argument("--target", type=float, default=0.2, help="target reward score") args = parser.parse_args() - with open(args.output_file, 'r') as f: - output = f.read().split('\n') + with open(args.output_file) as f: + output = f.read().split("\n") best_reward = -np.inf for line in output: - if line.startswith('step'): + if line.startswith("step"): reward = extract_reward_from_line(line) if reward > best_reward: best_reward = reward - print(f'Best reward is {best_reward}') - assert best_reward > args.target, f'Best reward must be greater than {args.target}. best_reward: {best_reward}' - print('Check passes') + print(f"Best reward is {best_reward}") + assert best_reward > args.target, f"Best reward must be greater than {args.target}. best_reward: {best_reward}" + print("Check passes") diff --git a/tests/e2e/envs/__init__.py b/tests/e2e/envs/__init__.py index 7d3914794..eb85e22f3 100644 --- a/tests/e2e/envs/__init__.py +++ b/tests/e2e/envs/__init__.py @@ -14,4 +14,4 @@ from .digit_completion import DigitCompletion -__all__ = ['DigitCompletion'] \ No newline at end of file +__all__ = ["DigitCompletion"] diff --git a/tests/e2e/envs/digit_completion/__init__.py b/tests/e2e/envs/digit_completion/__init__.py index 3e3aa76da..80893ae41 100644 --- a/tests/e2e/envs/digit_completion/__init__.py +++ b/tests/e2e/envs/digit_completion/__init__.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from transformers import AutoTokenizer, LlamaConfig + from .task import DigitCompletion, generate_ground_truth_response from .tokenizer import CharTokenizer -from transformers import AutoTokenizer, LlamaConfig - AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True) -__all__ = ['DigitCompletion', 'generate_ground_truth_response', 'CharTokenizer'] \ No newline at end of file +__all__ = ["DigitCompletion", "generate_ground_truth_response", "CharTokenizer"] diff --git a/tests/e2e/envs/digit_completion/task.py b/tests/e2e/envs/digit_completion/task.py index 7322027c7..c3643a86b 100644 --- a/tests/e2e/envs/digit_completion/task.py +++ b/tests/e2e/envs/digit_completion/task.py @@ -16,7 +16,7 @@ import numpy as np -class DigitCompletion(object): +class DigitCompletion: """ The implementation of a simple digit completion task. The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers. @@ -54,16 +54,18 @@ class DigitCompletion(object): self.np_rng = np.random.default_rng(seed=seed) def __str__(self): - return f'Prompt length: {self.prompt_length}. Response length: {self.response_length}, ' \ - f'Max number: {self.max_number}. Max diff: {self.max_diff}, ' \ - f'Max number in response: {self.max_num_in_response}' + return ( + f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, " + f"Max number: {self.max_number}. Max diff: {self.max_diff}, " + f"Max number in response: {self.max_num_in_response}" + ) def get_state(self): - return {'rng': self.np_rng} + return {"rng": self.np_rng} def set_state(self, state): - assert 'rng' in state, 'rng must be inside state' - self.np_rng = state['rng'] + assert "rng" in state, "rng must be inside state" + self.np_rng = state["rng"] @property def prompt_length(self): @@ -84,7 +86,7 @@ class DigitCompletion(object): for diff in range(0, self.max_diff + 1): second_num = self.add(first_num, diff) for num_to_complete in range(self.max_num_in_response + 1): - prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}' + prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" all_prompts.append(prompt) return all_prompts @@ -94,7 +96,7 @@ class DigitCompletion(object): diff = self.np_rng.integers(self.max_diff + 1) second_num = self.add(first_num, diff) num_to_complete = self.np_rng.integers(self.max_num_in_response + 1) - prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}' + prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" return prompt def sample_batch_str_prompts(self, batch_size): @@ -116,9 +118,9 @@ def compute_position_id_with_mask(mask): def generate_ground_truth_response(prompt: str): """Generate ground truth response given a prompt.""" - num, info = prompt.split(':') - num1, num2 = num.split(',') - max_number, num_to_gen = info.split(',') + num, info = prompt.split(":") + num1, num2 = num.split(",") + max_number, num_to_gen = info.split(",") num1 = int(num1) num2 = int(num2) max_number = int(max_number) @@ -130,11 +132,11 @@ def generate_ground_truth_response(prompt: str): curr = (last_num + diff) % max_number results.append(str(curr)) last_num = curr - response = ','.join(results) + response = ",".join(results) return response -def compute_reward(prompt: str, response: str, sequence_reward=1.): +def compute_reward(prompt: str, response: str, sequence_reward=1.0): """We compute dense reward here so that we can directly train RL without SFT""" response_length = len(response) ground_truth_response = generate_ground_truth_response(prompt) @@ -157,21 +159,21 @@ def compute_reward(prompt: str, response: str, sequence_reward=1.): # no matches break - return reward, {'ground_truth_response': ground_truth_response} + return reward, {"ground_truth_response": ground_truth_response} -if __name__ == '__main__': +if __name__ == "__main__": task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5) print(task.sample_str_prompts()) - prompt = '7,8:20,0' - response = '' + prompt = "7,8:20,0" + response = "" print(compute_reward(prompt, response)) - prompt = '7,8:20,0' - response = 'E000' + prompt = "7,8:20,0" + response = "E000" print(compute_reward(prompt, response)) - prompt = '9,10:20,2' - response = '11,12,13' + prompt = "9,10:20,2" + response = "11,12,13" print(compute_reward(prompt, response)) diff --git a/tests/e2e/envs/digit_completion/tokenizer.py b/tests/e2e/envs/digit_completion/tokenizer.py index 1b8d94ee3..9581dc841 100644 --- a/tests/e2e/envs/digit_completion/tokenizer.py +++ b/tests/e2e/envs/digit_completion/tokenizer.py @@ -27,7 +27,6 @@ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer class CharTokenizer(PreTrainedTokenizer): - def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs): """Character tokenizer for Hugging Face transformers. @@ -47,10 +46,10 @@ class CharTokenizer(PreTrainedTokenizer): model_max_length (int): Model maximum sequence length. """ - eos_token_str = 'E' - sep_token_str = 'S' - pad_token_str = 'P' - unk_token_str = 'U' + eos_token_str = "E" + sep_token_str = "S" + pad_token_str = "P" + unk_token_str = "U" self.characters = characters self.model_max_length = model_max_length @@ -64,9 +63,7 @@ class CharTokenizer(PreTrainedTokenizer): eos_token_str: 1, pad_token_str: 2, unk_token_str: 3, - **{ - ch: i + 4 for i, ch in enumerate(characters) - }, + **{ch: i + 4 for i, ch in enumerate(characters)}, } self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} @@ -101,9 +98,9 @@ class CharTokenizer(PreTrainedTokenizer): def convert_tokens_to_string(self, tokens): return "".join(tokens) - def build_inputs_with_special_tokens(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: sep = [self.sep_token_id] cls = [self.cls_token_id] result = cls + token_ids_0 + sep @@ -133,11 +130,11 @@ class CharTokenizer(PreTrainedTokenizer): return { "char_ords": [ord(ch) for ch in self.characters], "model_max_length": self.model_max_length, - "chat_template": self.chat_template + "chat_template": self.chat_template, } @classmethod - def from_config(cls, config: Dict) -> "DigitCompletionTokenizer": + def from_config(cls, config: Dict): cfg = {} cfg["characters"] = [chr(i) for i in config["char_ords"]] cfg["model_max_length"] = config["model_max_length"] diff --git a/tests/e2e/sft/test_sp_loss_match.py b/tests/e2e/sft/test_sp_loss_match.py index 9dbb38da0..dc8d47821 100644 --- a/tests/e2e/sft/test_sp_loss_match.py +++ b/tests/e2e/sft/test_sp_loss_match.py @@ -15,14 +15,15 @@ import torch import torch.distributed from tensordict import TensorDict -from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer from torch.distributed.device_mesh import init_device_mesh + +from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer from verl.utils.distributed import initialize_global_process_group def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4): """Test consistency between original forward pass and SP+rmpad forward passes. - + Args: trainer: The FSDPSFTTrainer instance to test total_steps: Number of steps to test (default: 4) @@ -88,28 +89,28 @@ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = def create_trainer(config): """Create and initialize a trainer instance with the given config. - + Args: config: Configuration object with training parameters - + Returns: FSDPSFTTrainer: Initialized trainer instance """ local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type='cuda', - mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), - mesh_dim_names=('dp', 'sp')) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") + ) return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) def main(config): """Main function to run trainer tests. - + Args: config: Configuration object with training parameters """ @@ -117,7 +118,7 @@ def main(config): test_trainer_forward_consistency(trainer) -if __name__ == '__main__': +if __name__ == "__main__": import hydra from omegaconf import DictConfig diff --git a/tests/gpu_utility/test_memory_buffers.py b/tests/gpu_utility/test_memory_buffers.py index 0116c8be0..1dd230340 100644 --- a/tests/gpu_utility/test_memory_buffers.py +++ b/tests/gpu_utility/test_memory_buffers.py @@ -17,20 +17,23 @@ Test memory buffers - We use Memory buffer to make one of the models and then compare the parameters """ -import torch import gc -from transformers import LlamaModel, LlamaConfig +import torch +from transformers import LlamaConfig, LlamaModel + from verl.utils.memory_buffer import MemoryBufferModuleWrapper def test_memory_buffers(): - llama_config = LlamaConfig(vocab_size=256, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=2, - num_attention_heads=16, - num_key_value_heads=16) + llama_config = LlamaConfig( + vocab_size=256, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=2, + num_attention_heads=16, + num_key_value_heads=16, + ) model = LlamaModel(config=llama_config).cuda() model_copy = LlamaModel(config=llama_config).cuda() @@ -45,7 +48,7 @@ def test_memory_buffers(): r_before = torch.cuda.memory_reserved(0) / norm_factor a_before = torch.cuda.memory_allocated(0) / norm_factor - print(f'Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB') + print(f"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB") model_wrapper = MemoryBufferModuleWrapper(model) @@ -56,15 +59,15 @@ def test_memory_buffers(): gc.collect() torch.cuda.empty_cache() - print(f'After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB') + print(f"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB") change_ratio = (a - a_before) / a_before - assert change_ratio < 0.01, f'make sure the allocated change is less than 1%, Got {change_ratio}' + assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}" for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()): assert name1 == name2 - assert torch.eq(param1.data, param2.data).all(), f'{param1.data}, {param2.data}, {name1}' + assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}" -if __name__ == '__main__': +if __name__ == "__main__": test_memory_buffers() diff --git a/tests/gpu_utility/test_ops.py b/tests/gpu_utility/test_ops.py index b7663cfd2..4bfb22298 100644 --- a/tests/gpu_utility/test_ops.py +++ b/tests/gpu_utility/test_ops.py @@ -14,33 +14,31 @@ def test_flash_attn_cross_entropy(): - from verl.utils.torch_functional import logprobs_from_logits_naive - - from verl.utils.debug import log_gpu_memory_usage - - from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - import torch + from flash_attn.ops.triton.cross_entropy import cross_entropy_loss from torch import nn - log_gpu_memory_usage('At start') + from verl.utils.debug import log_gpu_memory_usage + from verl.utils.torch_functional import logprobs_from_logits_naive - hidden_states = torch.randn(size=(2048, 5120), device='cuda', requires_grad=True, dtype=torch.bfloat16) + log_gpu_memory_usage("At start") - linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device='cuda', dtype=torch.bfloat16) + hidden_states = torch.randn(size=(2048, 5120), device="cuda", requires_grad=True, dtype=torch.bfloat16) + + linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device="cuda", dtype=torch.bfloat16) logits = linear(hidden_states) # logits = logits.float() - labels = torch.randint(low=0, high=155136, size=(2048,), device='cuda') + labels = torch.randint(low=0, high=155136, size=(2048,), device="cuda") - log_gpu_memory_usage('before computation') + log_gpu_memory_usage("before computation") # output = checkpoint.checkpoint(logprobs_from_logits, logits, labels, use_reentrant=True) output = -cross_entropy_loss(logits, labels)[0] # output = logprobs_from_logits(logits, labels) - log_gpu_memory_usage('After forward') + log_gpu_memory_usage("After forward") output.sum().backward() - log_gpu_memory_usage('After backward') + log_gpu_memory_usage("After backward") groundtruth = logprobs_from_logits_naive(logits.float(), labels) diff --git a/tests/gpu_utility/test_torch_functional.py b/tests/gpu_utility/test_torch_functional.py index 6dfa2e867..b3c163dbe 100644 --- a/tests/gpu_utility/test_torch_functional.py +++ b/tests/gpu_utility/test_torch_functional.py @@ -12,39 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -from verl.utils.model import create_random_mask -from flash_attn.bert_padding import unpad_input -import torch import pytest +import torch +from flash_attn.bert_padding import unpad_input + +from verl.utils.model import create_random_mask def test_log_probs_from_logits_response_rmpad(): from verl.utils.torch_functional import log_probs_from_logits_response, log_probs_from_logits_response_rmpad + vocab_size = 32000 batch_size = 2 prompt_length = 256 response_length = 256 - input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device='cuda') - attention_mask = create_random_mask(input_ids=input_ids, - max_ratio_of_left_padding=0.2, - max_ratio_of_valid_token=0.8, - min_ratio_of_valid_token=0.6) + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device="cuda") + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0.2, max_ratio_of_valid_token=0.8, min_ratio_of_valid_token=0.6 + ) response_mask = attention_mask[:, -response_length:] assert torch.all(response_mask[:, 0] == 1) - logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device='cuda') + logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device="cuda") logits_rmpad = unpad_input(logits, attention_mask)[0] - expected_output = log_probs_from_logits_response(input_ids=input_ids, - logits=logits, - response_length=response_length) - actual_output = log_probs_from_logits_response_rmpad(input_ids=input_ids, - attention_mask=attention_mask, - logits_rmpad=logits_rmpad, - response_length=response_length) + expected_output = log_probs_from_logits_response( + input_ids=input_ids, logits=logits, response_length=response_length + ) + actual_output = log_probs_from_logits_response_rmpad( + input_ids=input_ids, attention_mask=attention_mask, logits_rmpad=logits_rmpad, response_length=response_length + ) # This should bitwise align as only this operation only contains gather operators assert torch.all(torch.eq(actual_output * response_mask, expected_output * response_mask)) @@ -52,13 +52,14 @@ def test_log_probs_from_logits_response_rmpad(): @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]) def test_logprobs_from_logits_v2(dtype): - from verl.utils.torch_functional import logprobs_from_logits_v2, logprobs_from_logits_naive + from verl.utils.torch_functional import logprobs_from_logits_naive, logprobs_from_logits_v2 + vocab_size = 32000 batch_size = 2 seq_len = 512 - labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device='cuda') - logits = torch.randn(batch_size, seq_len, vocab_size, device='cuda', dtype=dtype) + labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") + logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) expected_output = logprobs_from_logits_naive(labels=labels, logits=logits) actual_output = logprobs_from_logits_v2(labels=labels, logits=logits) @@ -71,10 +72,12 @@ def test_logprobs_from_logits_v2(dtype): def test_lr_scheduler(): from torch import nn + model = nn.Linear(10, 10) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) from verl.utils.torch_functional import get_constant_schedule_with_warmup + constant_lr = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=2) lr_lst = [] @@ -86,11 +89,11 @@ def test_lr_scheduler(): torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.001, 0.001]) from verl.utils.torch_functional import get_cosine_schedule_with_warmup + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - cosine_lr = get_cosine_schedule_with_warmup(optimizer=optimizer, - num_warmup_steps=2, - num_training_steps=5, - min_lr_ratio=0.1) + cosine_lr = get_cosine_schedule_with_warmup( + optimizer=optimizer, num_warmup_steps=2, num_training_steps=5, min_lr_ratio=0.1 + ) lr_lst = [] diff --git a/tests/model/test_transformer.py b/tests/model/test_transformer.py index 1dd7fcff1..111230a8a 100644 --- a/tests/model/test_transformer.py +++ b/tests/model/test_transformer.py @@ -13,19 +13,26 @@ # limitations under the License. import torch -from verl.utils.model import create_random_mask, compute_position_id_with_mask -from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits -from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +from transformers import ( + AutoModelForCausalLM, + AutoModelForTokenClassification, + GemmaConfig, + LlamaConfig, + MistralConfig, + Qwen2Config, +) + +from verl.utils.model import compute_position_id_with_mask, create_random_mask +from verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean -from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config -from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForSequenceClassification # TODO(sgm): add more models for test # we only need one scale for each model test_configs = [ LlamaConfig(num_hidden_layers=1), MistralConfig(num_hidden_layers=1), GemmaConfig(num_hidden_layers=1), - Qwen2Config(num_hidden_layers=1) + Qwen2Config(num_hidden_layers=1), ] @@ -36,56 +43,67 @@ def test_hf_casual_models(): for config in test_configs: # config = AutoConfig.from_pretrained(test_case) - with torch.device('cuda'): - model = AutoModelForCausalLM.from_config(config=config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2') - model = model.to(device='cuda') - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') - attention_mask = create_random_mask(input_ids=input_ids, - max_ratio_of_left_padding=0.1, - max_ratio_of_valid_token=0.8, - min_ratio_of_valid_token=0.5) + with torch.device("cuda"): + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model = model.to(device="cuda") + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + attention_mask = create_random_mask( + input_ids=input_ids, + max_ratio_of_left_padding=0.1, + max_ratio_of_valid_token=0.8, + min_ratio_of_valid_token=0.5, + ) position_ids = compute_position_id_with_mask( - attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad, - use_cache=False).logits # (1, total_nnz, vocab_size) + logits_rmpad = model( + input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False + ).logits # (1, total_nnz, vocab_size) - origin_logits = model(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False).logits + origin_logits = model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ).logits origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask) logits_rmpad = logits_rmpad.squeeze(0) - log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad, - logits_rmpad=logits_rmpad, - indices=indices, - batch_size=batch_size, - seqlen=seqlen, - response_length=response_length) # (batch, seqlen) - origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad, - logits_rmpad=origin_logits_rmpad, - indices=origin_logits_indices, - batch_size=batch_size, - seqlen=seqlen, - response_length=response_length) # (batch, seqlen) + log_probs = log_probs_from_logits_all_rmpad( + input_ids_rmpad=input_ids_rmpad, + logits_rmpad=logits_rmpad, + indices=indices, + batch_size=batch_size, + seqlen=seqlen, + response_length=response_length, + ) # (batch, seqlen) + origin_log_probs = log_probs_from_logits_all_rmpad( + input_ids_rmpad=input_ids_rmpad, + logits_rmpad=origin_logits_rmpad, + indices=origin_logits_indices, + batch_size=batch_size, + seqlen=seqlen, + response_length=response_length, + ) # (batch, seqlen) - torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]), - masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]), - atol=1e-2, - rtol=1e-5) - print(f'Check pass') + torch.testing.assert_close( + masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]), + masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]), + atol=1e-2, + rtol=1e-5, + ) + print("Check pass") def test_hf_value_models(): @@ -95,47 +113,54 @@ def test_hf_value_models(): for config in test_configs: # config = AutoConfig.from_pretrained(test_case) config.num_labels = 1 - setattr(config, 'classifier_dropout', 0) - setattr(config, 'hidden_dropout', 0) - with torch.device('cuda'): - model = AutoModelForTokenClassification.from_config(config=config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2') - model = model.to(device='cuda') - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') - attention_mask = create_random_mask(input_ids=input_ids, - max_ratio_of_left_padding=0.1, - max_ratio_of_valid_token=0.8, - min_ratio_of_valid_token=0.5) + config.classifier_dropout = 0 + config.hidden_dropout = 0 + with torch.device("cuda"): + model = AutoModelForTokenClassification.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model = model.to(device="cuda") + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + attention_mask = create_random_mask( + input_ids=input_ids, + max_ratio_of_left_padding=0.1, + max_ratio_of_valid_token=0.8, + min_ratio_of_valid_token=0.5, + ) position_ids = compute_position_id_with_mask( - attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) - origin_logits = model(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False).logits + origin_logits = model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ).logits # input with input_ids_rmpad and postition_ids to enable flash attention varlen - rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad, - use_cache=False).logits # (1, total_nnz, 1) + rmpad_logits = model( + input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False + ).logits # (1, total_nnz, 1) rmpad_logits = rmpad_logits.squeeze(0) pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen) - torch.testing.assert_close(masked_mean(pad_logits, attention_mask[:, :, None]), - masked_mean(origin_logits, attention_mask[:, :, None]), - atol=1e-2, - rtol=1e-5) - print('Value model check pass') + torch.testing.assert_close( + masked_mean(pad_logits, attention_mask[:, :, None]), + masked_mean(origin_logits, attention_mask[:, :, None]), + atol=1e-2, + rtol=1e-5, + ) + print("Value model check pass") -if __name__ == '__main__': +if __name__ == "__main__": test_hf_casual_models() test_hf_value_models() diff --git a/tests/model/test_transformers_ulysses.py b/tests/model/test_transformers_ulysses.py index 2bbab1d50..27f070826 100644 --- a/tests/model/test_transformers_ulysses.py +++ b/tests/model/test_transformers_ulysses.py @@ -11,25 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import time import contextlib +import copy from dataclasses import dataclass import pytest import torch -import copy import torch.distributed +from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input from torch.distributed import init_device_mesh -from verl.utils.distributed import initialize_global_process_group -from verl.utils.model import create_random_mask, compute_position_id_with_mask -from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad -from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group -from verl.workers.sharding_manager import FSDPUlyssesShardingManager -from verl.protocol import DataProto -from flash_attn.bert_padding import unpad_input, index_first_axis, rearrange -from transformers import LlamaConfig, Qwen2Config, PretrainedConfig -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config + from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.protocol import DataProto +from verl.utils.distributed import initialize_global_process_group +from verl.utils.model import compute_position_id_with_mask, create_random_mask +from verl.utils.ulysses import ( + gather_outpus_and_unpad, + get_ulysses_sequence_parallel_world_size, + set_ulysses_sequence_parallel_group, + ulysses_pad_and_slice_inputs, +) +from verl.workers.sharding_manager import FSDPUlyssesShardingManager # TODO(sgm): add more models for test # we only need one scale for each model @@ -44,27 +47,25 @@ class SequenceParallelConfig: def test_configs(): return [ - SequenceParallelConfig(LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), - sp_size=8, - is_valid=True), - SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, - num_attention_heads=28, - num_key_value_heads=4, - hidden_size=3584), - sp_size=4, - is_valid=True), - SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, - num_attention_heads=28, - num_key_value_heads=4, - hidden_size=3584), - sp_size=8, - is_valid=False), - SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), - sp_size=4, - is_valid=True), - SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), - sp_size=8, - is_valid=True), + SequenceParallelConfig( + LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True + ), + SequenceParallelConfig( + Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), + sp_size=4, + is_valid=True, + ), + SequenceParallelConfig( + Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), + sp_size=8, + is_valid=False, + ), + SequenceParallelConfig( + Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True + ), + SequenceParallelConfig( + Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True + ), ] @@ -91,9 +92,9 @@ def test_hf_casual_fwd_bwd(test_config): def _hf_casual_fwd(config, sp_size, dp_size): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - ulysses_device_mesh = init_device_mesh(device_type='cuda', - mesh_shape=(dp_size, sp_size), - mesh_dim_names=('dp', 'sp')) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp") + ) sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) batch_size = 1 @@ -101,27 +102,27 @@ def _hf_casual_fwd(config, sp_size, dp_size): response_length = 127 # patch before load - with torch.device('cuda'): - model = AutoModelForCausalLM.from_config(config=config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2') + with torch.device("cuda"): + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) apply_monkey_patch(model, sp_size) - model = model.to(device='cuda') + model = model.to(device="cuda") sync_model_parameters_global(model) # different rank will generate different input_ids following fsdp - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') - attention_mask = create_random_mask(input_ids=input_ids, - max_ratio_of_left_padding=0, - max_ratio_of_valid_token=0.9, - min_ratio_of_valid_token=0.8) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 + ) position_ids = compute_position_id_with_mask( - attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here model_inputs = { - 'input_ids': input_ids.cuda(), - 'attention_mask': attention_mask.cuda(), - 'position_ids': position_ids.int().cuda() + "input_ids": input_ids.cuda(), + "attention_mask": attention_mask.cuda(), + "position_ids": position_ids.int().cuda(), } model_inputs = DataProto.from_dict(model_inputs) @@ -129,33 +130,38 @@ def _hf_casual_fwd(config, sp_size, dp_size): # 1. perform ulysses forward with sharding_manager: model_inputs = sharding_manager.preprocess_data(model_inputs) - input_ids = model_inputs.batch['input_ids'] - attention_mask = model_inputs.batch['attention_mask'] - position_ids = model_inputs.batch['position_ids'] - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids = model_inputs.batch["input_ids"] + attention_mask = model_inputs.batch["attention_mask"] + position_ids = model_inputs.batch["position_ids"] + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # slice input tensor for ulysses # input_ids are padded and sliced # postition_ids are only padded but not sliced input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, - use_cache=False).logits # (1, total_nnz/n, vocab_size) + logits_split_in_seq = model( + input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False + ).logits # (1, total_nnz/n, vocab_size) # all_gather output logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) # 2. perform normal forward set_ulysses_sequence_parallel_group(None) - logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad, - use_cache=False).logits # (1, total_nnz, vocab_size) + logits_rmpad_local = model( + input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False + ).logits # (1, total_nnz, vocab_size) mean_local = logits_rmpad_local.mean() mean_full = logits_full.mean() @@ -165,9 +171,9 @@ def _hf_casual_fwd(config, sp_size, dp_size): def _hf_casual_fwd_bwd(config, sp_size, dp_size): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - ulysses_device_mesh = init_device_mesh(device_type='cuda', - mesh_shape=(dp_size, sp_size), - mesh_dim_names=('dp', 'sp')) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp") + ) sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) batch_size = 1 @@ -175,27 +181,27 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): response_length = 127 # patch before load - with torch.device('cuda'): - model = AutoModelForCausalLM.from_config(config=config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2') + with torch.device("cuda"): + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) apply_monkey_patch(model, sp_size) - model = model.to(device='cuda') + model = model.to(device="cuda") sync_model_parameters_global(model) # different rank will generate different input_ids following fsdp - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') - attention_mask = create_random_mask(input_ids=input_ids, - max_ratio_of_left_padding=0, - max_ratio_of_valid_token=0.9, - min_ratio_of_valid_token=0.8) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 + ) position_ids = compute_position_id_with_mask( - attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here model_inputs = { - 'input_ids': input_ids.cuda(), - 'attention_mask': attention_mask.cuda(), - 'position_ids': position_ids.int().cuda() + "input_ids": input_ids.cuda(), + "attention_mask": attention_mask.cuda(), + "position_ids": position_ids.int().cuda(), } model_inputs = DataProto.from_dict(model_inputs) @@ -203,25 +209,29 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): # 1. perform ulysses forward with sharding_manager: model_inputs = sharding_manager.preprocess_data(model_inputs) - input_ids = model_inputs.batch['input_ids'] - attention_mask = model_inputs.batch['attention_mask'] - position_ids = model_inputs.batch['position_ids'] - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids = model_inputs.batch["input_ids"] + attention_mask = model_inputs.batch["attention_mask"] + position_ids = model_inputs.batch["position_ids"] + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # slice input tensor for ulysses # input_ids are padded and sliced # postition_ids are only padded but not sliced input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, - use_cache=False).logits # (1, total_nnz/n, vocab_size) + logits_split_in_seq = model( + input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False + ).logits # (1, total_nnz/n, vocab_size) # all_gather output logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) @@ -231,8 +241,9 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): input_ids_full = copy.deepcopy(input_ids_rmpad) position_ids_full = copy.deepcopy(position_ids_rmpad) model_no_sp = copy.deepcopy(model) - logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full, - use_cache=False).logits # (1, total_nnz, vocab_size) + logits_rmpad_local = model_no_sp( + input_ids_full, position_ids=position_ids_full, use_cache=False + ).logits # (1, total_nnz, vocab_size) mean_local = logits_rmpad_local.mean() mean_full = logits_full.mean() @@ -247,5 +258,5 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__, "-svv"]) diff --git a/tests/ray/check_worker_alive/main.py b/tests/ray/check_worker_alive/main.py index fcebbfe20..47f8a8ce3 100644 --- a/tests/ray/check_worker_alive/main.py +++ b/tests/ray/check_worker_alive/main.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -import sys import os +import sys +import time import ray -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.base.worker import Worker -from verl.single_controller.base.decorator import register, Dispatch +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup @ray.remote class TestActor(Worker): - def __init__(self) -> None: super().__init__() @@ -41,7 +40,7 @@ if __name__ == "__main__": ray.init() # test single-node-no-partition - print(f"test single-node-no-partition") + print("test single-node-no-partition") resource_pool = RayResourcePool([2], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestActor) @@ -56,8 +55,10 @@ if __name__ == "__main__": _ = wg.foo(wait_time) print("foo started") - print(time.time(), - f"wait 6x wait time {wait_time*6} to let signal returned to process but still not exceed process wait time") + print( + time.time(), + f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time", + ) time.sleep(wait_time * 6) ray.shutdown() diff --git a/tests/ray/detached_worker/client.py b/tests/ray/detached_worker/client.py index 1773fffe7..52f2c7242 100644 --- a/tests/ray/detached_worker/client.py +++ b/tests/ray/detached_worker/client.py @@ -17,44 +17,42 @@ In client, we can get the server handler and send RPC request import ray import torch +from server import Trainer +from tensordict import TensorDict from verl import DataProto from verl.single_controller.ray import RayClassWithInitArgs from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup -from tensordict import TensorDict - -from server import Trainer - def compute_position_id_with_mask(mask): return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) -if __name__ == '__main__': - - ray.init(address='auto', namespace='verl') +if __name__ == "__main__": + ray.init(address="auto", namespace="verl") # get the worker group using names - worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1'] + worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"] cls_with_init_args = RayClassWithInitArgs(cls=Trainer) - worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, - ray_cls_with_init=cls_with_init_args) + worker_group = NVMegatronRayWorkerGroup.from_detached( + worker_names=worker_names, ray_cls_with_init=cls_with_init_args + ) batch_size = 16 sequence_length = 1024 # give Trainer some data to train - input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device='cuda') + input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda") attention_mask = torch.ones_like(input_ids) position_ids = compute_position_id_with_mask(attention_mask) - data = DataProto(batch=TensorDict( - { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'position_ids': position_ids - }, batch_size=batch_size), - meta_info={}) + data = DataProto( + batch=TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}, + batch_size=batch_size, + ), + meta_info={}, + ) output = worker_group.train_model(data) diff --git a/tests/ray/detached_worker/server.py b/tests/ray/detached_worker/server.py index 2387a1a9b..1e019863a 100644 --- a/tests/ray/detached_worker/server.py +++ b/tests/ray/detached_worker/server.py @@ -17,46 +17,41 @@ Server starts a Trainer. Client sends data to the server to train. import os -os.environ['MEGATRON_USE_CUDA_TIMER'] = '0' -os.environ['MEGATRON_START_PROCESS_TIMER'] = 'False' -os.environ['NCCL_DEBUG'] = 'WARN' - -import torch -from torch import nn +os.environ["MEGATRON_USE_CUDA_TIMER"] = "0" +os.environ["MEGATRON_START_PROCESS_TIMER"] = "False" +os.environ["NCCL_DEBUG"] = "WARN" import ray -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool -from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup -from verl.single_controller.base.megatron.worker import MegatronWorker -from verl.single_controller.base.decorator import register, Dispatch -from verl import DataProto -from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP - +import torch from megatron.core import parallel_state as mpu -from megatron.core.models.gpt.gpt_model import ModelType from megatron.core import tensor_parallel -from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config -from verl.utils.megatron.optimizer import get_megatron_optimizer - +from megatron.core.models.gpt.gpt_model import ModelType +from omegaconf import OmegaConf +from tensordict import TensorDict +from torch import nn from transformers import LlamaConfig -from omegaconf import OmegaConf - -from tensordict import TensorDict +from verl import DataProto +from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.megatron.worker import MegatronWorker +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool +from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup +from verl.utils.megatron.optimizer import get_megatron_optimizer +from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config @ray.remote class Trainer(MegatronWorker): - def __init__(self): super().__init__() if not torch.distributed.is_initialized(): - rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group(backend="nccl") torch.cuda.set_device(rank) - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" mpu.initialize_model_parallel( tensor_model_parallel_size=2, pipeline_model_parallel_size=1, @@ -71,12 +66,14 @@ class Trainer(MegatronWorker): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - actor_model_config = LlamaConfig(vocab_size=256, - hidden_size=2048, - intermediate_size=5504, - num_hidden_layers=24, - num_attention_heads=16, - num_key_value_heads=16) + actor_model_config = LlamaConfig( + vocab_size=256, + hidden_size=2048, + intermediate_size=5504, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + ) megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16) self.megatron_config = megatron_config @@ -86,19 +83,23 @@ class Trainer(MegatronWorker): vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model # this_megatron_config = copy.deepcopy(megatron_config) # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank - parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config, - megatron_config=megatron_config, - pre_process=pre_process, - post_process=post_process) + parallel_model = ParallelLlamaForCausalLMRmPadPP( + config=actor_model_config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, + ) parallel_model.cuda() return parallel_model - actor_module = get_model(model_provider_func=megatron_actor_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True) + actor_module = get_model( + model_provider_func=megatron_actor_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + ) actor_module = nn.ModuleList(actor_module) - optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0}) + optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0}) optim_config = init_megatron_optim_config(optim_config) self.optimizer_config = optim_config @@ -109,33 +110,34 @@ class Trainer(MegatronWorker): @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def train_model(self, data: DataProto) -> DataProto: - input_ids = data.batch['input_ids'] - attention_mask = data.batch['attention_mask'] - position_ids = data.batch['position_ids'] + input_ids = data.batch["input_ids"] + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] self.optimizer.zero_grad() self.model.zero_grad_buffer( - zero_buffer=(not self.optimizer_config.use_distributed_optimizer - )) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm + zero_buffer=(not self.optimizer_config.use_distributed_optimizer) + ) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm # update for 1 iteration output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits output.mean().backward() - update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config, - self.megatron_config.timers) + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( + self.megatron_config, self.megatron_config.timers + ) - return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0])) + return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) -if __name__ == '__main__': - ray.init(address='auto', namespace='verl') +if __name__ == "__main__": + ray.init(address="auto", namespace="verl") resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) cls_with_init_args = RayClassWithInitArgs(cls=Trainer) worker_group = NVMegatronRayWorkerGroup( resource_pool=resource_pool, ray_cls_with_init=cls_with_init_args, - name_prefix='trainer', + name_prefix="trainer", detached=True, ) diff --git a/tests/ray/test_check_worker_alive.py b/tests/ray/test_check_worker_alive.py index 53b7f2aab..1596fd3c9 100644 --- a/tests/ray/test_check_worker_alive.py +++ b/tests/ray/test_check_worker_alive.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import os import subprocess +import time def test(): @@ -34,12 +34,13 @@ def test(): print( time.time(), - f"wait 1.5 wait time {wait_time*1.5} to let signal returned to process but still not exceed process wait time") + f"wait 1.5 wait time {wait_time * 1.5} to let signal returned to process but still not exceed process wait time", + ) time.sleep(wait_time * 1.5) - print(time.time(), f"start checking") + print(time.time(), "start checking") assert p.poll() is not None, f"process {p} still alive, expecting signal raised abort" assert p.returncode != 0, f"process {p} exit with code 0, expecting not-zero exit code" - print(f"test passed") + print("test passed") if __name__ == "__main__": diff --git a/tests/ray/test_colocated_workers.py b/tests/ray/test_colocated_workers.py index 96b859b4b..914ea5428 100644 --- a/tests/ray/test_colocated_workers.py +++ b/tests/ray/test_colocated_workers.py @@ -14,35 +14,37 @@ import ray -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import register, Dispatch -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls - from verl import DataProto +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, + create_colocated_worker_cls, +) @ray.remote class Actor(Worker): - def __init__(self) -> None: super().__init__() @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def add(self, data: DataProto): - data.batch['a'] += self.rank + data.batch["a"] += self.rank return data @ray.remote class Critic(Worker): - def __init__(self, config) -> None: super().__init__() self.config = config @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def sub(self, data: DataProto): - data.batch['a'] -= self.config['b'] + data.batch["a"] -= self.config["b"] return data @@ -50,10 +52,11 @@ def test_colocated_workers(): ray.init() import torch - data = DataProto.from_dict({'a': torch.zeros(10)}) + + data = DataProto.from_dict({"a": torch.zeros(10)}) # create separate workers on the same resource pool actor_cls = RayClassWithInitArgs(cls=Actor) - critic_cls = RayClassWithInitArgs(cls=Critic, config={'b': 10}) + critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10}) resource_pool = RayResourcePool(process_on_nodes=[2]) actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) @@ -63,13 +66,13 @@ def test_colocated_workers(): expected_critic_output = critic_wg.sub(data) # create colocated workers - cls_dict = {'actor': actor_cls, 'critic': critic_cls} + cls_dict = {"actor": actor_cls, "critic": critic_cls} ray_cls_with_init = create_colocated_worker_cls(cls_dict) wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) - colocated_actor_wg = spawn_wg['actor'] - colocated_critic_wg = spawn_wg['critic'] + colocated_actor_wg = spawn_wg["actor"] + colocated_critic_wg = spawn_wg["critic"] actor_output = colocated_actor_wg.add(data) critic_output = colocated_critic_wg.sub(data) diff --git a/tests/ray/test_data_transfer.py b/tests/ray/test_data_transfer.py index a17affd30..fdd854e32 100644 --- a/tests/ray/test_data_transfer.py +++ b/tests/ray/test_data_transfer.py @@ -15,27 +15,21 @@ In this test, we instantiate a data parallel worker with 8 GPUs """ -from verl.single_controller.base import Worker -from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool - -from verl.single_controller.base.decorator import Dispatch, register - import ray +import tensordict import torch - +from codetiming import Timer from torch import distributed as dist from verl import DataProto +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.utils.ray_utils import parallel_put -from codetiming import Timer - -import tensordict - @ray.remote class DummyWorker(Worker): - def __init__(self): super().__init__() dist.init_process_group() @@ -44,7 +38,7 @@ class DummyWorker(Worker): def do_nothing(self, data): for key in data.batch.keys(): data.batch[key] += 1 - if tensordict.__version__ >= '0.5.0': + if tensordict.__version__ >= "0.5.0": data.batch = data.batch.consolidate() return data @@ -75,35 +69,39 @@ def test_data_transfer(): for i in range(wg.world_size): # consolidate is necessary - if tensordict.__version__ >= '0.5.0': + if tensordict.__version__ >= "0.5.0": data_list[i].batch = data_list[i].batch.consolidate() - with Timer(name='ray.pickle', initial_text=True): + with Timer(name="ray.pickle", initial_text=True): for i in range(wg.world_size): ray.cloudpickle.pickle.dumps(data_list[i]) - with Timer(name='raw.pickle', initial_text=True): + with Timer(name="raw.pickle", initial_text=True): import pickle + for i in range(wg.world_size): pickle.dumps(data_list[i]) # we put in advance - with Timer(name='put', initial_text=True): + with Timer(name="put", initial_text=True): # takes around 40 seconds data_list_ref = parallel_put(data_list) # for i in range(wg.world_size): # data_list[i] = ray.put(data_list[i]) - with Timer(name='launch', initial_text=True): + with Timer(name="launch", initial_text=True): output_ref = wg.do_nothing(data_list_ref) - with Timer(name='get', initial_text=True): + with Timer(name="get", initial_text=True): # takes around 40 seconds output_lst = ray.get(output_ref) for input_data, output_data in zip(data_list, output_lst): for key in input_data.batch.keys(): - assert torch.all(torch.eq(input_data.batch[key] + 1, - output_data.batch[key])), (input_data.batch[key], output_data.batch[key], key) + assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), ( + input_data.batch[key], + output_data.batch[key], + key, + ) ray.shutdown() diff --git a/tests/ray/test_driverfunc_to_worker.py b/tests/ray/test_driverfunc_to_worker.py index ea253fd36..a38d790d6 100644 --- a/tests/ray/test_driverfunc_to_worker.py +++ b/tests/ray/test_driverfunc_to_worker.py @@ -13,28 +13,27 @@ # limitations under the License. import os + import ray import torch -from verl import DataProto from tensordict import TensorDict +from verl import DataProto from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs from verl.single_controller.ray import RayWorkerGroup +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool -os.environ['RAY_DEDUP_LOGS'] = '0' -os.environ['NCCL_DEBUG'] = 'WARN' +os.environ["RAY_DEDUP_LOGS"] = "0" +os.environ["NCCL_DEBUG"] = "WARN" @ray.remote class ModelActor(Worker): - def __init__(self): pass -class HackSelf(): - +class HackSelf: def __init__(self): pass @@ -44,11 +43,11 @@ def get_aux_metrics(self, test_proto): decode_count = [] for i in range(sequence_ids.size(0)): decode_count.append(len(sequence_ids[i].tolist())) - ret_proto = DataProto(batch=TensorDict({ - "sequence_ids": sequence_ids, - "decode_count": torch.tensor(decode_count) - }, - batch_size=sequence_ids.size(0))) + ret_proto = DataProto( + batch=TensorDict( + {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0) + ) + ) return ret_proto @@ -57,17 +56,21 @@ def test(): ray.init() # create 2 workers, each hold a GPU - resource_pool = RayResourcePool([2], use_gpu=True, name_prefix='a') + resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a") class_with_args = RayClassWithInitArgs(cls=ModelActor) shard_wg = RayWorkerGroup(resource_pool, class_with_args) test_bs = 8 - test_proto = DataProto(TensorDict({ - "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64), - }, - batch_size=test_bs), - meta_info={"query_length": 1536}) + test_proto = DataProto( + TensorDict( + { + "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64), + }, + batch_size=test_bs, + ), + meta_info={"query_length": 1536}, + ) # Sharding among different ranks ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto) diff --git a/tests/ray/test_high_level_scheduling_api.py b/tests/ray/test_high_level_scheduling_api.py index 2d83206e5..52cc7c7df 100644 --- a/tests/ray/test_high_level_scheduling_api.py +++ b/tests/ray/test_high_level_scheduling_api.py @@ -16,8 +16,8 @@ import time import ray -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool @ray.remote @@ -34,7 +34,7 @@ def test(): ray.init() # test single-node-no-partition - print(f"test single-node-no-partition") + print("test single-node-no-partition") resource_pool = RayResourcePool([8], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestActor) @@ -63,7 +63,7 @@ def test(): time.sleep(5) # test single-node-multi-partition - print(f"test single-node-multi-partition") + print("test single-node-multi-partition") rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm") ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref") total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool) diff --git a/tests/ray/test_ray_local_envs.py b/tests/ray/test_ray_local_envs.py index 63102d0c6..1dbd1f6a0 100644 --- a/tests/ray/test_ray_local_envs.py +++ b/tests/ray/test_ray_local_envs.py @@ -14,17 +14,17 @@ """ e2e test verl.single_controller.ray """ + import os + import ray -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.base.worker import Worker -from verl.single_controller.base.decorator import register, Dispatch, collect_all_to_all, Execute +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup @ray.remote class TestActor(Worker): - def __init__(self) -> None: super().__init__() @@ -40,9 +40,9 @@ def test_basics(): resource_pool = RayResourcePool([4], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestActor) - worker_group = RayWorkerGroup(resource_pool=resource_pool, - ray_cls_with_init=class_with_args, - name_prefix="worker_group_basic") + worker_group = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" + ) output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_WORLD_SIZE") assert output == ["4", "4", "4", "4"] @@ -53,5 +53,5 @@ def test_basics(): ray.shutdown() -if __name__ == '__main__': +if __name__ == "__main__": test_basics() diff --git a/tests/ray/test_rvdz.py b/tests/ray/test_rvdz.py index 9eec1d72c..7dea12f95 100644 --- a/tests/ray/test_rvdz.py +++ b/tests/ray/test_rvdz.py @@ -17,7 +17,6 @@ import ray @ray.remote class TestWorker: - def __init__(self, rank, world_size, group_name): self.rank = rank self.world_size = world_size @@ -26,6 +25,7 @@ class TestWorker: def init(self): from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray + self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name) def test(self): diff --git a/tests/ray/test_worker_group_basics.py b/tests/ray/test_worker_group_basics.py index b4b633969..02a5b94eb 100644 --- a/tests/ray/test_worker_group_basics.py +++ b/tests/ray/test_worker_group_basics.py @@ -15,12 +15,12 @@ e2e test verl.single_controller.ray """ -import torch import ray +import torch -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register from verl.single_controller.base.worker import Worker -from verl.single_controller.base.decorator import register, Dispatch, collect_all_to_all, Execute +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup def two_to_all_dispatch_fn(worker_group, *args, **kwargs): @@ -60,7 +60,7 @@ class TestActor(Worker): def foo_all_to_all(self, x, y): return self._x + y + x - @register(dispatch_mode={'dispatch_fn': two_to_all_dispatch_fn, 'collect_fn': collect_all_to_all}) + @register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all}) def foo_custom(self, x, y): return self._x + y + x @@ -94,9 +94,9 @@ def test_basics(): resource_pool = RayResourcePool([4], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - worker_group = RayWorkerGroup(resource_pool=resource_pool, - ray_cls_with_init=class_with_args, - name_prefix="worker_group_basic") + worker_group = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" + ) print(worker_group.worker_names) @@ -124,5 +124,5 @@ def test_basics(): ray.shutdown() -if __name__ == '__main__': +if __name__ == "__main__": test_basics() diff --git a/tests/ray/test_worker_group_torch.py b/tests/ray/test_worker_group_torch.py index 13508ed30..a601c43da 100644 --- a/tests/ray/test_worker_group_torch.py +++ b/tests/ray/test_worker_group_torch.py @@ -14,54 +14,52 @@ import os -os.environ['RAY_DEDUP_LOGS'] = '0' -os.environ['NCCL_DEBUG'] = 'WARN' +os.environ["RAY_DEDUP_LOGS"] = "0" +os.environ["NCCL_DEBUG"] = "WARN" +import ray import torch import torch.distributed -import ray -from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup @ray.remote class TestAllGatherActor(Worker): - def __init__(self, size) -> None: super().__init__() self.size = size def init(self): torch.distributed.init_process_group() - self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device='cuda') + self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda") self.tensor += self.rank def all_gather(self): world_size = self._world_size - output = torch.zeros(size=(self.tensor.shape[0] * world_size,), - dtype=self.tensor.dtype, - device=self.tensor.device) + output = torch.zeros( + size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device + ) torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) return output @ray.remote class TestAllGatherActorV2(Worker): - def __init__(self, size) -> None: super().__init__() self.size = size torch.distributed.init_process_group() - self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device='cuda') + self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda") self.tensor += self.rank def all_gather(self): world_size = self._world_size - output = torch.zeros(size=(self.tensor.shape[0] * world_size,), - dtype=self.tensor.dtype, - device=self.tensor.device) + output = torch.zeros( + size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device + ) torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) return output @@ -78,8 +76,8 @@ def test_all_gather_torch(): worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") - worker_group.execute_all_sync('init') - output = worker_group.execute_all_sync('all_gather') + worker_group.execute_all_sync("init") + output = worker_group.execute_all_sync("all_gather") for i in range(1, len(output)): assert torch.all(output[i] == output[0]) @@ -102,7 +100,7 @@ def test_all_gather_torch_v2(): worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") - output = worker_group.execute_all_sync('all_gather') + output = worker_group.execute_all_sync("all_gather") for i in range(1, len(output)): assert torch.all(output[i] == output[0]) diff --git a/tests/rollout/run_fsdp_vllm.py b/tests/rollout/run_fsdp_vllm.py index 93ebff295..d9cd9c9d5 100644 --- a/tests/rollout/run_fsdp_vllm.py +++ b/tests/rollout/run_fsdp_vllm.py @@ -13,29 +13,30 @@ # limitations under the License. import os -from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload -from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType +import time + import torch - -from verl.utils.distributed import initialize_global_process_group -from verl.third_party.vllm import LLM - +import torch.distributed as dist +from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from vllm import SamplingParams -import time -import torch.distributed as dist +from verl.third_party.vllm import LLM +from verl.utils.distributed import initialize_global_process_group def main(): - assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example' + assert torch.cuda.is_available(), "CUDA must be present to run FSDP vLLM example" local_rank, rank, world_size = initialize_global_process_group() - local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = "~/.cache/verl/rlhf" local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = 'Qwen/Qwen2-7B-Instruct' + hdfs_path = "Qwen/Qwen2-7B-Instruct" from verl.utils.fs import copy_to_local + local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True) @@ -51,14 +52,16 @@ def main(): "The future of AI is", ] tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) - input_ids = prompts['input_ids'] - attention_mask = prompts['attention_mask'] + prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) + input_ids = prompts["input_ids"] + attention_mask = prompts["attention_mask"] from verl.utils.torch_functional import pad_sequence_to_length + input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda() attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda() from transformers import GenerationConfig + generation_config = GenerationConfig(do_sample=False) actor_model.cuda() output = actor_model.generate( @@ -72,61 +75,63 @@ def main(): # renormalize_logits=True, output_scores=False, # this is potentially very large return_dict_in_generate=True, - use_cache=False) # may OOM when use_cache = True + use_cache=False, + ) # may OOM when use_cache = True seq = output.sequences response = seq[:, max_prompt_length:] - print(f'hf response: {tokenizer.batch_decode(response)}') + print(f"hf response: {tokenizer.batch_decode(response)}") tensor_model_parallel_size = 4 from torch.distributed.device_mesh import init_device_mesh - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - fsdp_model = FSDP(actor_model, - use_orig_params=True, - auto_wrap_policy=None, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - cpu_offload=CPUOffload(offload_params=False), - sync_module_states=False, - device_mesh=device_mesh) + fsdp_model = FSDP( + actor_model, + use_orig_params=True, + auto_wrap_policy=None, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + cpu_offload=CPUOffload(offload_params=False), + sync_module_states=False, + device_mesh=device_mesh, + ) - FSDP.set_state_dict_type(fsdp_model, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + ) state_dict = fsdp_model.state_dict() - sampling_params = SamplingParams(temperature=0, - top_p=1, - n=1, - max_tokens=response_length, - logprobs=1, - ignore_eos=True, - detokenize=False) + sampling_params = SamplingParams( + temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False + ) print(actor_model_config) - llm = LLM(model=None, - tokenizer=tokenizer, - model_hf_config=actor_model_config, - tensor_parallel_size=tensor_model_parallel_size, - enforce_eager=True, - dtype='bfloat16', - load_format='dummy_dtensor', - gpu_memory_utilization=0.8, - trust_remote_code=True) + llm = LLM( + model=None, + tokenizer=tokenizer, + model_hf_config=actor_model_config, + tensor_parallel_size=tensor_model_parallel_size, + enforce_eager=True, + dtype="bfloat16", + load_format="dummy_dtensor", + gpu_memory_utilization=0.8, + trust_remote_code=True, + ) # Warmup iterations for _ in range(10): torch.cuda.synchronize() - llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor') + llm.sync_model_weights(actor_weights=state_dict, load_format="dtensor") torch.cuda.synchronize() dist.barrier() start_time = time.time() - llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor') + llm.sync_model_weights(actor_weights=state_dict, load_format="dtensor") torch.cuda.synchronize() dist.barrier() end_time = time.time() @@ -142,14 +147,15 @@ def main(): pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs + for i in range(batch_size): idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) - print('start generation') + print("start generation") outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) vllm_output = outputs[0].cuda() if torch.distributed.get_rank() == 0: - print(f'hf response: {tokenizer.batch_decode(response)}') - print(f'vllm response: {tokenizer.batch_decode(vllm_output)}') + print(f"hf response: {tokenizer.batch_decode(response)}") + print(f"vllm response: {tokenizer.batch_decode(vllm_output)}") if __name__ == "__main__": diff --git a/tests/rollout/test_sglang_spmd.py b/tests/rollout/test_sglang_spmd.py index da49459c8..fdb26d25b 100644 --- a/tests/rollout/test_sglang_spmd.py +++ b/tests/rollout/test_sglang_spmd.py @@ -26,13 +26,11 @@ # limitations under the License. import os + import torch -from torch.distributed.device_mesh import init_device_mesh - from sglang.srt.entrypoints.verl_engine import VerlEngine - -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import GenerationConfig +from torch.distributed.device_mesh import init_device_mesh +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from verl.utils.torch_functional import pad_sequence_to_length @@ -53,7 +51,7 @@ def levenshtein(s1, s2): dp[i][j] = min( dp[i - 1][j] + 1, # Deletion dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost # Substitution + dp[i - 1][j - 1] + cost, # Substitution ) return dp[m][n] @@ -98,19 +96,20 @@ def initialize_global_process_group(timeout_second=36000): def test_sglang_spmd(): - assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.' + assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." initialize_global_process_group() # fill rollout config max_prompt_length = 16 max_response_length = 16 # Initialize model and token - local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = "~/.cache/verl/rlhf" local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = 'Qwen/Qwen2-7B-Instruct' + hdfs_path = "Qwen/Qwen2-7B-Instruct" from verl.utils.fs import copy_to_local + local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left') + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") preencode_prompts = [ "Who won the Champions League in 2019?", @@ -118,9 +117,9 @@ def test_sglang_spmd(): "What's your name", ] tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) - input_ids = prompts['input_ids'] - attention_mask = prompts['attention_mask'] + prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) + input_ids = prompts["input_ids"] + attention_mask = prompts["attention_mask"] input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) @@ -128,17 +127,19 @@ def test_sglang_spmd(): actor_model = AutoModelForCausalLM.from_pretrained(local_model_path) actor_model.to(torch.bfloat16) - sampling_params = dict(n=1, - temperature=0, - top_p=1, - top_k=-1, - max_new_tokens=max_response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ignore_eos=False) + sampling_params = dict( + n=1, + temperature=0, + top_p=1, + top_k=-1, + max_new_tokens=max_response_length, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ignore_eos=False, + ) tensor_parallel_size = 4 device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) @@ -147,13 +148,15 @@ def test_sglang_spmd(): for k in ["TORCHELASTIC_USE_AGENT_STORE"]: if k in os.environ: del os.environ[k] - print('building sglang rollout engine') - llm = VerlEngine(model_path=local_model_path, - dtype="bfloat16", - mem_fraction_static=0.5, - device_mesh_cpu=inference_device_mesh_cpu["tp"], - base_gpu_id=0, - gpu_id_step=1) + print("building sglang rollout engine") + llm = VerlEngine( + model_path=local_model_path, + dtype="bfloat16", + mem_fraction_static=0.5, + device_mesh_cpu=inference_device_mesh_cpu["tp"], + base_gpu_id=0, + gpu_id_step=1, + ) llm.release_memory_occupation() print("start generation") @@ -174,7 +177,8 @@ def test_sglang_spmd(): # renormalize_logits=True, output_scores=False, # this is potentially very large return_dict_in_generate=True, - use_cache=False) # may OOM when use_cache = True + use_cache=False, + ) # may OOM when use_cache = True seq = output.sequences response = seq[:, max_prompt_length:] @@ -184,7 +188,7 @@ def test_sglang_spmd(): idx_list = [] batch_size = input_ids.shape[0] - pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id) + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id for i in range(batch_size): idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) @@ -197,8 +201,7 @@ def test_sglang_spmd(): sglang_response_tokens.append(generated_text) print(f"sglang response: {sglang_response_tokens}") - assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \ - f"Strings differ more than 10%:\n" + assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n" print("Check Pass") diff --git a/tests/rollout/test_vllm_hf_loader.py b/tests/rollout/test_vllm_hf_loader.py index bf966d4ee..e8df7eb3b 100644 --- a/tests/rollout/test_vllm_hf_loader.py +++ b/tests/rollout/test_vllm_hf_loader.py @@ -13,16 +13,12 @@ # limitations under the License. import os + import torch -import transformers +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from vllm import SamplingParams from verl.third_party.vllm import LLM, vllm_version -from verl.utils.model import update_model_config -from vllm import SamplingParams -from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM - -from transformers import GenerationConfig - from verl.utils.torch_functional import pad_sequence_to_length from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs @@ -43,7 +39,7 @@ def levenshtein(s1, s2): dp[i][j] = min( dp[i - 1][j] + 1, # Deletion dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost # Substitution + dp[i - 1][j - 1] + cost, # Substitution ) return dp[m][n] @@ -70,17 +66,18 @@ def are_lists_similar(a, b): def test_vllm_with_hf(): - assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.' + assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." # fill rollout config max_prompt_length = 16 max_response_length = 16 # Initialize model and token - local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = "~/.cache/verl/rlhf" local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = 'deepseek-ai/deepseek-llm-7b-chat' + hdfs_path = "deepseek-ai/deepseek-llm-7b-chat" from verl.utils.fs import copy_to_local + local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) tokenizer = AutoTokenizer.from_pretrained(local_model_path) @@ -90,9 +87,9 @@ def test_vllm_with_hf(): "What's your name", ] tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) - input_ids = prompts['input_ids'] - attention_mask = prompts['attention_mask'] + prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) + input_ids = prompts["input_ids"] + attention_mask = prompts["attention_mask"] input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) @@ -105,28 +102,27 @@ def test_vllm_with_hf(): temperature = 0 top_p = 1 - kwargs = dict(n=1, - temperature=temperature, - top_p=top_p, - max_tokens=max_response_length, - logprobs=1, - ignore_eos=True) + kwargs = dict( + n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True + ) - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): - kwargs['detokenize'] = False + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): + kwargs["detokenize"] = False sampling_params = SamplingParams(**kwargs) tensor_parallel_size = 4 - llm = LLM(model=actor_model, - tokenizer=tokenizer, - model_hf_config=actor_model_config, - tensor_parallel_size=tensor_parallel_size, - dtype='bfloat16', - gpu_memory_utilization=0.1, - load_format='hf') + llm = LLM( + model=actor_model, + tokenizer=tokenizer, + model_hf_config=actor_model_config, + tensor_parallel_size=tensor_parallel_size, + dtype="bfloat16", + gpu_memory_utilization=0.1, + load_format="hf", + ) - print('start generation') + print("start generation") input_ids = input_ids.cuda() attention_mask = attention_mask.cuda() batch_size = input_ids.size(0) @@ -140,6 +136,7 @@ def test_vllm_with_hf(): llm.free_cache_engine() llm = None import gc + torch.cuda.empty_cache() gc.collect() @@ -156,18 +153,18 @@ def test_vllm_with_hf(): # renormalize_logits=True, output_scores=False, # this is potentially very large return_dict_in_generate=True, - use_cache=False) # may OOM when use_cache = True + use_cache=False, + ) # may OOM when use_cache = True seq = output.sequences response = seq[:, max_prompt_length:] hf_response_tokens = tokenizer.batch_decode(response) vllm_response_tokens = tokenizer.batch_decode(vllm_output) - print(f'hf response: {hf_response_tokens}') - print(f'vllm response: {vllm_response_tokens}') - assert are_lists_similar(hf_response_tokens, vllm_response_tokens), \ - f'Strings differ more than 10%:\n' - print('Check Pass') + print(f"hf response: {hf_response_tokens}") + print(f"vllm response: {vllm_response_tokens}") + assert are_lists_similar(hf_response_tokens, vllm_response_tokens), "Strings differ more than 10%:\n" + print("Check Pass") # if __name__ == "__main__": diff --git a/tests/rollout/test_vllm_spmd.py b/tests/rollout/test_vllm_spmd.py index b69ed8887..883cbe899 100644 --- a/tests/rollout/test_vllm_spmd.py +++ b/tests/rollout/test_vllm_spmd.py @@ -13,16 +13,14 @@ # limitations under the License. import os + import torch -import transformers -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload -from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType - +from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType +from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import LLM, SamplingParams -from verl.utils.model import update_model_config -from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM -from transformers import GenerationConfig from verl.utils.distributed import initialize_global_process_group from verl.utils.torch_functional import pad_sequence_to_length @@ -43,7 +41,7 @@ def levenshtein(s1, s2): dp[i][j] = min( dp[i - 1][j] + 1, # Deletion dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost # Substitution + dp[i - 1][j - 1] + cost, # Substitution ) return dp[m][n] @@ -70,16 +68,17 @@ def are_lists_similar(a, b): def test_vllm_spmd(): - assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.' + assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." local_rank, rank, world_size = initialize_global_process_group() # Initialize model and token - local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = "~/.cache/verl/rlhf" local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = 'Qwen/Qwen2-7B-Instruct' + hdfs_path = "Qwen/Qwen2-7B-Instruct" from verl.utils.fs import copy_to_local + local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left', trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True) actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) actor_model.to(torch.bfloat16) @@ -93,46 +92,46 @@ def test_vllm_spmd(): "What's your name", ] tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) - input_ids = prompts['input_ids'] - attention_mask = prompts['attention_mask'] + prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) + input_ids = prompts["input_ids"] + attention_mask = prompts["attention_mask"] input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) - print('start generation') + print("start generation") input_ids = input_ids.cuda() attention_mask = attention_mask.cuda() temperature = 0 top_p = 1 - kwargs = dict(n=1, - temperature=temperature, - top_p=top_p, - max_tokens=max_response_length, - logprobs=1, - ignore_eos=True) + kwargs = dict( + n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True + ) tensor_parallel_size = 4 from torch.distributed.device_mesh import init_device_mesh - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - fsdp_model = FSDP(actor_model, - use_orig_params=True, - auto_wrap_policy=None, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - cpu_offload=CPUOffload(offload_params=False), - sync_module_states=False, - device_mesh=device_mesh) + fsdp_model = FSDP( + actor_model, + use_orig_params=True, + auto_wrap_policy=None, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + cpu_offload=CPUOffload(offload_params=False), + sync_module_states=False, + device_mesh=device_mesh, + ) - FSDP.set_state_dict_type(fsdp_model, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + ) state_dict = fsdp_model.state_dict() @@ -142,7 +141,7 @@ def test_vllm_spmd(): enable_sleep_mode=True, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend="external_launcher", - dtype='bfloat16', + dtype="bfloat16", enforce_eager=True, gpu_memory_utilization=0.8, disable_custom_all_reduce=True, @@ -162,7 +161,8 @@ def test_vllm_spmd(): world_size = torch.distributed.get_world_size() model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model model.load_weights( - ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items())) + ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items()) + ) outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) verl_vllm_response_tokens = [] @@ -171,11 +171,10 @@ def test_vllm_spmd(): verl_vllm_response_tokens.append(generated_text) if torch.distributed.get_rank() == 0: - print(f'vllm response: {vllm_response_tokens}') - print(f'verl-vllm response: {verl_vllm_response_tokens}') - assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), \ - f'Strings differ more than 10%:\n' - print('Check Pass') + print(f"vllm response: {vllm_response_tokens}") + print(f"verl-vllm response: {verl_vllm_response_tokens}") + assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), "Strings differ more than 10%:\n" + print("Check Pass") torch.distributed.destroy_process_group() diff --git a/tests/sandbox/test_sandbox.py b/tests/sandbox/test_sandbox.py index d60474e4a..744135c2e 100644 --- a/tests/sandbox/test_sandbox.py +++ b/tests/sandbox/test_sandbox.py @@ -12,21 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json from verl.utils.reward_score import _default_compute_score from verl.utils.reward_score.prime_code import apps_check_correctness -import asyncio from verl.workers.reward_manager.prime import parallel_compute_score_async prime_math_answers = [ """\\begin{bmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19 \n \\end{bmatrix}""", - """\\frac{\\sqrt{505}}{7}""", """x^2 + y^2 + 4x - 6y + 13""" + """\\frac{\\sqrt{505}}{7}""", + """x^2 + y^2 + 4x - 6y + 13""", ] prime_math_gts = [ """\\begin{pmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19\n \\end{pmatrix}""", # mat test """\\frac{\\sqrt{505}}{7}""", # frac test - """(x + 2)^2 + (y - 3)^2 """ # symbolic test + """(x + 2)^2 + (y - 3)^2 """, # symbolic test ] prime_code_answers = [ @@ -83,7 +84,7 @@ if __name__ == '__main__': ] * 2 prime_code_gts = [ """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample - """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""" + """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", ] # A failed sample with first several in-out passed prime_code_scores = [1.0, 0.9] @@ -99,18 +100,17 @@ def test_parallelism(): while len(sequences_str) < 32: sequences_str.extend(prime_code_answers) ground_truth.extend(prime_code_gts) - data_sources.extend(['codecontests'] * len(prime_code_answers)) + data_sources.extend(["codecontests"] * len(prime_code_answers)) sequences_str.extend(prime_math_answers) ground_truth.extend(prime_math_gts) - data_sources.extend(['numina_aops_forum'] * len(prime_math_answers)) + data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) scores = asyncio.run( - parallel_compute_score_async(_default_compute_score, - sequences_str, - ground_truth, - data_sources, - num_processes=16)) + parallel_compute_score_async( + _default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16 + ) + ) print(scores) @@ -118,7 +118,7 @@ def test_prime_code(): """ Test PRIME code sandbox. """ - data_source = 'codecontests' + data_source = "codecontests" for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): score = _default_compute_score(data_source, completion, ground_truth) assert float(score) == score_ @@ -127,13 +127,13 @@ def test_prime_code(): def test_check_correctness(): completion = prime_code_answers[0] ground_truth = json.loads(prime_code_gts[0]) - ground_truth_single = {'inputs': ground_truth['inputs'][:1], 'outputs': ground_truth['outputs'][:1]} + ground_truth_single = {"inputs": ground_truth["inputs"][:1], "outputs": ground_truth["outputs"][:1]} res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False) print(res, meta) def test_prime_math(): - data_source = 'numina_aops_forum' + data_source = "numina_aops_forum" for completion, ground_truth in zip(prime_math_answers, prime_math_gts): score = _default_compute_score(data_source, completion, ground_truth) assert float(score) == 1.0 diff --git a/tests/sanity/check_license.py b/tests/sanity/check_license.py index 99dc1d048..c4a00610a 100644 --- a/tests/sanity/check_license.py +++ b/tests/sanity/check_license.py @@ -19,21 +19,21 @@ license_head_prime = "Copyright 2024 PRIME team and/or its affiliates" license_head_individual = "Copyright 2025 Individual Contributor:" license_headers = [license_head_bytedance, license_head_bytedance_25, license_head_prime, license_head_individual] -from pathlib import Path from argparse import ArgumentParser +from pathlib import Path -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument('--directory', '-d', required=True, type=str) + parser.add_argument("--directory", "-d", required=True, type=str) args = parser.parse_args() directory_in_str = args.directory - pathlist = Path(directory_in_str).glob('**/*.py') + pathlist = Path(directory_in_str).glob("**/*.py") for path in pathlist: # because path is object not string path_in_str = str(path.absolute()) print(path_in_str) - with open(path_in_str, 'r', encoding='utf-8') as f: + with open(path_in_str, encoding="utf-8") as f: file_content = f.read() has_license = False @@ -41,4 +41,4 @@ if __name__ == '__main__': if lh in file_content: has_license = True break - assert has_license, f'file {path_in_str} does not contain license' + assert has_license, f"file {path_in_str} does not contain license" diff --git a/tests/sanity/test_import.py b/tests/sanity/test_import.py index 2adf63a15..4f8a918fe 100644 --- a/tests/sanity/test_import.py +++ b/tests/sanity/test_import.py @@ -15,9 +15,11 @@ def test_import(): import verl + print(verl.__version__) def test_single_controller_import(): import verl.single_controller + print(verl.single_controller.__version__) diff --git a/tests/utility/test_tensor_dict_utilities.py b/tests/utility/test_tensor_dict_utilities.py index dfd033c00..b67bb1ae7 100644 --- a/tests/utility/test_tensor_dict_utilities.py +++ b/tests/utility/test_tensor_dict_utilities.py @@ -13,41 +13,38 @@ # limitations under the License. import random + +import numpy as np import pytest import torch from tensordict import TensorDict -from verl.protocol import union_tensor_dict, union_numpy_dict - from verl import DataProto -import numpy as np +from verl.protocol import union_numpy_dict, union_tensor_dict def test_union_tensor_dict(): obs = torch.randn(100, 10) - data1 = TensorDict({'obs': obs, 'act': torch.randn(100, 3)}, batch_size=[100]) - data2 = TensorDict({'obs': obs, 'next_obs': torch.randn(100, 10), 'rew': torch.randn(100)}, batch_size=[100]) + data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100]) + data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]) - data_with_copied_obs = TensorDict({ - 'obs': obs.clone(), - 'next_obs': torch.randn(100, 10), - 'rew': torch.randn(100) - }, - batch_size=[100]) + data_with_copied_obs = TensorDict( + {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100] + ) data = union_tensor_dict(data1, data2) with pytest.raises(AssertionError): data = union_tensor_dict(data1, data_with_copied_obs) data = np.random.random(100) - data2 = [float('nan') for _ in range(99)] - data2.append('nan') + data2 = [float("nan") for _ in range(99)] + data2.append("nan") data2 = np.array(data2, dtype=object) data3 = np.tile(data2, (2, 1)) - a = {'a': data, 'b': data2, 'c': data3} - b = {'a': data, 'b': data2, 'c': data3} - b_ = {'a': np.random.random(100)} + a = {"a": data, "b": data2, "c": data3} + b = {"a": data, "b": data2, "c": data3} + b_ = {"a": np.random.random(100)} union_numpy_dict(a, b) with pytest.raises(AssertionError): union_numpy_dict(a, b_) @@ -56,21 +53,21 @@ def test_union_tensor_dict(): def test_tensor_dict_constructor(): obs = torch.randn(100, 10) act = torch.randn(100, 10, 3) - data = DataProto.from_dict(tensors={'obs': obs, 'act': act}) + data = DataProto.from_dict(tensors={"obs": obs, "act": act}) assert data.batch.batch_size == torch.Size([100]) with pytest.raises(AssertionError): - data = DataProto.from_dict(tensors={'obs': obs, 'act': act}, num_batch_dims=2) + data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2) with pytest.raises(AssertionError): - data = DataProto.from_dict(tensors={'obs': obs, 'act': act}, num_batch_dims=3) + data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3) def test_tensor_dict_make_iterator(): obs = torch.randn(100, 10) - labels = [random.choice(['abc', 'cde']) for _ in range(100)] - dataset = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}) + labels = [random.choice(["abc", "cde"]) for _ in range(100)] + dataset = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}) data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1) data_list_1 = [] @@ -85,94 +82,94 @@ def test_tensor_dict_make_iterator(): for data1, data2 in zip(data_list_1, data_list_2): assert isinstance(data1, DataProto) assert isinstance(data2, DataProto) - result = torch.all(torch.eq(data1.batch['obs'], data2.batch['obs'])) + result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"])) if not result.item(): - print(data1.batch['obs']) - print(data2.batch['obs']) + print(data1.batch["obs"]) + print(data2.batch["obs"]) assert False - non_tensor_result = np.all(np.equal(data1.non_tensor_batch['labels'], data2.non_tensor_batch['labels'])) + non_tensor_result = np.all(np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"])) if not non_tensor_result.item(): - print(data1.non_tensor_batch['labels']) - print(data2.non_tensor_batch['labels']) + print(data1.non_tensor_batch["labels"]) + print(data2.non_tensor_batch["labels"]) def test_reorder(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) - labels = ['a', 'b', 'c', 'd', 'e', 'f'] - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'}) + labels = ["a", "b", "c", "d", "e", "f"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) data.reorder(torch.tensor([3, 4, 2, 0, 1, 5])) - assert torch.all(torch.eq(data.batch['obs'], torch.tensor([4, 5, 3, 1, 2, 6]))) - assert np.all(data.non_tensor_batch['labels'] == np.array(['d', 'e', 'c', 'a', 'b', 'f'])) - assert data.meta_info == {'name': 'abdce'} + assert torch.all(torch.eq(data.batch["obs"], torch.tensor([4, 5, 3, 1, 2, 6]))) + assert np.all(data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"])) + assert data.meta_info == {"name": "abdce"} def test_chunk_concat(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) - labels = ['a', 'b', 'c', 'd', 'e', 'f'] - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'}) + labels = ["a", "b", "c", "d", "e", "f"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) with pytest.raises(AssertionError): data.chunk(5) data_split = data.chunk(2) assert len(data_split) == 2 - assert torch.all(torch.eq(data_split[0].batch['obs'], torch.tensor([1, 2, 3]))) - assert np.all(data_split[0].non_tensor_batch['labels'] == np.array(['a', 'b', 'c'])) - assert data_split[0].meta_info == {'name': 'abdce'} + assert torch.all(torch.eq(data_split[0].batch["obs"], torch.tensor([1, 2, 3]))) + assert np.all(data_split[0].non_tensor_batch["labels"] == np.array(["a", "b", "c"])) + assert data_split[0].meta_info == {"name": "abdce"} - assert torch.all(torch.eq(data_split[1].batch['obs'], torch.tensor([4, 5, 6]))) - assert np.all(data_split[1].non_tensor_batch['labels'] == np.array(['d', 'e', 'f'])) - assert data_split[1].meta_info == {'name': 'abdce'} + assert torch.all(torch.eq(data_split[1].batch["obs"], torch.tensor([4, 5, 6]))) + assert np.all(data_split[1].non_tensor_batch["labels"] == np.array(["d", "e", "f"])) + assert data_split[1].meta_info == {"name": "abdce"} concat_data = DataProto.concat(data_split) - assert torch.all(torch.eq(concat_data.batch['obs'], data.batch['obs'])) - assert np.all(concat_data.non_tensor_batch['labels'] == data.non_tensor_batch['labels']) + assert torch.all(torch.eq(concat_data.batch["obs"], data.batch["obs"])) + assert np.all(concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]) assert concat_data.meta_info == data.meta_info def test_pop(): obs = torch.randn(100, 10) act = torch.randn(100, 3) - dataset = DataProto.from_dict({'obs': obs, 'act': act}, meta_info={'2': 2, '1': 1}) - poped_dataset = dataset.pop(batch_keys=['obs'], meta_info_keys=['2']) + dataset = DataProto.from_dict({"obs": obs, "act": act}, meta_info={"2": 2, "1": 1}) + poped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["2"]) - assert poped_dataset.batch.keys() == {'obs'} - assert poped_dataset.meta_info.keys() == {'2'} + assert poped_dataset.batch.keys() == {"obs"} + assert poped_dataset.meta_info.keys() == {"2"} - assert dataset.batch.keys() == {'act'} - assert dataset.meta_info.keys() == {'1'} + assert dataset.batch.keys() == {"act"} + assert dataset.meta_info.keys() == {"1"} def test_repeat(): # Create a DataProto object with some batch and non-tensor data obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ['a', 'b', 'c'] - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'}) + labels = ["a", "b", "c"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) # Test interleave=True repeated_data_interleave = data.repeat(repeat_times=2, interleave=True) expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]]) - expected_labels_interleave = ['a', 'a', 'b', 'b', 'c', 'c'] + expected_labels_interleave = ["a", "a", "b", "b", "c", "c"] - assert torch.all(torch.eq(repeated_data_interleave.batch['obs'], expected_obs_interleave)) - assert (repeated_data_interleave.non_tensor_batch['labels'] == expected_labels_interleave).all() - assert repeated_data_interleave.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) + assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() + assert repeated_data_interleave.meta_info == {"info": "test_info"} # Test interleave=False repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False) expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]]) - expected_labels_no_interleave = ['a', 'b', 'c', 'a', 'b', 'c'] + expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"] - assert torch.all(torch.eq(repeated_data_no_interleave.batch['obs'], expected_obs_no_interleave)) - assert (repeated_data_no_interleave.non_tensor_batch['labels'] == expected_labels_no_interleave).all() - assert repeated_data_no_interleave.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) + assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() + assert repeated_data_no_interleave.meta_info == {"info": "test_info"} def test_dataproto_pad_unpad(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ['a', 'b', 'c'] - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'}) + labels = ["a", "b", "c"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto @@ -180,115 +177,116 @@ def test_dataproto_pad_unpad(): assert pad_size == 1 expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]]) - expected_labels = ['a', 'b', 'c', 'a'] + expected_labels = ["a", "b", "c", "a"] - assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs)) - assert (padded_data.non_tensor_batch['labels'] == expected_labels).all() - assert padded_data.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) + assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() + assert padded_data.meta_info == {"info": "test_info"} unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) - assert torch.all(torch.eq(unpadd_data.batch['obs'], obs)) - assert (unpadd_data.non_tensor_batch['labels'] == labels).all() - assert unpadd_data.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) + assert (unpadd_data.non_tensor_batch["labels"] == labels).all() + assert unpadd_data.meta_info == {"info": "test_info"} padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3) assert pad_size == 0 expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - expected_labels = ['a', 'b', 'c'] + expected_labels = ["a", "b", "c"] - assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs)) - assert (padded_data.non_tensor_batch['labels'] == expected_labels).all() - assert padded_data.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) + assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() + assert padded_data.meta_info == {"info": "test_info"} unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) - assert torch.all(torch.eq(unpadd_data.batch['obs'], obs)) - assert (unpadd_data.non_tensor_batch['labels'] == labels).all() - assert unpadd_data.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) + assert (unpadd_data.non_tensor_batch["labels"] == labels).all() + assert unpadd_data.meta_info == {"info": "test_info"} padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7) assert pad_size == 4 expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) - expected_labels = ['a', 'b', 'c', 'a', 'b', 'c', 'a'] - assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs)) - assert (padded_data.non_tensor_batch['labels'] == expected_labels).all() - assert padded_data.meta_info == {'info': 'test_info'} + expected_labels = ["a", "b", "c", "a", "b", "c", "a"] + assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) + assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() + assert padded_data.meta_info == {"info": "test_info"} unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) - assert torch.all(torch.eq(unpadd_data.batch['obs'], obs)) - assert (unpadd_data.non_tensor_batch['labels'] == labels).all() - assert unpadd_data.meta_info == {'info': 'test_info'} + assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) + assert (unpadd_data.non_tensor_batch["labels"] == labels).all() + assert unpadd_data.meta_info == {"info": "test_info"} def test_dataproto_fold_unfold(): - from verl.protocol import fold_batch_dim, unfold_batch_dim, DataProto + from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ['a', 'b', 'c'] - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'}) + labels = ["a", "b", "c"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) data1 = data.repeat(repeat_times=2, interleave=True) data2 = fold_batch_dim(data1, new_batch_size=3) - torch.testing.assert_close(data2.batch['obs'], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]])) - assert (data2.non_tensor_batch['labels'] == [['a', 'a'], ['b', 'b'], ['c', 'c']]).all() + torch.testing.assert_close(data2.batch["obs"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]])) + assert (data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]]).all() data2.reorder(indices=torch.tensor([1, 2, 0])) data3 = unfold_batch_dim(data2, batch_dims=2) - torch.testing.assert_close(data3.batch['obs'], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]])) - assert (data3.non_tensor_batch['labels'] == ['b', 'b', 'c', 'c', 'a', 'a']).all() - assert data3.meta_info == {'info': 'test_info'} + torch.testing.assert_close(data3.batch["obs"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]])) + assert (data3.non_tensor_batch["labels"] == ["b", "b", "c", "c", "a", "a"]).all() + assert data3.meta_info == {"info": "test_info"} def test_torch_save_data_proto(): - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ['a', 'b', 'c'] - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'}) - data.save_to_disk('test_data.pt') - loaded_data = DataProto.load_from_disk('test_data.pt') + labels = ["a", "b", "c"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data.save_to_disk("test_data.pt") + loaded_data = DataProto.load_from_disk("test_data.pt") - assert torch.all(torch.eq(loaded_data.batch['obs'], data.batch['obs'])) - assert (loaded_data.non_tensor_batch['labels'] == data.non_tensor_batch['labels']).all() + assert torch.all(torch.eq(loaded_data.batch["obs"], data.batch["obs"])) + assert (loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]).all() assert loaded_data.meta_info == data.meta_info import os - os.remove('test_data.pt') + + os.remove("test_data.pt") def test_len(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = np.array(['a', 'b', 'c'], dtype=object) - data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'}) + labels = np.array(["a", "b", "c"], dtype=object) + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) assert len(data) == 3 - data = DataProto(batch=None, non_tensor_batch={'labels': labels}, meta_info={'info': 'test_info'}) + data = DataProto(batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"}) assert len(data) == 3 - data = DataProto(batch=None, non_tensor_batch={}, meta_info={'info': 'test_info'}) + data = DataProto(batch=None, non_tensor_batch={}, meta_info={"info": "test_info"}) assert len(data) == 0 - data = DataProto(batch=None, non_tensor_batch=None, meta_info={'info': 'test_info'}) + data = DataProto(batch=None, non_tensor_batch=None, meta_info={"info": "test_info"}) assert len(data) == 0 def test_seqlen_balancing(): - from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx + from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + input_ids = torch.randint(low=0, high=10, size=(20, 100)) from verl.utils.model import create_random_mask - attention_mask = create_random_mask(input_ids=input_ids, - max_ratio_of_left_padding=0.1, - max_ratio_of_valid_token=0.9, - min_ratio_of_valid_token=0.5) - data = {'input_ids': input_ids, 'attention_mask': attention_mask} + + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 + ) + data = {"input_ids": input_ids, "attention_mask": attention_mask} dataproto = DataProto.from_single_dict(data) micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) batch = torch.cat(micro_batches) @@ -298,4 +296,4 @@ def test_seqlen_balancing(): reverse_idx_map = get_reverse_idx(micro_bsz_idx) reverse_idx_map = torch.tensor(reverse_idx_map) new_batch = batch[reverse_idx_map] - torch.testing.assert_close(new_batch, dataproto.batch) \ No newline at end of file + torch.testing.assert_close(new_batch, dataproto.batch) diff --git a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py index 64594c633..8028d44e5 100644 --- a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py +++ b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py @@ -14,10 +14,13 @@ """ Test the MultiTurnSFTDataset implementation """ + import os + import pandas as pd import torch from transformers import AutoTokenizer + from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset @@ -25,51 +28,35 @@ def test_multiturn_sft_dataset(): print("Starting test...") # Create a temporary parquet file with test data test_data = { - 'messages': [[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is 2+2?" - }, { - "role": "assistant", - "content": "2+2 equals 4." - }, { - "role": "user", - "content": "And what is 4+4?" - }, { - "role": "assistant", - "content": "4+4 equals 8." - }], - [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Tell me a joke." - }, { - "role": "assistant", - "content": "Why did the chicken cross the road?" - }, { - "role": "user", - "content": "Why?" - }, { - "role": "assistant", - "content": "To get to the other side!" - }]] + "messages": [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "And what is 4+4?"}, + {"role": "assistant", "content": "4+4 equals 8."}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + {"role": "assistant", "content": "Why did the chicken cross the road?"}, + {"role": "user", "content": "Why?"}, + {"role": "assistant", "content": "To get to the other side!"}, + ], + ] } # Create test directory if it doesn't exist - os.makedirs('test_data', exist_ok=True) - test_file = 'test_data/test.parquet' + os.makedirs("test_data", exist_ok=True) + test_file = "test_data/test.parquet" # Save test data to parquet df = pd.DataFrame(test_data) df.to_parquet(test_file) # Initialize tokenizer and dataset - tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') - config = {'max_length': 512, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}} + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") + config = {"max_length": 512, "truncation": "error", "multiturn": {"messages_key": "messages"}} dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) # Test 1: Dataset Length @@ -80,23 +67,22 @@ def test_multiturn_sft_dataset(): item1 = dataset[1] # Joke conversation # Test 2: Required Keys and Types - required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] + required_keys = ["input_ids", "attention_mask", "position_ids", "loss_mask"] for key in required_keys: assert key in item0, f"Missing key {key} in dataset item" assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" # Test 3: Shape Consistency - assert item0['loss_mask'].shape == item0['input_ids'].shape, \ - "Loss mask shape doesn't match input_ids shape" - assert item0['attention_mask'].shape == item0['input_ids'].shape, \ + assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape" + assert item0["attention_mask"].shape == item0["input_ids"].shape, ( "Attention mask shape doesn't match input_ids shape" - assert item0['position_ids'].shape == item0['input_ids'].shape, \ - "Position IDs shape doesn't match input_ids shape" + ) + assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape" # Test 4: Loss Mask Pattern - Math Conversation - loss_mask0 = item0['loss_mask'] - input_ids0 = item0['input_ids'] + loss_mask0 = item0["loss_mask"] + input_ids0 = item0["input_ids"] # Find assistant response positions assistant_positions0 = torch.where(loss_mask0 == 1)[0] @@ -109,8 +95,8 @@ def test_multiturn_sft_dataset(): assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" # Test 5: Loss Mask Pattern - Joke Conversation - loss_mask1 = item1['loss_mask'] - input_ids1 = item1['input_ids'] + loss_mask1 = item1["loss_mask"] + input_ids1 = item1["input_ids"] # Find assistant response positions assistant_positions1 = torch.where(loss_mask1 == 1)[0] @@ -123,7 +109,7 @@ def test_multiturn_sft_dataset(): assert "other side" in assistant_text1, "Second assistant response not found" # Test 6: Attention Mask Pattern - attention_mask0 = item0['attention_mask'] + attention_mask0 = item0["attention_mask"] sequence_length = torch.sum(attention_mask0) assert sequence_length > 0, "No tokens marked as attended in attention mask" assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" @@ -131,9 +117,10 @@ def test_multiturn_sft_dataset(): assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" # Test 7: Position IDs Pattern - position_ids0 = item0['position_ids'] - assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ + position_ids0 = item0["position_ids"] + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), ( "Position IDs not sequential for non-padded tokens" + ) if sequence_length < len(position_ids0): assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" @@ -147,16 +134,16 @@ def test_multiturn_sft_dataset(): print(f"\nAssistant responses (from loss mask):\n{assistant_text}") # Verify that loss mask is set for all assistant responses - for msg in test_data['messages'][0]: # First conversation - if msg['role'] == 'assistant': + for msg in test_data["messages"][0]: # First conversation + if msg["role"] == "assistant": # The content should appear in the masked text - assert msg['content'] in assistant_text, \ - f"Assistant message '{msg['content']}' not found in masked text" + assert msg["content"] in assistant_text, f"Assistant message '{msg['content']}' not found in masked text" # The content should NOT appear in the non-masked text non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - assert msg['content'] not in non_assistant_text, \ + assert msg["content"] not in non_assistant_text, ( f"Assistant message '{msg['content']}' found in non-assistant text" + ) # Test 9: Verify non-assistant parts have loss_mask=0 # Get non-assistant text @@ -164,30 +151,31 @@ def test_multiturn_sft_dataset(): print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") # Verify that system and user messages are in the non-assistant text - for msg in test_data['messages'][0]: # First conversation - if msg['role'] in ['system', 'user']: - assert msg['content'] in non_assistant_text, \ + for msg in test_data["messages"][0]: # First conversation + if msg["role"] in ["system", "user"]: + assert msg["content"] in non_assistant_text, ( f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + ) # And verify they're NOT in the assistant text - assert msg['content'] not in assistant_text, \ + assert msg["content"] not in assistant_text, ( f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + ) # Test 10: Verify padding behavior - padding_config = {'max_length': 1024, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}} + padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}} small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config) padded_item = small_dataset[0] # Get actual sequence length (before padding) - actual_length = torch.sum(padded_item['attention_mask']) + actual_length = torch.sum(padded_item["attention_mask"]) # Verify padding tokens - assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ + assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), ( "Padding tokens not set correctly" - assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \ - "Attention mask not set correctly for padding" - assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ - "Loss mask not set correctly for padding" + ) + assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding" + assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding" print("All tests passed!") print("Starting test...") diff --git a/tests/verl/utils/dataset/test_rl_dataset.py b/tests/verl/utils/dataset/test_rl_dataset.py index 040da8278..e3859a093 100644 --- a/tests/verl/utils/dataset/test_rl_dataset.py +++ b/tests/verl/utils/dataset/test_rl_dataset.py @@ -12,32 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + import torch -from torch.utils.data import DataLoader -from transformers import AutoTokenizer from omegaconf import OmegaConf +from torch.utils.data import DataLoader def get_gsm8k_data(): # prepare test dataset url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet" - local_folder = os.path.expanduser('~/verl-data/gsm8k/') - local_path = os.path.join(local_folder, 'train.parquet') + local_folder = os.path.expanduser("~/verl-data/gsm8k/") + local_path = os.path.join(local_folder, "train.parquet") os.makedirs(local_folder, exist_ok=True) return local_path def test_rl_dataset(): - from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils import hf_tokenizer - tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct') + from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct") local_path = get_gsm8k_data() - config = OmegaConf.create({ - "prompt_key": "prompt", - "max_prompt_length": 256, - "filter_overlong_prompts": True, - "filter_overlong_prompts_workers": 2, - }) + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 256, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": 2, + } + ) dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config) dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) @@ -56,29 +59,34 @@ def test_rl_dataset(): non_tensors[key] = val data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) - assert 'input_ids' in data_proto.batch + assert "input_ids" in data_proto.batch - data = dataset[0]['input_ids'] + data = dataset[0]["input_ids"] output = tokenizer.batch_decode([data])[0] - print(f'type: type{output}') - print(f'\n\noutput: {output}') + print(f"type: type{output}") + print(f"\n\noutput: {output}") def test_image_rl_data(): + from verl.utils import hf_processor, hf_tokenizer from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn - from verl.utils import hf_tokenizer, hf_processor - tokenizer = hf_tokenizer('Qwen/Qwen2-VL-2B-Instruct') - processor = hf_processor('Qwen/Qwen2-VL-2B-Instruct') - config = OmegaConf.create({ - "prompt_key": "prompt", - "max_prompt_length": 1024, - "filter_overlong_prompts": True, - "filter_overlong_prompts_workers": 2, - }) - dataset = RLHFDataset(data_files=os.path.expanduser("~/data/geo3k/train.parquet"), - tokenizer=tokenizer, - config=config, - processor=processor) + + tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct") + processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct") + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 1024, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": 2, + } + ) + dataset = RLHFDataset( + data_files=os.path.expanduser("~/data/geo3k/train.parquet"), + tokenizer=tokenizer, + config=config, + processor=processor, + ) dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) @@ -97,10 +105,57 @@ def test_image_rl_data(): data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) - assert 'multi_modal_data' in data_proto.non_tensor_batch - assert 'multi_modal_inputs' in data_proto.non_tensor_batch + assert "multi_modal_data" in data_proto.non_tensor_batch + assert "multi_modal_inputs" in data_proto.non_tensor_batch - data = dataset[0]['input_ids'] + data = dataset[0]["input_ids"] output = tokenizer.batch_decode([data])[0] - print(f'type: type{output}') - print(f'\n\noutput: {output}') + print(f"type: type{output}") + print(f"\n\noutput: {output}") + + +def test_image_rl_data(): + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct") + processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct") + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 1024, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": 2, + } + ) + dataset = RLHFDataset( + data_files=os.path.expanduser("~/data/geo3k/train.parquet"), + tokenizer=tokenizer, + config=config, + processor=processor, + ) + + dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + + a = next(iter(dataloader)) + + from verl import DataProto + + tensors = {} + non_tensors = {} + + for key, val in a.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + else: + non_tensors[key] = val + + data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) + + assert "multi_modal_data" in data_proto.non_tensor_batch + assert "multi_modal_inputs" in data_proto.non_tensor_batch + + data = dataset[0]["input_ids"] + output = tokenizer.batch_decode([data])[0] + print(f"type: type{output}") + print(f"\n\noutput: {output}") diff --git a/tests/verl/utils/dataset/test_rm_dataset.py b/tests/verl/utils/dataset/test_rm_dataset.py index f40d4ac06..066937ac3 100644 --- a/tests/verl/utils/dataset/test_rm_dataset.py +++ b/tests/verl/utils/dataset/test_rm_dataset.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from transformers import AutoTokenizer from verl.utils import hf_tokenizer from verl.utils.dataset.rm_dataset import RMDataset @@ -21,8 +20,8 @@ from verl.utils.dataset.rm_dataset import RMDataset def get_rm_data(): # prepare test dataset url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/full_hh_rlhf/rm/test.parquet" - local_folder = os.path.expanduser('~/verl-data/full_hh_rlhf/rm/') - local_path = os.path.join(local_folder, 'test.parquet') + local_folder = os.path.expanduser("~/verl-data/full_hh_rlhf/rm/") + local_path = os.path.join(local_folder, "test.parquet") os.makedirs(local_folder, exist_ok=True) return local_path @@ -31,7 +30,7 @@ def test_rm_dataset(): tokenizer = hf_tokenizer("facebook/opt-1.3b") local_path = get_rm_data() dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512) - data = dataset[0]['input_ids'] + data = dataset[0]["input_ids"] output = tokenizer.batch_decode(data) assert len(output) > 1 assert type(output[0]) == str diff --git a/tests/verl/utils/dataset/test_sft_dataset.py b/tests/verl/utils/dataset/test_sft_dataset.py index b58e9cfd4..3f6037eb0 100644 --- a/tests/verl/utils/dataset/test_sft_dataset.py +++ b/tests/verl/utils/dataset/test_sft_dataset.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from transformers import AutoTokenizer from verl.utils import hf_tokenizer from verl.utils.dataset.sft_dataset import SFTDataset @@ -21,46 +20,56 @@ from verl.utils.dataset.sft_dataset import SFTDataset def get_gsm8k_data(): # prepare test dataset url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet" - local_folder = os.path.expanduser('~/verl-data/gsm8k/') - local_path = os.path.join(local_folder, 'train.parquet') + local_folder = os.path.expanduser("~/verl-data/gsm8k/") + local_path = os.path.join(local_folder, "train.parquet") return local_path def test_sft_cot_dataset(): - tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct') + tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") local_path = get_gsm8k_data() from omegaconf import OmegaConf - dataset = SFTDataset(parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create({ - 'prompt_key': 'prompt', - 'prompt_dict_keys': ['content'], - 'response_key': 'extra_info', - 'response_dict_keys': ['answer'], - 'max_length': 512, - })) - data = dataset[0]['input_ids'] + dataset = SFTDataset( + parquet_files=local_path, + tokenizer=tokenizer, + config=OmegaConf.create( + { + "prompt_key": "prompt", + "prompt_dict_keys": ["content"], + "response_key": "extra_info", + "response_dict_keys": ["answer"], + "max_length": 512, + } + ), + ) + + data = dataset[0]["input_ids"] output = tokenizer.batch_decode([data])[0] assert len(output) > 1 assert type(output) == str def test_sft_dataset(): - tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct') + tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") local_path = get_gsm8k_data() from omegaconf import OmegaConf - dataset = SFTDataset(parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create({ - "prompt_key": 'extra_info', - 'prompt_dict_keys': ['question'], - 'response_key': 'extra_info', - 'response_dict_keys': ['answer'], - 'max_length': 512 - })) - data = dataset[0]['input_ids'] + dataset = SFTDataset( + parquet_files=local_path, + tokenizer=tokenizer, + config=OmegaConf.create( + { + "prompt_key": "extra_info", + "prompt_dict_keys": ["question"], + "response_key": "extra_info", + "response_dict_keys": ["answer"], + "max_length": 512, + } + ), + ) + + data = dataset[0]["input_ids"] output = tokenizer.batch_decode([data])[0] assert len(output) > 1 assert type(output) == str diff --git a/tests/verl/utils/test_import_utils.py b/tests/verl/utils/test_import_utils.py index 111fdd4e1..96d720a89 100644 --- a/tests/verl/utils/test_import_utils.py +++ b/tests/verl/utils/test_import_utils.py @@ -13,9 +13,9 @@ # limitations under the License. import os -import sys -import importlib.util + import pytest + from verl.utils.import_utils import load_extern_type # Path to the test module @@ -84,7 +84,7 @@ def test_load_extern_type_invalid_module(): # Create a temporary file with syntax errors import tempfile - with tempfile.NamedTemporaryFile(suffix='.py', mode='w+', delete=False) as temp_file: + with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file: temp_file.write("This is not valid Python syntax :") temp_path = temp_file.name diff --git a/verl/__init__.py b/verl/__init__.py index 9b174f43d..89fab359f 100644 --- a/verl/__init__.py +++ b/verl/__init__.py @@ -16,24 +16,26 @@ import os version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) -with open(os.path.join(version_folder, 'version/version')) as f: +with open(os.path.join(version_folder, "version/version")) as f: __version__ = f.read().strip() -from .protocol import DataProto - -from .utils.logging_utils import set_basic_config import logging +from .protocol import DataProto +from .utils.logging_utils import set_basic_config + set_basic_config(level=logging.WARNING) from . import single_controller -__all__ = ['DataProto', "__version__"] +__all__ = ["DataProto", "__version__"] -if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true': +if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true": import importlib + if importlib.util.find_spec("modelscope") is None: - raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') + raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`") # Patch hub to download models from modelscope to speed up. from modelscope.utils.hf_util import patch_hub + patch_hub() diff --git a/verl/models/llama/megatron/__init__.py b/verl/models/llama/megatron/__init__.py index b188b3ee6..853dee0ed 100644 --- a/verl/models/llama/megatron/__init__.py +++ b/verl/models/llama/megatron/__init__.py @@ -13,12 +13,13 @@ # limitations under the License. from .modeling_llama_megatron import ( - # original model with megatron - ParallelLlamaModel, ParallelLlamaForCausalLM, # rmpad with megatron ParallelLlamaForCausalLMRmPad, - ParallelLlamaForValueRmPad, # rmpad with megatron and pipeline parallelism ParallelLlamaForCausalLMRmPadPP, - ParallelLlamaForValueRmPadPP) + ParallelLlamaForValueRmPad, + ParallelLlamaForValueRmPadPP, + # original model with megatron + ParallelLlamaModel, +) diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py index 4f602b5f4..4eb0a7d04 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib -from packaging.version import Version -import torch import time -from typing import Dict, Any, Callable, Optional + +import torch import torch.distributed as dist @@ -29,7 +27,7 @@ def _megatron_calc_layer_map(config): """ from megatron.core import mpu - print(f'get megatron data parallel size: {mpu.get_data_parallel_world_size()}') + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") pp_size = mpu.get_pipeline_model_parallel_world_size() virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 @@ -40,8 +38,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -51,20 +50,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_llama(state_dict, - wrapped_models, - config, - params_dtype, - is_value_model=False, - tie_word_embeddings=False): - """Load merged state_dict to sharded Megatron module in training. - """ - from megatron.core import mpu - from verl.utils.megatron_utils import print_rank_0, unwrap_model - from megatron.core.transformer.module import Float16Module +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP + from verl.utils.megatron_utils import print_rank_0, unwrap_model + start_time = time.time() def _get_gpt_model(model): @@ -72,9 +68,9 @@ def load_state_dict_to_megatron_llama(state_dict, def fetch_params(module): for param in module.parameters(): - torch.distributed.fetch(param.data, - src=mpu.get_data_parallel_src_rank(), - group=mpu.get_data_parallel_group()) + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -92,7 +88,9 @@ def load_state_dict_to_megatron_llama(state_dict, assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}' + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -148,16 +146,16 @@ def load_state_dict_to_megatron_llama(state_dict, if gate_name in state_dict and up_name in state_dict: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) if tensor is not None: @@ -171,7 +169,7 @@ def load_state_dict_to_megatron_llama(state_dict, nonlocal mp_group tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() - assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) + assert q_name in state_dict and k_name in state_dict and v_name in state_dict full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -182,31 +180,29 @@ def load_state_dict_to_megatron_llama(state_dict, q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) if tensor is not None: @@ -235,9 +231,9 @@ def load_state_dict_to_megatron_llama(state_dict, for vpp_rank in range(vpp_size): num_layer_vpp_chunk = num_layer_per_pp // vpp_size num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * ( - config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + \ - (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) layer_list.extend(list(range(offset, offset + num_layer_this_model))) else: num_layer_this_model = num_layer_per_pp @@ -275,8 +271,11 @@ def load_state_dict_to_megatron_llama(state_dict, f"{layer_name}.post_attention_layernorm.weight", ) - _fetch_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) _fetch_tp_shard_tensor( sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, @@ -297,15 +296,15 @@ def load_state_dict_to_megatron_llama(state_dict, lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: _fetch_tensor(lm_head_weight, "lm_head.weight") - print_rank_0('load lm_head weight') - elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: _fetch_tensor(lm_head_weight, "reward_head.weight") - print_rank_0('load lm_head from value_head weight') + print_rank_0("load lm_head from value_head weight") else: _fetch_tensor(None, "lm_head.weight") - print_rank_0('fail to match lm_head in value_model') + print_rank_0("fail to match lm_head in value_model") else: _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py index 7f1b6a4a0..696931de1 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib -from packaging.version import Version -import torch import time -from typing import Dict, Any, Callable, Optional + +import torch import torch.distributed as dist @@ -29,7 +27,7 @@ def _megatron_calc_layer_map(config): """ from megatron.core import mpu - print(f'get megatron data parallel size: {mpu.get_data_parallel_world_size()}') + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") pp_size = mpu.get_pipeline_model_parallel_world_size() virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 @@ -40,8 +38,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -51,20 +50,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_llama(state_dict, - wrapped_models, - config, - params_dtype, - is_value_model=False, - tie_word_embeddings=False): - """Load merged state_dict to sharded Megatron module in training. - """ - from megatron.core import mpu - from verl.utils.megatron_utils import print_rank_0, unwrap_model - from megatron.core.transformer.module import Float16Module +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP + from verl.utils.megatron_utils import print_rank_0, unwrap_model + start_time = time.time() def _get_gpt_model(model): @@ -72,9 +68,9 @@ def load_state_dict_to_megatron_llama(state_dict, def broadcast_params(module): for param in module.parameters(): - torch.distributed.broadcast(param.data, - src=mpu.get_data_parallel_src_rank(), - group=mpu.get_data_parallel_group()) + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -92,7 +88,9 @@ def load_state_dict_to_megatron_llama(state_dict, assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}' + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -171,8 +169,9 @@ def load_state_dict_to_megatron_llama(state_dict, requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -217,8 +216,9 @@ def load_state_dict_to_megatron_llama(state_dict, requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -238,16 +238,16 @@ def load_state_dict_to_megatron_llama(state_dict, if torch.distributed.get_rank() == 0: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -270,9 +270,9 @@ def load_state_dict_to_megatron_llama(state_dict, requires_grad=False, ) else: - assert ( - tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -290,7 +290,7 @@ def load_state_dict_to_megatron_llama(state_dict, tp_size = mpu.get_tensor_model_parallel_world_size() if torch.distributed.get_rank() == 0: - assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) + assert q_name in state_dict and k_name in state_dict and v_name in state_dict full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -301,33 +301,33 @@ def load_state_dict_to_megatron_llama(state_dict, q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], - dim=0)) + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], - dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -350,8 +350,9 @@ def load_state_dict_to_megatron_llama(state_dict, requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -406,8 +407,11 @@ def load_state_dict_to_megatron_llama(state_dict, f"{layer_name}.post_attention_layernorm.weight", ) - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) _broadcast_tp_shard_tensor( sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, @@ -429,15 +433,15 @@ def load_state_dict_to_megatron_llama(state_dict, lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "lm_head.weight") - print_rank_0('load lm_head weight') - elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0('load lm_head from value_head weight') + print_rank_0("load lm_head from value_head weight") else: _broadcast_tensor(None, "lm_head.weight") - print_rank_0('fail to match lm_head in value_model') + print_rank_0("fail to match lm_head in value_model") else: _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") dist.barrier() @@ -446,4 +450,4 @@ def load_state_dict_to_megatron_llama(state_dict, broadcast_params(wrapped_model) torch.cuda.empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") \ No newline at end of file + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py index 5ee15a949..c71fa6541 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -30,8 +30,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() - assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size() - ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) # We only support TP-DP-PP grouping, for correctness when resharding return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank @@ -54,8 +55,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -107,9 +109,11 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers - ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format( - len(models[i].model.layers), num_layers_per_model) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) state_dict = dict() @@ -247,7 +251,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -306,10 +310,10 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp:total_size] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] q_weight_list.append(q_part) k_weight_list.append(k_part) v_weight_list.append(v_part) @@ -318,10 +322,10 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp:total_size] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] q_weight_list.append(q_part) if i * config.num_key_value_heads % tp_size == 0: k_weight_list.append(k_part) @@ -384,10 +388,12 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals src_pp_rank=src_pp_rank, ) - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank) + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) _broadcast_tp_shard_tensor( sync_layer.mlp.down_proj.weight, @@ -410,14 +416,19 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals if is_value_model: if pp_rank == pp_size - 1: - print(f'gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}') - _broadcast_tensor(gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1) - _broadcast_tensor(gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and - getattr(gpt_model_module, "reward_weight", None) is not None else None, - "reward_head.weight", - src_pp_rank=pp_size - 1) + print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) else: _broadcast_tp_shard_tensor( diff --git a/verl/models/llama/megatron/layers/parallel_attention.py b/verl/models/llama/megatron/layers/parallel_attention.py index 3d3ecd790..54f93fa9a 100644 --- a/verl/models/llama/megatron/layers/parallel_attention.py +++ b/verl/models/llama/megatron/layers/parallel_attention.py @@ -22,31 +22,29 @@ import math from typing import Optional, Tuple import torch +from megatron.core import ModelParallelConfig, tensor_parallel from megatron.core import parallel_state as mpu -from megatron.core import tensor_parallel -from megatron.core import ModelParallelConfig from torch import nn from transformers import LlamaConfig -from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear +from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear from verl.utils.megatron import tensor_parallel as tp_utils class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -99,9 +97,10 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -114,7 +113,6 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): - def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): super().__init__(dim, max_position_embeddings, base, device) @@ -122,7 +120,8 @@ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation self.old_context_len = config.rope_scaling[ - "original_max_position_embeddings"] # `8192` in the original implementation + "original_max_position_embeddings" + ] # `8192` in the original implementation low_freq_wavelen = self.old_context_len / self.low_freq_factor high_freq_wavelen = self.old_context_len / self.high_freq_factor @@ -131,8 +130,9 @@ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor - - self.low_freq_factor) + smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) @@ -140,15 +140,15 @@ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -189,47 +189,56 @@ class ParallelLlamaAttention(nn.Module): # assign values after tp tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}' - assert self.num_key_value_heads % tp_size == 0, \ - f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}' + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}" + ) self.num_heads_per_tp = self.num_heads // tp_size self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size self.hidden_size_per_tp = self.hidden_size // tp_size if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() if megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - assert row_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) # [self.q_size, self.k_size, self.v_size] - self.qkv_proj = QKVParallelLinear(input_size=self.hidden_size, - num_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - bias=config.attention_bias, - gather_output=False, - skip_bias_add=False, - **column_kwargs) + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + bias=config.attention_bias, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) self.q_size = self.num_heads_per_tp * self.head_dim self.k_size = self.num_key_value_heads_per_tp * self.head_dim self.v_size = self.num_key_value_heads_per_tp * self.head_dim - self.o_proj = tensor_parallel.RowParallelLinear(input_size=self.num_heads * self.head_dim, - output_size=self.hidden_size, - bias=config.attention_bias, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs) + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=config.attention_bias, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) self._init_rope() @@ -297,12 +306,14 @@ class ParallelLlamaAttention(nn.Module): if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") + f" {attn_weights.size()}" + ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 @@ -312,7 +323,8 @@ class ParallelLlamaAttention(nn.Module): if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) @@ -326,10 +338,9 @@ Remove padding Attention - Compatible with sequence parallel """ -from transformers.utils import is_flash_attn_2_available import torch.nn.functional as F - from einops import rearrange +from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func @@ -358,40 +369,34 @@ from flash_attn.layers.rotary import apply_rotary_emb # use flash-attn rotary embeddings with rmpad # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb(q, - cos, - sin, - interleaved=False, - inplace=False, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen) - k_embed = apply_rotary_emb(k, - cos, - sin, - interleaved=False, - inplace=False, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen) + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) return q_embed, k_embed class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): - - def forward(self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen_in_batch: int = None): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel if self.megatron_config.sequence_parallel: total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], - dim=-1) # (total_nnz, 1, hidden_size) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) if self.megatron_config.sequence_parallel: sequence_parallel_pad = total_nnz - cu_seqlens[-1] @@ -408,13 +413,10 @@ class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, - key_states, - cos, - sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, # TODO: llama does not have dropout in the config?? diff --git a/verl/models/llama/megatron/layers/parallel_decoder.py b/verl/models/llama/megatron/layers/parallel_decoder.py index e51632a33..3f74b69c0 100644 --- a/verl/models/llama/megatron/layers/parallel_decoder.py +++ b/verl/models/llama/megatron/layers/parallel_decoder.py @@ -21,19 +21,18 @@ from typing import Optional, Tuple import torch +from megatron.core import ModelParallelConfig from torch import nn from transformers import LlamaConfig -from megatron.core import ModelParallelConfig + +from verl.utils.megatron_utils import TransformerConfig, convert_config from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad from .parallel_mlp import ParallelLlamaMLP from .parallel_rmsnorm import ParallelLlamaRMSNorm -from verl.utils.megatron_utils import TransformerConfig, convert_config - class ParallelLlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -101,7 +100,6 @@ class ParallelLlamaDecoderLayer(nn.Module): class ParallelLlamaDecoderLayerRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -120,7 +118,7 @@ class ParallelLlamaDecoderLayerRmPad(nn.Module): sequence_length: int = None, indices: torch.Tensor = None, cu_seqlens: int = None, - max_seqlen_in_batch: int = None + max_seqlen_in_batch: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # (total_nnz // sp, 1, hidden_size) @@ -129,12 +127,14 @@ class ParallelLlamaDecoderLayerRmPad(nn.Module): # Self Attention # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) - hidden_states = self.self_attn(hidden_states=hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = residual + hidden_states diff --git a/verl/models/llama/megatron/layers/parallel_linear.py b/verl/models/llama/megatron/layers/parallel_linear.py index e3e4e4385..03890b028 100644 --- a/verl/models/llama/megatron/layers/parallel_linear.py +++ b/verl/models/llama/megatron/layers/parallel_linear.py @@ -13,23 +13,23 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py -from typing import Optional, Tuple from megatron.core import tensor_parallel class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): - - def __init__(self, - input_size, - num_heads, - num_key_value_heads, - head_dim, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): # Keep input parameters, and already restrict the head numbers self.input_size = input_size self.q_output_size = num_heads * head_dim @@ -41,44 +41,48 @@ class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): input_size = self.input_size output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim - super().__init__(input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs) + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): - - def __init__(self, - input_size, - gate_ouput_size, - up_output_size, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): # Keep input parameters, and already restrict the head numbers self.input_size = input_size self.output_size = gate_ouput_size + up_output_size self.gather_output = gather_output self.skip_bias_add = skip_bias_add - super().__init__(input_size=self.input_size, - output_size=self.output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs) + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) import torch class LinearForLastLayer(torch.nn.Linear): - def __init__( self, input_size, @@ -90,7 +94,7 @@ class LinearForLastLayer(torch.nn.Linear): super().__init__(in_features=input_size, out_features=output_size, bias=bias) self.sequence_parallel = config.sequence_parallel if self.sequence_parallel: - setattr(self.weight, 'sequence_parallel', True) + self.weight.sequence_parallel = True def forward( self, diff --git a/verl/models/llama/megatron/layers/parallel_mlp.py b/verl/models/llama/megatron/layers/parallel_mlp.py index 21ad9b16a..583a317eb 100644 --- a/verl/models/llama/megatron/layers/parallel_mlp.py +++ b/verl/models/llama/megatron/layers/parallel_mlp.py @@ -18,18 +18,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from megatron.core import ModelParallelConfig, tensor_parallel from megatron.core import parallel_state as mpu -from megatron.core import tensor_parallel -from megatron.core import ModelParallelConfig from torch import nn from transformers.activations import ACT2FN -from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear from verl.utils.megatron import tensor_parallel as tp_utils class ParallelLlamaMLP(nn.Module): - def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: super().__init__() self.config = config @@ -41,8 +39,8 @@ class ParallelLlamaMLP(nn.Module): row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() if megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - assert row_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) @@ -59,12 +57,14 @@ class ParallelLlamaMLP(nn.Module): ) self.gate_size = self.intermediate_size // tp_size - self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs) + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) self.act_fn = ACT2FN[config.hidden_act] diff --git a/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/verl/models/llama/megatron/layers/parallel_rmsnorm.py index 7027036bf..bc2e9ae36 100644 --- a/verl/models/llama/megatron/layers/parallel_rmsnorm.py +++ b/verl/models/llama/megatron/layers/parallel_rmsnorm.py @@ -13,17 +13,17 @@ # limitations under the License. import numbers + import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine from megatron.core import ModelParallelConfig from torch import nn from transformers import LlamaConfig -from apex.normalization.fused_layer_norm import fused_rms_norm_affine from verl.utils.megatron import sequence_parallel as sp_utils class ParallelLlamaRMSNorm(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): """ LlamaRMSNorm is equivalent to T5LayerNorm @@ -39,8 +39,10 @@ class ParallelLlamaRMSNorm(nn.Module): sp_utils.mark_parameter_as_sequence_parallel(self.weight) def forward(self, hidden_states): - return fused_rms_norm_affine(input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True) \ No newline at end of file + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index 83d0f0f2b..80653c037 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -23,10 +23,7 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint -from megatron.core import tensor_parallel -from megatron.core import ModelParallelConfig -from megatron.core import mpu - +from megatron.core import ModelParallelConfig, mpu, tensor_parallel from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig @@ -35,7 +32,9 @@ from transformers.models.llama.modeling_llama import CausalLMOutputWithPast from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config -from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad + +from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm + """ TODO: 1. Add weight initialization. Here we need to be careful on TP weight init. @@ -87,14 +86,15 @@ class ParallelLlamaModel(nn.Module): self.vocab_size = config.vocab_size embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelLlamaRMSNorm(config, megatron_config) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask @@ -111,10 +111,12 @@ class ParallelLlamaModel(nn.Module): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -157,7 +159,6 @@ class ParallelLlamaModel(nn.Module): class ParallelLlamaForCausalLM(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -166,15 +167,17 @@ class ParallelLlamaForCausalLM(nn.Module): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) def forward( self, @@ -233,23 +236,26 @@ class ParallelLlamaModelRmPad(nn.Module): embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() self.megatron_config = megatron_config if megatron_config is not None: - assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelLlamaRMSNorm(config, megatron_config) - def forward(self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]: + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -268,12 +274,14 @@ class ParallelLlamaModelRmPad(nn.Module): hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer(hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = layer_outputs @@ -283,7 +291,6 @@ class ParallelLlamaModelRmPad(nn.Module): class ParallelLlamaForCausalLMRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -295,14 +302,16 @@ class ParallelLlamaForCausalLMRmPad(nn.Module): def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head @@ -329,8 +338,9 @@ class ParallelLlamaForCausalLMRmPad(nn.Module): batch_size, sequence_length = input_ids.shape # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), - attention_mask) # (total_nnz, 1) + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -339,12 +349,14 @@ class ParallelLlamaForCausalLMRmPad(nn.Module): input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) - outputs = self.model(input_ids=input_ids, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = outputs @@ -357,8 +369,9 @@ class ParallelLlamaForCausalLMRmPad(nn.Module): logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # add removed padding back - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -370,11 +383,10 @@ class ParallelLlamaForCausalLMRmPad(nn.Module): class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): - def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) # lm_head is effectively the same as sequence parallel @@ -423,12 +435,12 @@ class ParallelLlamaModelRmPadPP(nn.Module): self.megatron_config = megatron_config embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) else: self.embed_tokens = None @@ -442,9 +454,7 @@ class ParallelLlamaModelRmPadPP(nn.Module): self.layers = nn.ModuleList() self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * ( - config.num_hidden_layers // vpp_size) + \ - (pp_rank * self.num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) else: self.num_layer_this_model = self.num_layer_per_pp offset = pp_rank * self.num_layer_per_pp @@ -452,7 +462,7 @@ class ParallelLlamaModelRmPadPP(nn.Module): self.layers = nn.ModuleList() for i in range(self.num_layer_this_model): layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) - self.layers.add_module(f'{i}', layer) + self.layers.add_module(f"{i}", layer) if post_process: self.norm = ParallelLlamaRMSNorm(config, megatron_config) @@ -469,13 +479,15 @@ class ParallelLlamaModelRmPadPP(nn.Module): forward_step_func""" self.input_tensor = input_tensor - def forward(self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]: + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -501,12 +513,14 @@ class ParallelLlamaModelRmPadPP(nn.Module): hidden_states = self.input_tensor for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer(hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = layer_outputs @@ -517,21 +531,23 @@ class ParallelLlamaModelRmPadPP(nn.Module): class ParallelLlamaForCausalLMRmPadPP(nn.Module): - - def __init__(self, - config: LlamaConfig, - megatron_config: ModelParallelConfig, - pre_process, - post_process, - share_embeddings_and_output_weights=False): + def __init__( + self, + config: LlamaConfig, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPadPP(config, - megatron_config=megatron_config, - pre_process=pre_process, - post_process=post_process) - assert share_embeddings_and_output_weights == False, f'Llama Model not supports sharing embedding and output weights' + self.model = ParallelLlamaModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) + assert share_embeddings_and_output_weights == False, ( + "Llama Model not supports sharing embedding and output weights" + ) self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process @@ -553,14 +569,16 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module): def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head @@ -592,8 +610,9 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module): # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model batch_size, sequence_length = input_ids.shape # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), - attention_mask) # (total_nnz, 1) + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -602,12 +621,14 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module): input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) - outputs = self.model(input_ids=input_ids_rmpad, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) if self.post_process: hidden_states = outputs @@ -620,8 +641,9 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module): totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -635,11 +657,10 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module): class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): - def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) # lm_head is effectively the same as sequence parallel diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py index fefabfeca..17afad6d8 100644 --- a/verl/models/mcore/loader.py +++ b/verl/models/mcore/loader.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import time + +import torch import torch.distributed as dist + from .saver import _megatron_calc_global_rank @@ -26,7 +28,6 @@ def _megatron_calc_layer_map(config): mapping from the global layer index to a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) """ - import megatron from megatron.core import mpu pp_size = mpu.get_pipeline_model_parallel_world_size() @@ -38,8 +39,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -50,15 +52,14 @@ def _megatron_calc_layer_map(config): def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): - """Load merged state_dict to sharded Megatron module in training. - """ - import megatron - from megatron.core import mpu - from verl.utils.megatron_utils import print_rank_0, unwrap_model - from megatron.core.transformer.module import Float16Module + """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP + from verl.utils.megatron_utils import print_rank_0, unwrap_model + start_time = time.time() def _get_gpt_model(model): @@ -66,9 +67,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par def broadcast_params(module): for param in module.parameters(): - torch.distributed.broadcast(param.data, - src=mpu.get_data_parallel_src_rank(), - group=mpu.get_data_parallel_group()) + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -167,8 +168,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -213,8 +215,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -234,16 +237,16 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par if torch.distributed.get_rank() == src_rank: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -266,9 +269,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par requires_grad=False, ) else: - assert ( - tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -286,7 +289,7 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par tp_size = mpu.get_tensor_model_parallel_world_size() if torch.distributed.get_rank() == src_rank: - assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) + assert q_name in state_dict and k_name in state_dict and v_name in state_dict full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -302,18 +305,19 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par sizes.append(config.hidden_size) new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] num_query_groups_per_partition = models[0].config.num_query_groups // tp_size - new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size] + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) total_size_per_head = total_size // num_query_groups_per_partition for j in range(num_query_groups_per_partition): - new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size @@ -324,19 +328,20 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par sizes.append(config.hidden_size) new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size] + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) total_size_per_head = total_size // config.num_attention_heads for j in range(config.num_attention_heads): - new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -359,8 +364,9 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -409,7 +415,8 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par f"{layer_name}.self_attn.q_proj.bias", f"{layer_name}.self_attn.k_proj.bias", f"{layer_name}.self_attn.v_proj.bias", - bias=True) + bias=True, + ) _broadcast_tp_shard_tensor( sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, @@ -421,8 +428,11 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par f"{layer_name}.post_attention_layernorm.weight", ) - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) _broadcast_tp_shard_tensor( sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, @@ -445,14 +455,14 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par if is_value_model: # if torch.distributed.get_rank() == src_rank: - if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "lm_head.weight") - elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0('load lm_head from value_head weight') + print_rank_0("load lm_head from value_head weight") else: _broadcast_tensor(None, "lm_head.weight") - print_rank_0('fail to match lm_head in value_model') + print_rank_0("fail to match lm_head in value_model") # else: # _broadcast_tensor(lm_head_weight, "lm_head.weight") diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py index 5598ab2d2..53a19820b 100644 --- a/verl/models/mcore/saver.py +++ b/verl/models/mcore/saver.py @@ -13,23 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from verl.utils.megatron_utils import print_rank_0, unwrap_model -from megatron.core import mpu -from megatron.core.transformer.module import Float16Module -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from torch.nn.parallel import DistributedDataParallel as torchDDP -import torch import time import torch import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.megatron_utils import print_rank_0, unwrap_model -def _megatron_calc_global_rank(tp_rank: int = 0, - dp_rank: int = 0, - pp_rank: int = 0, - cp_rank: int = 0, - ep_rank: int = 0): +def _megatron_calc_global_rank( + tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 +): """Calculate global rank with support for CP/EP parallelism""" # Get parallel sizes for each dimension @@ -41,8 +39,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, # Verify total GPU count matches (must be consistent with parallel_state.py) total_size = tp_size * dp_size * pp_size * cp_size - assert total_size == torch.distributed.get_world_size(), \ + assert total_size == torch.distributed.get_world_size(), ( f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" + ) # Core calculation logic (corresponds to RankGenerator order parameter) # Assumes default order is "tp-cp-ep-dp-pp" @@ -67,8 +66,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -121,9 +121,11 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].decoder.layers - ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format( - len(models[i].decoder.layers), num_layers_per_model) + assert len(models[i].decoder.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].decoder.layers), num_layers_per_model + ) + ) state_dict = dict() @@ -261,7 +263,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -321,13 +323,13 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_size_chunk = q_size_tp // num_query_groups_per_partition kv_size_chunk = kv_size_tp // num_query_groups_per_partition for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] q_weight_list.append(q_part) k_weight_list.append(k_part) v_weight_list.append(v_part) @@ -337,13 +339,13 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_size_chunk = q_size_tp // num_query_groups_per_partition kv_size_chunk = kv_size_tp // num_query_groups_per_partition for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] q_weight_list.append(q_part) if i * config.num_key_value_heads % tp_size == 0: k_weight_list.append(k_part) @@ -393,8 +395,10 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F src_pp_rank=src_pp_rank, ) - if getattr(sync_layer.self_attention.linear_qkv, 'bias', - None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0: + if ( + getattr(sync_layer.self_attention.linear_qkv, "bias", None) is not None + and sync_layer.self_attention.linear_qkv.bias.numel() > 0 + ): _broadcast_tp_shard_tensor_qkv( sync_layer.self_attention.linear_qkv.bias, f"{layer_name}.self_attn.q_proj.bias", @@ -416,10 +420,12 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F src_pp_rank=src_pp_rank, ) - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank) + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) _broadcast_tp_shard_tensor( sync_layer.mlp.linear_fc2.weight, @@ -439,7 +445,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F ) if tie_word_embeddings: - print_rank_0(f"tie word embedding skip load lm_head...") + print_rank_0("tie word embedding skip load lm_head...") else: print_rank_0("collecting lm_head...") @@ -459,7 +465,6 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F dist.barrier() torch.cuda.empty_cache() if torch.distributed.get_rank() == 0: - for k, v in state_dict.items(): if dtype != v.dtype: state_dict[k] = v.to(dtype) diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index fcf406d44..8b55d82b4 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -13,17 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from verl.utils.megatron import sequence_parallel as sp_utils -from verl.utils.megatron import tensor_parallel as tp_utils import torch -from megatron.core.packed_seq_params import PackedSeqParams from megatron.core import parallel_state as mpu -from verl.utils.megatron_utils import unwrap_model +from megatron.core.packed_seq_params import PackedSeqParams -def preprocess_packed_seqs(input_ids: torch.Tensor, - attention_mask: torch.Tensor, - pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]: +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: """ Preprocess packed sequences CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking. @@ -55,42 +52,49 @@ def preprocess_packed_seqs(input_ids: torch.Tensor, for i in range(batch_size): if cp_size <= 1: seqlen = seqlens_in_batch[i] - input_ids_rmpad[cu_seqlens_padded[i]:cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] + input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] continue seqlen = seqlens_in_batch_padded[i] // cp_size half_seqlen = seqlen // 2 start_idx = cu_seqlens_padded[i] // cp_size # split to 2 chunks d = input_ids[i, attention_mask[i]] - input_ids_rmpad[start_idx:start_idx + half_seqlen] = d[half_seqlen * cp_rank:half_seqlen * (cp_rank + 1)] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1) remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank remain_end = min(remain_end, d.shape[0]) remain_len = remain_end - remain_start if remain_len > 0: - input_ids_rmpad[start_idx + half_seqlen:start_idx + half_seqlen + - remain_len] = d[remain_start:remain_end] + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] - packed_seq_params = PackedSeqParams(qkv_format='thd', - cu_seqlens_q=cu_seqlens_padded, - max_seqlen_q=max_seqlen_in_batch, - cu_seqlens_kv=cu_seqlens_padded, - max_seqlen_kv=max_seqlen_in_batch, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded) + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) if pre_process: return input_ids_rmpad.unsqueeze(0), packed_seq_params else: return input_ids, packed_seq_params -def postprocess_packed_seqs(output: torch.Tensor, - packed_seq_params: PackedSeqParams, - attention_mask: torch.Tensor, - batch_size: int, - seq_len: int, - post_process: bool = True) -> torch.Tensor: +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: """ Postprocess packed sequences """ @@ -112,13 +116,13 @@ def postprocess_packed_seqs(output: torch.Tensor, for i in range(batch_size): if cp_size <= 1: s = attention_mask[i].sum().item() - output_new[i, - attention_mask[i]] = output[0][packed_seq_params. - cu_seqlens_q_padded[i]:packed_seq_params.cu_seqlens_q_padded[i] + - s] + output_new[i, attention_mask[i]] = output[0][ + packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s + ] continue - s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] - - packed_seq_params.cu_seqlens_q_padded[i]) // cp_size + s_len_padded_chunk = ( + packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i] + ) // cp_size half_seqlen = s_len_padded_chunk // 2 s_len = attention_mask[i].sum().item() s_len_padded = s_len_padded_chunk * cp_size @@ -127,20 +131,24 @@ def postprocess_packed_seqs(output: torch.Tensor, o = output_list[j][0] # split to 2 chunks packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size - o0, o1 = o[packed_start_idx:packed_start_idx + - half_seqlen], o[packed_start_idx + half_seqlen:packed_start_idx + s_len_padded_chunk] - tmp[j * half_seqlen:(j + 1) * half_seqlen] = o0 - tmp[s_len_padded - (j + 1) * half_seqlen:s_len_padded - j * half_seqlen] = o1 + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 output_new[i, attention_mask[i]] = tmp[:s_len] return output_new -def remove_left_padding(input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - sequence_parallel: bool = False, - pre_process: bool = True): +def remove_left_padding( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + sequence_parallel: bool = False, + pre_process: bool = True, +): """ Remove left padding from input_ids, attention_mask and position_ids return new_input_ids, new_attention_mask, new_position_ids @@ -148,39 +156,42 @@ def remove_left_padding(input_ids: torch.Tensor, assert attention_mask.ndim == 2 assert position_ids.ndim == 2 cp_size = mpu.get_context_parallel_world_size() - assert cp_size == 1, 'Context parallel size without seq_pack is not supported' + assert cp_size == 1, "Context parallel size without seq_pack is not supported" batch_size = input_ids.shape[0] shape = list(input_ids.shape) # batch_size, seq_len,... seq_lens = attention_mask.sum(dim=1) seq_len = seq_lens.max().item() if sequence_parallel: from megatron.core import parallel_state as mpu + sp_world_size = mpu.get_tensor_model_parallel_world_size() pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size seq_len = seq_len + pad_size shape[1] = seq_len if pre_process: new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) - new_attention_mask = torch.zeros(dtype=attention_mask.dtype, - device=attention_mask.device, - size=(batch_size, seq_len)) + new_attention_mask = torch.zeros( + dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) + ) new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) for i in range(batch_size): if pre_process: - new_input_ids[i, :seq_lens[i]] = input_ids[i, attention_mask[i]] - new_attention_mask[i, :seq_lens[i]] = attention_mask[i, attention_mask[i]] - new_position_ids[i, :seq_lens[i]] = position_ids[i, attention_mask[i]] + new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] + new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] + new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] if pre_process: return new_input_ids, new_attention_mask, new_position_ids else: return input_ids, new_attention_mask, new_position_ids -def recover_left_padding(result, - attention_mask: torch.Tensor, - original_attention_mask: torch.Tensor, - origin_seqlen: int, - post_process: bool = True): +def recover_left_padding( + result, + attention_mask: torch.Tensor, + original_attention_mask: torch.Tensor, + origin_seqlen: int, + post_process: bool = True, +): """ Recover left padding from result return result diff --git a/verl/models/qwen2/megatron/__init__.py b/verl/models/qwen2/megatron/__init__.py index 26ff01378..be95fbefc 100644 --- a/verl/models/qwen2/megatron/__init__.py +++ b/verl/models/qwen2/megatron/__init__.py @@ -13,12 +13,13 @@ # limitations under the License. from .modeling_qwen2_megatron import ( - # original model with megatron - ParallelQwen2Model, ParallelQwen2ForCausalLM, # rmpad with megatron ParallelQwen2ForCausalLMRmPad, - ParallelQwen2ForValueRmPad, # rmpad with megatron and pipeline parallelism ParallelQwen2ForCausalLMRmPadPP, - ParallelQwen2ForValueRmPadPP) + ParallelQwen2ForValueRmPad, + ParallelQwen2ForValueRmPadPP, + # original model with megatron + ParallelQwen2Model, +) diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py index 87e1561a4..d4397af17 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import time -from typing import Dict, Any, Callable, Optional + +import torch import torch.distributed as dist @@ -36,8 +36,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -47,20 +48,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_qwen2(state_dict, - wrapped_models, - config, - params_dtype, - is_value_model=False, - tie_word_embeddings=False): - """Load merged state_dict to sharded Megatron module in training. - """ - from verl.utils.megatron_utils import print_rank_0, unwrap_model +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu from megatron.core.transformer.module import Float16Module - from megatron.core import DistributedDataParallel as LocalDDP from torch.nn.parallel import DistributedDataParallel as torchDDP + from verl.utils.megatron_utils import print_rank_0, unwrap_model + start_time = time.time() def _get_gpt_model(model): @@ -68,9 +66,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, def fetch_params(module): for param in module.parameters(): - torch.distributed.fetch(param.data, - src=mpu.get_data_parallel_src_rank(), - group=mpu.get_data_parallel_group()) + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -88,7 +86,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}' + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -144,16 +144,16 @@ def load_state_dict_to_megatron_qwen2(state_dict, if gate_name in state_dict and up_name in state_dict: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) if tensor is not None: @@ -167,7 +167,7 @@ def load_state_dict_to_megatron_qwen2(state_dict, nonlocal mp_group tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() - assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) + assert q_name in state_dict and k_name in state_dict and v_name in state_dict full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -179,40 +179,38 @@ def load_state_dict_to_megatron_qwen2(state_dict, kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) if tensor is not None: @@ -240,9 +238,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, for vpp_rank in range(vpp_size): num_layer_vpp_chunk = num_layer_per_pp // vpp_size num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * ( - config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + \ - (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) layer_list.extend(list(range(offset, offset + num_layer_this_model))) else: num_layer_this_model = num_layer_per_pp @@ -255,7 +253,7 @@ def load_state_dict_to_megatron_qwen2(state_dict, dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] print( - f'{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}' + f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}" ) gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) @@ -273,11 +271,13 @@ def load_state_dict_to_megatron_qwen2(state_dict, f"{layer_name}.self_attn.v_proj.weight", ) - _fetch_tp_shard_tensor_qkv(sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True) + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) _fetch_tp_shard_tensor( sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, @@ -290,8 +290,11 @@ def load_state_dict_to_megatron_qwen2(state_dict, f"{layer_name}.post_attention_layernorm.weight", ) - _fetch_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) _fetch_tp_shard_tensor( sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, @@ -315,15 +318,15 @@ def load_state_dict_to_megatron_qwen2(state_dict, lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: _fetch_tensor(lm_head_weight, "lm_head.weight") - print_rank_0('load lm_head from value_head weight') - elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: _fetch_tensor(lm_head_weight, "reward_head.weight") - print_rank_0('load lm_head from value_head weight') + print_rank_0("load lm_head from value_head weight") else: _fetch_tensor(None, "lm_head.weight") - print_rank_0('fail to match lm_head in value_model') + print_rank_0("fail to match lm_head in value_model") else: _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py index 3d14887cd..326768cdd 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import time -from typing import Dict, Any, Callable, Optional + +import torch import torch.distributed as dist @@ -36,8 +36,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -47,20 +48,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_qwen2(state_dict, - wrapped_models, - config, - params_dtype, - is_value_model=False, - tie_word_embeddings=False): - """Load merged state_dict to sharded Megatron module in training. - """ - from verl.utils.megatron_utils import print_rank_0, unwrap_model +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu from megatron.core.transformer.module import Float16Module - from megatron.core import DistributedDataParallel as LocalDDP from torch.nn.parallel import DistributedDataParallel as torchDDP + from verl.utils.megatron_utils import print_rank_0, unwrap_model + start_time = time.time() def _get_gpt_model(model): @@ -68,9 +66,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, def broadcast_params(module): for param in module.parameters(): - torch.distributed.broadcast(param.data, - src=mpu.get_data_parallel_src_rank(), - group=mpu.get_data_parallel_group()) + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -88,7 +86,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}' + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -167,8 +167,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -213,8 +214,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -234,16 +236,16 @@ def load_state_dict_to_megatron_qwen2(state_dict, if torch.distributed.get_rank() == 0: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -266,9 +268,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, requires_grad=False, ) else: - assert ( - tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -286,7 +288,7 @@ def load_state_dict_to_megatron_qwen2(state_dict, tp_size = mpu.get_tensor_model_parallel_world_size() if torch.distributed.get_rank() == 0: - assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) + assert q_name in state_dict and k_name in state_dict and v_name in state_dict full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -298,42 +300,42 @@ def load_state_dict_to_megatron_qwen2(state_dict, kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], - dim=0)) + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, - dtype=params_dtype, - device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device() + ) for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], - dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -356,8 +358,9 @@ def load_state_dict_to_megatron_qwen2(state_dict, requires_grad=False, ) else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) for i in range(tp_size): @@ -401,11 +404,13 @@ def load_state_dict_to_megatron_qwen2(state_dict, f"{layer_name}.self_attn.v_proj.weight", ) - _broadcast_tp_shard_tensor_qkv(sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True) + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) _broadcast_tp_shard_tensor( sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, @@ -418,8 +423,11 @@ def load_state_dict_to_megatron_qwen2(state_dict, f"{layer_name}.post_attention_layernorm.weight", ) - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) _broadcast_tp_shard_tensor( sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, @@ -444,15 +452,15 @@ def load_state_dict_to_megatron_qwen2(state_dict, lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "lm_head.weight") - print_rank_0('load lm_head from value_head weight') - elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0('load lm_head from value_head weight') + print_rank_0("load lm_head from value_head weight") else: _broadcast_tensor(None, "lm_head.weight") - print_rank_0('fail to match lm_head in value_model') + print_rank_0("fail to match lm_head in value_model") else: _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") @@ -463,4 +471,4 @@ def load_state_dict_to_megatron_qwen2(state_dict, broadcast_params(wrapped_model) torch.cuda.empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") \ No newline at end of file + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py index 547a097b1..d669f716a 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from verl.utils.megatron_utils import print_rank_0, unwrap_model -from megatron.core import mpu -from megatron.core.transformer.module import Float16Module -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from torch.nn.parallel import DistributedDataParallel as torchDDP -import torch import time import torch import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.megatron_utils import print_rank_0, unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): @@ -30,8 +30,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() - assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size() - ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) # We only support TP-DP-PP grouping, for correctness when resharding return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank @@ -54,8 +55,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -107,9 +109,11 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers - ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format( - len(models[i].model.layers), num_layers_per_model) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) state_dict = dict() @@ -247,7 +251,7 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -306,10 +310,10 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp:total_size] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] q_weight_list.append(q_part) k_weight_list.append(k_part) v_weight_list.append(v_part) @@ -318,10 +322,10 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp:total_size] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] q_weight_list.append(q_part) if i * config.num_key_value_heads % tp_size == 0: k_weight_list.append(k_part) @@ -392,10 +396,12 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals src_pp_rank=src_pp_rank, ) - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank) + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) _broadcast_tp_shard_tensor( sync_layer.mlp.down_proj.weight, @@ -415,18 +421,23 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals ) if tie_word_embeddings: - print_rank_0(f"tie word embedding skip load lm_head...") + print_rank_0("tie word embedding skip load lm_head...") else: print_rank_0("collecting lm_head...") if is_value_model: - _broadcast_tensor(gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1) - _broadcast_tensor(gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and - getattr(gpt_model_module, "reward_weight", None) is not None else None, - "reward_head.weight", - src_pp_rank=pp_size - 1) + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) else: _broadcast_tp_shard_tensor( @@ -439,7 +450,6 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals torch.cuda.empty_cache() if torch.distributed.get_rank() == 0: - for k, v in state_dict.items(): if dtype != v.dtype: state_dict[k] = v.to(dtype) diff --git a/verl/models/qwen2/megatron/layers/parallel_attention.py b/verl/models/qwen2/megatron/layers/parallel_attention.py index a4f65e892..f13d2d887 100644 --- a/verl/models/qwen2/megatron/layers/parallel_attention.py +++ b/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -22,31 +22,29 @@ import math from typing import Optional, Tuple import torch +from megatron.core import ModelParallelConfig, tensor_parallel from megatron.core import parallel_state as mpu -from megatron.core import tensor_parallel -from megatron.core import ModelParallelConfig from torch import nn from transformers import Qwen2Config -from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear +from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear from verl.utils.megatron import tensor_parallel as tp_utils class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -99,9 +97,10 @@ class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -115,8 +114,8 @@ class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -157,24 +156,29 @@ class ParallelQwen2Attention(nn.Module): # assign values after tp tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}' - assert self.num_key_value_heads % tp_size == 0, \ - f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}' + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}" + ) self.num_heads_per_tp = self.num_heads // tp_size self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size self.hidden_size_per_tp = self.hidden_size // tp_size if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() if megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - assert row_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) @@ -188,7 +192,8 @@ class ParallelQwen2Attention(nn.Module): bias=True, gather_output=False, skip_bias_add=False, - **column_kwargs) + **column_kwargs, + ) self.q_size = self.num_heads_per_tp * self.head_dim self.k_size = self.num_key_value_heads_per_tp * self.head_dim @@ -201,7 +206,8 @@ class ParallelQwen2Attention(nn.Module): bias=False, input_is_parallel=True, skip_bias_add=False, - **row_kwargs) + **row_kwargs, + ) self._init_rope() @@ -241,12 +247,14 @@ class ParallelQwen2Attention(nn.Module): if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") + f" {attn_weights.size()}" + ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 @@ -256,7 +264,8 @@ class ParallelQwen2Attention(nn.Module): if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) @@ -270,10 +279,9 @@ Remove padding Attention - Compatible with sequence parallel """ -from transformers.utils import is_flash_attn_2_available import torch.nn.functional as F - from einops import rearrange +from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func @@ -302,40 +310,34 @@ from flash_attn.layers.rotary import apply_rotary_emb # use flash-attn rotary embeddings with rmpad # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb(q, - cos, - sin, - interleaved=False, - inplace=False, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen) - k_embed = apply_rotary_emb(k, - cos, - sin, - interleaved=False, - inplace=False, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen) + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) return q_embed, k_embed class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): - - def forward(self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen_in_batch: int = None): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel if self.megatron_config.sequence_parallel: total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], - dim=-1) # (total_nnz, 1, hidden_size) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) if self.megatron_config.sequence_parallel: sequence_parallel_pad = total_nnz - cu_seqlens[-1] @@ -352,13 +354,10 @@ class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, - key_states, - cos, - sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, # It is recommended to use dropout with FA according to the docs diff --git a/verl/models/qwen2/megatron/layers/parallel_decoder.py b/verl/models/qwen2/megatron/layers/parallel_decoder.py index 84562b3bc..4217c2897 100644 --- a/verl/models/qwen2/megatron/layers/parallel_decoder.py +++ b/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -21,19 +21,18 @@ from typing import Optional, Tuple import torch +from megatron.core import ModelParallelConfig from torch import nn from transformers import Qwen2Config -from megatron.core import ModelParallelConfig + +from verl.utils.megatron_utils import TransformerConfig, convert_config from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad from .parallel_mlp import ParallelQwen2MLP from .parallel_rmsnorm import ParallelQwen2RMSNorm -from verl.utils.megatron_utils import TransformerConfig, convert_config - class ParallelQwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -101,7 +100,6 @@ class ParallelQwen2DecoderLayer(nn.Module): class ParallelQwen2DecoderLayerRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -120,7 +118,7 @@ class ParallelQwen2DecoderLayerRmPad(nn.Module): sequence_length: int = None, indices: torch.Tensor = None, cu_seqlens: int = None, - max_seqlen_in_batch: int = None + max_seqlen_in_batch: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # (total_nnz // sp, 1, hidden_size) @@ -129,12 +127,14 @@ class ParallelQwen2DecoderLayerRmPad(nn.Module): # Self Attention # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) - hidden_states = self.self_attn(hidden_states=hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = residual + hidden_states diff --git a/verl/models/qwen2/megatron/layers/parallel_linear.py b/verl/models/qwen2/megatron/layers/parallel_linear.py index bfe5cf4e6..e6d4a09f4 100644 --- a/verl/models/qwen2/megatron/layers/parallel_linear.py +++ b/verl/models/qwen2/megatron/layers/parallel_linear.py @@ -13,23 +13,23 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py -from typing import Optional, Tuple from megatron.core import tensor_parallel class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): - - def __init__(self, - input_size, - num_heads, - num_key_value_heads, - head_dim, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): # Keep input parameters, and already restrict the head numbers self.input_size = input_size self.q_output_size = num_heads * head_dim @@ -41,34 +41,39 @@ class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): input_size = self.input_size output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim - super().__init__(input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs) + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): - - def __init__(self, - input_size, - gate_ouput_size, - up_output_size, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): # Keep input parameters, and already restrict the head numbers self.input_size = input_size self.output_size = gate_ouput_size + up_output_size self.gather_output = gather_output self.skip_bias_add = skip_bias_add - super().__init__(input_size=self.input_size, - output_size=self.output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs) + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) diff --git a/verl/models/qwen2/megatron/layers/parallel_mlp.py b/verl/models/qwen2/megatron/layers/parallel_mlp.py index 48b977119..672908a21 100644 --- a/verl/models/qwen2/megatron/layers/parallel_mlp.py +++ b/verl/models/qwen2/megatron/layers/parallel_mlp.py @@ -18,18 +18,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from megatron.core import ModelParallelConfig, tensor_parallel from megatron.core import parallel_state as mpu -from megatron.core import tensor_parallel -from megatron.core import ModelParallelConfig from torch import nn from transformers.activations import ACT2FN -from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear from verl.utils.megatron import tensor_parallel as tp_utils class ParallelQwen2MLP(nn.Module): - def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: super().__init__() self.config = config @@ -41,8 +39,8 @@ class ParallelQwen2MLP(nn.Module): row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() if megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - assert row_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) @@ -59,12 +57,14 @@ class ParallelQwen2MLP(nn.Module): ) self.gate_size = self.intermediate_size // tp_size - self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs) + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) self.act_fn = ACT2FN[config.hidden_act] diff --git a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py index 726eb7f89..2f4c90dd4 100644 --- a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py +++ b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py @@ -13,17 +13,17 @@ # limitations under the License. import numbers + import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine from megatron.core import ModelParallelConfig from torch import nn from transformers import Qwen2Config -from apex.normalization.fused_layer_norm import fused_rms_norm_affine from verl.utils.megatron import sequence_parallel as sp_utils class ParallelQwen2RMSNorm(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): """ Qwen2RMSNorm is equivalent to T5LayerNorm @@ -39,8 +39,10 @@ class ParallelQwen2RMSNorm(nn.Module): sp_utils.mark_parameter_as_sequence_parallel(self.weight) def forward(self, hidden_states): - return fused_rms_norm_affine(input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True) \ No newline at end of file + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index c15111bd6..b5a33f6b0 100644 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -17,16 +17,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Qwen2 model.""" +"""PyTorch Qwen2 model.""" from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint -from megatron.core import tensor_parallel, parallel_state -from megatron.core import ModelParallelConfig -from megatron.core import mpu - +from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.qwen2.configuration_qwen2 import Qwen2Config @@ -35,7 +32,9 @@ from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config -from .layers import ParallelQwen2DecoderLayer, ParallelQwen2RMSNorm, ParallelQwen2DecoderLayerRmPad + +from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm + """ TODO: 1. Add weight initialization. Here we need to be careful on TP weight init. @@ -87,14 +86,15 @@ class ParallelQwen2Model(nn.Module): self.vocab_size = config.vocab_size embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelQwen2RMSNorm(config, megatron_config) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask @@ -111,10 +111,12 @@ class ParallelQwen2Model(nn.Module): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -157,7 +159,6 @@ class ParallelQwen2Model(nn.Module): class ParallelQwen2ForCausalLM(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -166,15 +167,17 @@ class ParallelQwen2ForCausalLM(nn.Module): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) def forward( self, @@ -233,23 +236,26 @@ class ParallelQwen2ModelRmPad(nn.Module): embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() self.megatron_config = megatron_config if megatron_config is not None: - assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelQwen2RMSNorm(config, megatron_config) - def forward(self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]: + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -268,12 +274,14 @@ class ParallelQwen2ModelRmPad(nn.Module): hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer(hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = layer_outputs @@ -283,7 +291,6 @@ class ParallelQwen2ModelRmPad(nn.Module): class ParallelQwen2ForCausalLMRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) @@ -295,14 +302,16 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module): def _init_head(self, config: Qwen2Config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head @@ -329,8 +338,9 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module): batch_size, sequence_length = input_ids.shape # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), - attention_mask) # (total_nnz, 1) + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -339,12 +349,14 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module): input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) - outputs = self.model(input_ids=input_ids, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = outputs @@ -357,8 +369,9 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module): logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # add removed padding back - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -370,11 +383,10 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module): class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): - def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) # lm_head is effectively the same as sequence parallel @@ -423,12 +435,12 @@ class ParallelQwen2ModelRmPadPP(nn.Module): self.megatron_config = megatron_config embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) else: self.embed_tokens = None @@ -441,9 +453,7 @@ class ParallelQwen2ModelRmPadPP(nn.Module): if vpp_size is not None: self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * ( - config.num_hidden_layers // vpp_size) + \ - (pp_rank * self.num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) else: self.num_layer_this_model = self.num_layer_per_pp offset = pp_rank * self.num_layer_per_pp @@ -451,7 +461,7 @@ class ParallelQwen2ModelRmPadPP(nn.Module): self.layers = nn.ModuleList() for i in range(self.num_layer_this_model): layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) - self.layers.add_module(f'{i}', layer) + self.layers.add_module(f"{i}", layer) if post_process: self.norm = ParallelQwen2RMSNorm(config, megatron_config) @@ -468,13 +478,15 @@ class ParallelQwen2ModelRmPadPP(nn.Module): forward_step_func""" self.input_tensor = input_tensor - def forward(self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]: + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -500,12 +512,14 @@ class ParallelQwen2ModelRmPadPP(nn.Module): hidden_states = self.input_tensor for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer(hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) hidden_states = layer_outputs @@ -516,16 +530,20 @@ class ParallelQwen2ModelRmPadPP(nn.Module): class ParallelQwen2ForCausalLMRmPadPP(nn.Module): - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process, - share_embeddings_and_output_weights): + def __init__( + self, + config: Qwen2Config, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights, + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.megatron_config = megatron_config - self.model = ParallelQwen2ModelRmPadPP(config, - megatron_config=megatron_config, - pre_process=pre_process, - post_process=post_process) + self.model = ParallelQwen2ModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process @@ -549,16 +567,17 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module): def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - skip_weight_param_allocation=self.pre_process and - self.share_embeddings_and_output_weights, - **column_kwargs) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + **column_kwargs, + ) def setup_embeddings_and_output_layer(self) -> None: """Sets up embedding layer in first stage and output layer in last stage. @@ -640,8 +659,9 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module): # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model batch_size, sequence_length = input_ids.shape # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), - attention_mask) # (total_nnz, 1) + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -650,12 +670,14 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module): input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) - outputs = self.model(input_ids=input_ids_rmpad, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch) + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) if self.post_process: hidden_states = outputs @@ -667,8 +689,9 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module): totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -682,11 +705,10 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module): class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): - def _init_head(self, config): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if self.megatron_config is not None: - assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) # lm_head is effectively the same as sequence parallel diff --git a/verl/models/registry.py b/verl/models/registry.py index dfe777f8b..6fa8effd4 100644 --- a/verl/models/registry.py +++ b/verl/models/registry.py @@ -20,18 +20,23 @@ import torch.nn as nn # Supported models in Megatron-LM # Architecture -> (module, class). _MODELS = { - "LlamaForCausalLM": - ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")), - "Qwen2ForCausalLM": - ("qwen2", ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad")), - "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", - "ParallelMistralForCausalLMRmPad")) + "LlamaForCausalLM": ( + "llama", + ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad"), + ), + "Qwen2ForCausalLM": ( + "qwen2", + ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad"), + ), + "MistralForCausalLM": ( + "mistral", + ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"), + ), } # return model class class ModelRegistry: - @staticmethod def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index 886ccb67d..b172791df 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -12,35 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from typing import Optional, Tuple, Callable import sys -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack +from typing import Callable, Optional, Tuple + +import torch + +if sys.version_info >= (3, 11): + pass +else: + pass -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.cache_utils import Cache -from transformers.utils import logging from transformers.modeling_flash_attention_utils import _flash_attention_forward -from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ - get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.utils import logging + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) logger = logging.get_logger(__name__) + def llama_flash_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. @@ -83,7 +91,8 @@ def llama_flash_attn_forward( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory.") + "removed and `position_embeddings` will be mandatory." + ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings @@ -121,7 +130,8 @@ def llama_flash_attn_forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) @@ -168,8 +178,9 @@ def llama_attn_forward( NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. """ - from transformers.models.llama.modeling_llama import eager_attention_forward from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.llama.modeling_llama import eager_attention_forward + bsz, q_len, _ = hidden_states.shape query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 6a6f47e8c..e2c3c58ab 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -14,18 +14,19 @@ """ Apply monkey-patch function to models """ + import sys from typing import Optional import torch -from transformers.modeling_utils import PreTrainedModel from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_utils import PreTrainedModel from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, - get_ulysses_sequence_parallel_world_size, get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, ) @@ -92,12 +93,9 @@ def _ulysses_flash_attention_forward( position_ids = torch.concat(position_ids_list, dim=-1) # (bsz, seq_len, n_head/n, head_dim) - attn_output = _flash_attention_forward(query_states, - key_states, - value_states, - *args, - position_ids=position_ids, - **kwargs) + attn_output = _flash_attention_forward( + query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs + ) ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: @@ -112,17 +110,20 @@ def apply_monkey_patch(model: PreTrainedModel, ulysses_sp_size: int): module = sys.modules[model.__module__] num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads - assert num_attention_heads % ulysses_sp_size == 0, \ + assert num_attention_heads % ulysses_sp_size == 0, ( f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + ) assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" f"or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0," - f"kv heads are repeated to ensure correctness.") + f"kv heads are repeated to ensure correctness." + ) # TODO: VLM models only, unify monkey patch to LLM models. if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope - from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 + + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward @@ -136,16 +137,18 @@ def apply_monkey_patch(model: PreTrainedModel, ulysses_sp_size: int): else: # transformers>=4.48.0 from transformers.integrations import flash_attention + flash_attention._flash_attention_forward = _ulysses_flash_attention_forward print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") -from functools import lru_cache -from packaging import version import importlib.metadata +from functools import lru_cache + +from packaging import version -@lru_cache() +@lru_cache def is_transformers_version_in_range(min_version: str, max_version: str) -> bool: try: # Get the installed version of the transformers library diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py index 63d9ae98b..3c4be2d35 100644 --- a/verl/models/transformers/qwen2.py +++ b/verl/models/transformers/qwen2.py @@ -12,30 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from typing import Optional, Tuple, Callable +from typing import Callable, Optional, Tuple -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +import torch from transformers.cache_utils import Cache -from transformers.utils import logging from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.processing_utils import Unpack -from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ - get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.utils import logging + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) logger = logging.get_logger(__name__) def qwen2_flash_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. @@ -70,7 +74,8 @@ def qwen2_flash_attn_forward( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory.") + "removed and `position_embeddings` will be mandatory." + ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings @@ -101,7 +106,8 @@ def qwen2_flash_attn_forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) @@ -112,8 +118,11 @@ def qwen2_flash_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and - self.layer_idx >= self.config.max_window_layers): + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): sliding_window = self.config.sliding_window else: sliding_window = None @@ -160,6 +169,7 @@ def qwen2_attn_forward( NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. """ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + bsz, q_len, _ = hidden_states.shape hidden_shape = (bsz, q_len, -1, self.head_dim) @@ -189,11 +199,15 @@ def qwen2_attn_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) sliding_window = None - if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and - self.layer_idx >= self.config.max_window_layers): + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): sliding_window = self.config.sliding_window from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 718b9ca6f..888a5272a 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple import inspect -import torch import os -from transformers.utils import is_flash_attn_greater_or_equal +from typing import Optional, Tuple + +import torch from transformers.modeling_flash_attention_utils import _flash_attention_forward -from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ - get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from transformers.utils import is_flash_attn_greater_or_equal + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) try: from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -132,17 +138,20 @@ def get_rope_index( return position_ids -def prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - position_ids: torch.Tensor): +def prepare_fa2_from_position_ids( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor +): query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) value = value.view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - cu_seqlens = torch.cat(( - indices_q[position_ids == 0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), - )) + cu_seqlens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) @@ -169,8 +178,9 @@ def flash_attention_forward( causal = is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = (_flash_supports_window_size and sliding_window is not None and - key_states.shape[1] > sliding_window) + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} if is_flash_attn_greater_or_equal("2.4.1"): @@ -181,7 +191,8 @@ def flash_attention_forward( if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): batch_size = query_states.size(0) query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids[0]) # remove channel dimension + query_states, key_states, value_states, position_ids[0] + ) # remove channel dimension cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output = flash_attn_varlen_func( @@ -223,7 +234,7 @@ def ulysses_flash_attn_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, None, None]: - from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv, apply_multimodal_rotary_pos_emb + from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) @@ -255,8 +266,9 @@ def ulysses_flash_attn_forward( else: cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, - self.rope_scaling["mrope_section"]) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) dropout_rate = 0.0 if not self.training else self.attention_dropout # Reashape to the expected shape for Flash Attention @@ -264,8 +276,11 @@ def ulysses_flash_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and - self.layer_idx >= self.config.max_window_layers): + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): sliding_window = self.config.sliding_window else: sliding_window = None diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py index a70076126..36302293d 100644 --- a/verl/models/weight_loader_registry.py +++ b/verl/models/weight_loader_registry.py @@ -14,29 +14,31 @@ def get_weight_loader(arch: str): - from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama - from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2 from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel + _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { - 'LlamaForCausalLM': load_state_dict_to_megatron_gptmodel, - 'Qwen2ForCausalLM': load_state_dict_to_megatron_gptmodel, + "LlamaForCausalLM": load_state_dict_to_megatron_gptmodel, + "Qwen2ForCausalLM": load_state_dict_to_megatron_gptmodel, } if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] - raise ValueError(f"Model architectures {arch} loader are not supported for now. " - f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}") + raise ValueError( + f"Model architectures {arch} loader are not supported for now. " + f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}" + ) def get_weight_saver(arch: str): - from verl.models.llama.megatron.checkpoint_utils.llama_saver import merge_megatron_ckpt_llama - from verl.models.qwen2.megatron.checkpoint_utils.qwen2_saver import merge_megatron_ckpt_qwen2 from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel + _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { - 'LlamaForCausalLM': merge_megatron_ckpt_gptmodel, - 'Qwen2ForCausalLM': merge_megatron_ckpt_gptmodel, + "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, } if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] - raise ValueError(f"Model architectures {arch} saver are not supported for now. " - f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}") + raise ValueError( + f"Model architectures {arch} saver are not supported for now. " + f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}" + ) diff --git a/verl/protocol.py b/verl/protocol.py index 481d53c0d..dcade9e00 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -16,22 +16,22 @@ Implement base data transfer protocol between any two functions, modules. We can subclass Protocol to define more detailed batch info with specific keys """ -import pickle -import numpy as np -import pandas as pd import copy +import pickle from dataclasses import dataclass, field from typing import Callable, Dict, List, Union -import torch +import numpy as np +import pandas as pd import tensordict +import torch from packaging import version from tensordict import TensorDict -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from verl.utils.py_functional import union_two_dict -__all__ = ['DataProto', 'union_tensor_dict'] +__all__ = ["DataProto", "union_tensor_dict"] try: tensordict.set_lazy_legacy(False).set() @@ -39,7 +39,7 @@ except: pass -def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int): +def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int): """Pad a DataProto to size divisible by size_divisor Args: @@ -49,7 +49,7 @@ def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int): data: (DataProto): the padded DataProto pad_size (int) """ - assert isinstance(data, DataProto), 'data must be a DataProto' + assert isinstance(data, DataProto), "data must be a DataProto" if len(data) % size_divisor != 0: pad_size = size_divisor - len(data) % size_divisor padding_protos = [] @@ -65,7 +65,7 @@ def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int): return data_padded, pad_size -def unpad_dataproto(data: 'DataProto', pad_size): +def unpad_dataproto(data: "DataProto", pad_size): if pad_size != 0: data = data[:-pad_size] return data @@ -73,14 +73,16 @@ def unpad_dataproto(data: 'DataProto', pad_size): def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: """Union two tensordicts.""" - assert tensor_dict1.batch_size == tensor_dict2.batch_size, \ - f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}' + assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( + f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + ) for key in tensor_dict2.keys(): if key not in tensor_dict1.keys(): tensor_dict1[key] = tensor_dict2[key] else: - assert tensor_dict1[key].equal(tensor_dict2[key]), \ - f'{key} in tensor_dict1 and tensor_dict2 are not the same object' + assert tensor_dict1[key].equal(tensor_dict2[key]), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) return tensor_dict1 @@ -91,8 +93,9 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str assert isinstance(tensor_dict2[key], np.ndarray) assert isinstance(tensor_dict1[key], np.ndarray) # to properly deal with nan and object type - assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), \ - f'{key} in tensor_dict1 and tensor_dict2 are not the same object' + assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) tensor_dict1[key] = val return tensor_dict1 @@ -110,7 +113,7 @@ def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): return output -def fold_batch_dim(data: 'DataProto', new_batch_size): +def fold_batch_dim(data: "DataProto", new_batch_size): """ Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] """ @@ -130,7 +133,7 @@ def fold_batch_dim(data: 'DataProto', new_batch_size): return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) -def unfold_batch_dim(data: 'DataProto', batch_dims=2): +def unfold_batch_dim(data: "DataProto", batch_dims=2): """ Unfold the first n dims as new batch dim """ @@ -149,7 +152,7 @@ def unfold_batch_dim(data: 'DataProto', batch_dims=2): return DataProto(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) -def collate_fn(x: list['DataProtoItem']): +def collate_fn(x: list["DataProtoItem"]): batch = [] non_tensor_batch = [] for data in x: @@ -178,6 +181,7 @@ class DataProto: TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the same batch size should be put inside batch. """ + batch: TensorDict = None non_tensor_batch: Dict = field(default_factory=dict) meta_info: Dict = field(default_factory=dict) @@ -198,7 +202,7 @@ class DataProto: def __getitem__(self, item): """ Enhanced indexing for DataProto objects. - + Args: item: Can be one of: - int: A single index @@ -206,7 +210,7 @@ class DataProto: - list: A list of indices - numpy.ndarray: An array of indices - torch.Tensor: A tensor of indices - + Returns: DataProto: For all indexing types except single integers DataProtoItem: Only for single integer indices @@ -231,8 +235,9 @@ class DataProto: def __getstate__(self): import io + buffer = io.BytesIO() - if version.parse(tensordict.__version__) >= version.parse('0.5.0') and self.batch is not None: + if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: self.batch = self.batch.contiguous() self.batch = self.batch.consolidate() torch.save(self.batch, buffer) @@ -241,22 +246,23 @@ class DataProto: def __setstate__(self, data): import io + batch_deserialized_bytes, non_tensor_batch, meta_info = data batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) - batch = torch.load(batch_deserialized, - weights_only=False, - map_location='cpu' if not torch.cuda.is_available() else None) + batch = torch.load( + batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None + ) self.batch = batch self.non_tensor_batch = non_tensor_batch self.meta_info = meta_info def save_to_disk(self, filepath): - with open(filepath, 'wb') as f: + with open(filepath, "wb") as f: pickle.dump(self, f) @staticmethod - def load_from_disk(filepath) -> 'DataProto': - with open(filepath, 'rb') as f: + def load_from_disk(filepath) -> "DataProto": + with open(filepath, "rb") as f: data = pickle.load(f) return data @@ -271,10 +277,10 @@ class DataProto: size_of_numpy_array /= 1024**3 size_of_tensordict /= 1024**3 - message = f'Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB' + message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB" if prefix: - message = f'{prefix}, ' + message + message = f"{prefix}, " + message print(message) def check_consistency(self): @@ -282,7 +288,7 @@ class DataProto: We expose this function as a public one so that user can call themselves directly """ if self.batch is not None: - assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1' + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" if self.non_tensor_batch is not None: for key, val in self.non_tensor_batch.items(): @@ -290,15 +296,16 @@ class DataProto: if self.batch is not None and len(self.non_tensor_batch) != 0: # TODO: we can actually lift this restriction if needed - assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.' + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." batch_size = self.batch.batch_size[0] for key, val in self.non_tensor_batch.items(): - assert isinstance( - val, np.ndarray - ), f'data in the non_tensor_batch must be a numpy.array with dtype=object, but for {key=}, got {type(val)=}' - assert val.shape[ - 0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}' + assert isinstance(val, np.ndarray), ( + f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for {key=}, got {type(val)=}" + ) + assert val.shape[0] == batch_size, ( + f"key {key} length {len(val)} is not equal to batch size {batch_size}" + ) @classmethod def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None): @@ -311,7 +318,7 @@ class DataProto: elif isinstance(val, np.ndarray): non_tensors[key] = val else: - raise ValueError(f'Unsupported type in data {type(val)}') + raise ValueError(f"Unsupported type in data {type(val)}") return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) @@ -321,10 +328,10 @@ class DataProto: 1. All the tensor in tensors have the same dim0 2. Only dim0 is the batch dim """ - assert len(tensors) > 0, 'tensors must not be empty' - assert num_batch_dims > 0, 'num_batch_dims must be greater than zero' + assert len(tensors) > 0, "tensors must not be empty" + assert num_batch_dims > 0, "num_batch_dims must be greater than zero" if non_tensors is not None: - assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.' + assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." if meta_info is None: meta_info = {} @@ -342,8 +349,9 @@ class DataProto: pivot_key = key else: current_batch = tensor.shape[:num_batch_dims] - assert batch_size == current_batch, \ - f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}' + assert batch_size == current_batch, ( + f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}" + ) for key, val in non_tensors.items(): non_tensors[key] = np.array(val, dtype=object) @@ -351,7 +359,7 @@ class DataProto: tensor_dict = TensorDict(source=tensors, batch_size=batch_size) return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) - def to(self, device) -> 'DataProto': + def to(self, device) -> "DataProto": """move the batch to device Args: @@ -365,7 +373,7 @@ class DataProto: self.batch = self.batch.to(device) return self - def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto': + def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": """Select a subset of the DataProto via batch_keys and meta_info_keys Args: @@ -403,10 +411,10 @@ class DataProto: def select_idxs(self, idxs): """ Select specific indices from the DataProto. - + Args: idxs (torch.Tensor or numpy.ndarray or list): Indices to select - + Returns: DataProto: A new DataProto containing only the selected indices """ @@ -422,10 +430,10 @@ class DataProto: if self.batch is not None: # Use TensorDict's built-in indexing capabilities - selected_batch = TensorDict(source={ - key: tensor[idxs_torch] for key, tensor in self.batch.items() - }, - batch_size=(idxs_torch.shape[0],)) + selected_batch = TensorDict( + source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, + batch_size=(idxs_torch.shape[0],), + ) else: selected_batch = None @@ -439,27 +447,27 @@ class DataProto: """ Slice the DataProto and return a new DataProto object. This is an improved version of direct slicing which returns a DataProtoItem. - + Args: start (int, optional): Start index. Defaults to None (start from beginning). end (int, optional): End index (exclusive). Defaults to None (go to end). step (int, optional): Step size. Defaults to None (step=1). - + Returns: DataProto: A new DataProto containing the sliced data - + Examples: # Using the slice method directly sliced_data = data_proto.slice(10, 20) - + # Using enhanced indexing (returns DataProto) sliced_data = data_proto[10:20] sliced_data = data_proto[::2] # Every other element - + # Using list indexing (returns DataProto) indices = [1, 5, 10] selected_data = data_proto[indices] - + # Single index still returns DataProtoItem single_item = data_proto[5] """ @@ -481,7 +489,7 @@ class DataProto: # Return a new DataProto object return DataProto(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) - def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto': + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` Args: @@ -513,7 +521,7 @@ class DataProto: meta_info[key] = self.meta_info.pop(key) return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) - def rename(self, old_keys=None, new_keys=None) -> 'DataProto': + def rename(self, old_keys=None, new_keys=None) -> "DataProto": """ Note that this function only rename the key in the batch """ @@ -525,7 +533,7 @@ class DataProto: elif isinstance(keys, list): pass else: - raise TypeError(f'keys must be a list or a string, but got {type(keys)}') + raise TypeError(f"keys must be a list or a string, but got {type(keys)}") return keys old_keys = validate_input(old_keys) @@ -533,13 +541,14 @@ class DataProto: if len(new_keys) != len(old_keys): raise ValueError( - f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}') + f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" + ) self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) return self - def union(self, other: 'DataProto') -> 'DataProto': + def union(self, other: "DataProto") -> "DataProto": """Union with another DataProto. Union batch and meta_info separately. Throw an error if @@ -583,11 +592,9 @@ class DataProto: generator = None assert isinstance(dataloader_kwargs, Dict) - train_dataloader = DataLoader(dataset=self, - batch_size=mini_batch_size, - collate_fn=collate_fn, - generator=generator, - **dataloader_kwargs) + train_dataloader = DataLoader( + dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs + ) def get_data(): for _ in range(epochs): @@ -597,7 +604,7 @@ class DataProto: return iter(get_data()) - def chunk(self, chunks: int) -> List['DataProto']: + def chunk(self, chunks: int) -> List["DataProto"]: """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. Args: @@ -606,8 +613,9 @@ class DataProto: Returns: List[DataProto]: a list of DataProto after splitting """ - assert len( - self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.' + assert len(self) % chunks == 0, ( + f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." + ) if self.batch is not None: batch_lst = self.batch.chunk(chunks=chunks, dim=0) @@ -625,12 +633,13 @@ class DataProto: output = [] for i in range(chunks): output.append( - DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)) + DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) + ) return output @staticmethod - def concat(data: List['DataProto']) -> 'DataProto': + def concat(data: List["DataProto"]) -> "DataProto": """Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is assumed to be identical and will use the first one. @@ -723,16 +732,17 @@ class DataProtoFuture: - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any operation on the DataProtoFuture in driver. """ + collect_fn: Callable futures: List[ray.ObjectRef] dispatch_fn: Callable = None @staticmethod - def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture': + def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture": output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) return output - def chunk(self, chunks: int) -> List['DataProtoFuture']: + def chunk(self, chunks: int) -> List["DataProtoFuture"]: from functools import partial arg_future_lst = [] @@ -741,9 +751,9 @@ class DataProtoFuture: def dispatch_fn(x, i, chunks): return x.chunk(chunks=chunks)[i] - arg_future = DataProtoFuture(collect_fn=self.collect_fn, - dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), - futures=self.futures) + arg_future = DataProtoFuture( + collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures + ) arg_future_lst.append(arg_future) return arg_future_lst @@ -757,9 +767,10 @@ class DataProtoFuture: return output -from verl.utils.torch_functional import allgather_dict_tensors import torch.distributed +from verl.utils.torch_functional import allgather_dict_tensors + def all_gather_data_proto(data: DataProto, process_group): # Note that this is an inplace operator just like torch.distributed.all_gather diff --git a/verl/single_controller/__init__.py b/verl/single_controller/__init__.py index cb11fa555..791f55637 100644 --- a/verl/single_controller/__init__.py +++ b/verl/single_controller/__init__.py @@ -17,10 +17,10 @@ import os version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) # Note(haibin.lin): single_controller.__version__ is deprecated -with open(os.path.join(os.path.join(version_folder, os.pardir), 'version/version')) as f: +with open(os.path.join(os.path.join(version_folder, os.pardir), "version/version")) as f: __version__ = f.read().strip() from . import base from .base import * -__all__ = base.__all__ \ No newline at end of file +__all__ = base.__all__ diff --git a/verl/single_controller/base/__init__.py b/verl/single_controller/base/__init__.py index d91f827df..b24bd9942 100644 --- a/verl/single_controller/base/__init__.py +++ b/verl/single_controller/base/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from .worker import Worker -from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool +from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup -__all__ = ['Worker', 'WorkerGroup', 'ClassWithInitArgs', 'ResourcePool'] \ No newline at end of file +__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"] diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 0f57f8e0a..9d95c77d4 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -14,12 +14,13 @@ from enum import Enum from functools import wraps -from typing import Dict, List, Tuple from types import FunctionType +from typing import Dict, List, Tuple + from verl.protocol import DataProtoFuture # here we add a magic number of avoid user-defined function already have this attribute -MAGIC_ATTR = 'attrs_3141562937' +MAGIC_ATTR = "attrs_3141562937" class Dispatch(Enum): @@ -44,6 +45,7 @@ class Execute(Enum): def _split_args_kwargs_data_proto(chunks, *args, **kwargs): from verl.protocol import DataProto, DataProtoFuture + splitted_args = [] for arg in args: assert isinstance(arg, (DataProto, DataProtoFuture)) @@ -76,8 +78,10 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs): User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - assert isinstance(worker_group, - MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}' + + assert isinstance(worker_group, MegatronWorkerGroup), ( + f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" + ) all_args = [] for arg in args: @@ -105,6 +109,7 @@ def collect_megatron_compute(worker_group, output): Only collect the data from the tp=0 and pp=last and every dp ranks """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) output_in_dp = [] pp_size = worker_group.get_megatron_global_info().pp_size @@ -120,6 +125,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) @@ -127,9 +133,10 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): def _concat_data_proto_or_future(output: List): - from verl.protocol import DataProto, DataProtoFuture import ray + from verl.protocol import DataProto, DataProtoFuture + # make sure all the elements in output has the same type for o in output: assert type(o) == type(output[0]) @@ -148,9 +155,10 @@ def collect_megatron_compute_data_proto(worker_group, output): """ Each output must be a DataProto. We concat the dim=0 of output """ - from verl.protocol import DataProto import ray + from verl.protocol import DataProto + output = collect_megatron_compute(worker_group, output) for o in output: assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" @@ -163,6 +171,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): treat pp as dp. """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) pp_size = worker_group.pp_size @@ -196,7 +205,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): all_kwargs = {} for k, v in kwargs.items(): - assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size, f'expect len(v)=={pp_dp_cp_size}, got {len(v)}' + assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" transformed_v = [] for i in range(worker_group.world_size): local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank @@ -215,6 +224,7 @@ def collect_megatron_pp_as_dp(worker_group, output): treat pp as dp. Only collect data on tp=0 """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) output_in_dp = [] for global_rank in range(worker_group.world_size): @@ -229,6 +239,7 @@ def collect_megatron_pp_only(worker_group, output): Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) output_in_pp = [] for global_rank in range(worker_group.world_size): @@ -240,6 +251,7 @@ def collect_megatron_pp_only(worker_group, output): def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) pp_dp_cp_size = worker_group.dp_size * worker_group.pp_size * worker_group.cp_size @@ -249,8 +261,8 @@ def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): def collect_megatron_pp_as_dp_data_proto(worker_group, output): - from verl.protocol import DataProto from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + assert isinstance(worker_group, MegatronWorkerGroup) output = collect_megatron_pp_as_dp(worker_group, output) @@ -259,6 +271,7 @@ def collect_megatron_pp_as_dp_data_proto(worker_group, output): def dispatch_dp_compute(worker_group, *args, **kwargs): from verl.single_controller.base.worker_group import WorkerGroup + assert isinstance(worker_group, WorkerGroup) for arg in args: assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size @@ -269,6 +282,7 @@ def dispatch_dp_compute(worker_group, *args, **kwargs): def collect_dp_compute(worker_group, output): from verl.single_controller.base.worker_group import WorkerGroup + assert isinstance(worker_group, WorkerGroup) assert len(output) == worker_group.world_size return output @@ -276,6 +290,7 @@ def collect_dp_compute(worker_group, output): def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): from verl.single_controller.base.worker_group import WorkerGroup + assert isinstance(worker_group, WorkerGroup) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) return splitted_args, splitted_kwargs @@ -283,6 +298,7 @@ def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): from verl.single_controller.base.worker_group import WorkerGroup + assert isinstance(worker_group, WorkerGroup) assert type(args[0]) == FunctionType # NOTE: The first one args is a function! @@ -292,9 +308,10 @@ def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): def collect_dp_compute_data_proto(worker_group, output): - from verl.protocol import DataProto import ray + from verl.protocol import DataProto + for o in output: assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" @@ -305,49 +322,40 @@ def collect_dp_compute_data_proto(worker_group, output): def get_predefined_dispatch_fn(dispatch_mode): predefined_dispatch_mode_fn = { Dispatch.ONE_TO_ALL: { - 'dispatch_fn': dispatch_one_to_all, - 'collect_fn': collect_all_to_all, + "dispatch_fn": dispatch_one_to_all, + "collect_fn": collect_all_to_all, }, Dispatch.ALL_TO_ALL: { - 'dispatch_fn': dispatch_all_to_all, - 'collect_fn': collect_all_to_all, + "dispatch_fn": dispatch_all_to_all, + "collect_fn": collect_all_to_all, }, Dispatch.MEGATRON_COMPUTE: { - 'dispatch_fn': dispatch_megatron_compute, - 'collect_fn': collect_megatron_compute, + "dispatch_fn": dispatch_megatron_compute, + "collect_fn": collect_megatron_compute, }, Dispatch.MEGATRON_PP_AS_DP: { - 'dispatch_fn': dispatch_megatron_pp_as_dp, - 'collect_fn': collect_megatron_pp_as_dp, - }, - Dispatch.MEGATRON_PP_ONLY: { - 'dispatch_fn': dispatch_one_to_all, - 'collect_fn': collect_megatron_pp_only + "dispatch_fn": dispatch_megatron_pp_as_dp, + "collect_fn": collect_megatron_pp_as_dp, }, + Dispatch.MEGATRON_PP_ONLY: {"dispatch_fn": dispatch_one_to_all, "collect_fn": collect_megatron_pp_only}, Dispatch.MEGATRON_COMPUTE_PROTO: { - 'dispatch_fn': dispatch_megatron_compute_data_proto, - 'collect_fn': collect_megatron_compute_data_proto + "dispatch_fn": dispatch_megatron_compute_data_proto, + "collect_fn": collect_megatron_compute_data_proto, }, Dispatch.MEGATRON_PP_AS_DP_PROTO: { - 'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto, - 'collect_fn': collect_megatron_pp_as_dp_data_proto - }, - Dispatch.DP_COMPUTE: { - 'dispatch_fn': dispatch_dp_compute, - 'collect_fn': collect_dp_compute + "dispatch_fn": dispatch_megatron_pp_as_dp_data_proto, + "collect_fn": collect_megatron_pp_as_dp_data_proto, }, + Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, Dispatch.DP_COMPUTE_PROTO: { - 'dispatch_fn': dispatch_dp_compute_data_proto, - 'collect_fn': collect_dp_compute_data_proto + "dispatch_fn": dispatch_dp_compute_data_proto, + "collect_fn": collect_dp_compute_data_proto, }, Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { - 'dispatch_fn': dispatch_dp_compute_data_proto_with_func, - 'collect_fn': collect_dp_compute_data_proto + "dispatch_fn": dispatch_dp_compute_data_proto_with_func, + "collect_fn": collect_dp_compute_data_proto, }, - Dispatch.DP_COMPUTE_METRIC: { - 'dispatch_fn': dispatch_dp_compute_data_proto, - 'collect_fn': collect_dp_compute - } + Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, } return predefined_dispatch_mode_fn[dispatch_mode] @@ -358,27 +366,24 @@ def get_predefined_execute_fn(execute_mode): Leave the choice of how these two functions handle argument 'blocking' to users """ predefined_execute_mode_fn = { - Execute.ALL: { - 'execute_fn_name': 'execute_all' - }, - Execute.RANK_ZERO: { - 'execute_fn_name': 'execute_rank_zero' - } + Execute.ALL: {"execute_fn_name": "execute_all"}, + Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, } return predefined_execute_mode_fn[execute_mode] def _check_dispatch_mode(dispatch_mode): - assert isinstance(dispatch_mode, - (Dispatch, Dict)), f'dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}' + assert isinstance(dispatch_mode, (Dispatch, Dict)), ( + f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" + ) if isinstance(dispatch_mode, Dict): - necessary_keys = ['dispatch_fn', 'collect_fn'] + necessary_keys = ["dispatch_fn", "collect_fn"] for key in necessary_keys: - assert key in dispatch_mode, f'key {key} should be in dispatch_mode if it is a dictionary' + assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" def _check_execute_mode(execute_mode): - assert isinstance(execute_mode, Execute), f'execute_mode must be a Execute. Got {execute_mode}' + assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" def _materialize_futures(*args, **kwargs): @@ -401,14 +406,13 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki _check_execute_mode(execute_mode=execute_mode) def decorator(func): - @wraps(func) def inner(*args, **kwargs): if materialize_futures: args, kwargs = _materialize_futures(*args, **kwargs) return func(*args, **kwargs) - attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking} + attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} setattr(inner, MAGIC_ATTR, attrs) return inner diff --git a/verl/single_controller/base/megatron/worker.py b/verl/single_controller/base/megatron/worker.py index af9f612ea..5fc711281 100644 --- a/verl/single_controller/base/megatron/worker.py +++ b/verl/single_controller/base/megatron/worker.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo +from verl.single_controller.base.worker import DistGlobalInfo, DistRankInfo, Worker class MegatronWorker(Worker): - def __init__(self, cuda_visible_devices=None) -> None: super().__init__(cuda_visible_devices) def get_megatron_global_info(self): from megatron.core import parallel_state as mpu + tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() @@ -31,6 +31,7 @@ class MegatronWorker(Worker): def get_megatron_rank_info(self): from megatron.core import parallel_state as mpu + tp_rank = mpu.get_tensor_model_parallel_rank() dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -39,11 +40,12 @@ class MegatronWorker(Worker): return info def _init_hf_config_and_tf_config(self, model_path, dtype, override_model_config): - from verl.utils.model import print_model_size, update_model_config - from verl.utils.fs import copy_to_local - from verl.utils import hf_tokenizer from transformers import AutoConfig + from verl.models.mcore import hf_to_mcore_config + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_to_local + from verl.utils.model import update_model_config # Step 1: initialize the tokenizer self.local_path = copy_to_local(model_path) @@ -54,29 +56,30 @@ class MegatronWorker(Worker): # Step 3: override the hf config override_config_kwargs = { - 'bos_token_id': self.tokenizer.bos_token_id, - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_model_config) self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) update_model_config(hf_config, override_config_kwargs=override_config_kwargs) self.architectures = getattr(hf_config, "architectures", None) if self.rank == 0: - print(f'Model config after override: {hf_config}') + print(f"Model config after override: {hf_config}") tf_config = hf_to_mcore_config(hf_config, dtype) def add_optimization_config_to_tf_config(tf_config, verl_model_config): # add optimization config to tf_config, e.g. checkpointing - if verl_model_config.get('enable_gradient_checkpointing', False): - gradient_checkpointing_cfg = dict(verl_model_config.get('gradient_checkpointing_kwargs', dict())) - tf_config.recompute_method = gradient_checkpointing_cfg.get('activations_checkpoint_method', 'full') - tf_config.recompute_granularity = gradient_checkpointing_cfg.get('activations_checkpoint_granularity', - 'full') - tf_config.recompute_num_layers = gradient_checkpointing_cfg.get('activations_checkpoint_num_layers', -1) + if verl_model_config.get("enable_gradient_checkpointing", False): + gradient_checkpointing_cfg = dict(verl_model_config.get("gradient_checkpointing_kwargs", dict())) + tf_config.recompute_method = gradient_checkpointing_cfg.get("activations_checkpoint_method", "full") + tf_config.recompute_granularity = gradient_checkpointing_cfg.get( + "activations_checkpoint_granularity", "full" + ) + tf_config.recompute_num_layers = gradient_checkpointing_cfg.get("activations_checkpoint_num_layers", -1) add_optimization_config_to_tf_config(tf_config, self.config.model) - print(f'TF config: {tf_config}') + print(f"TF config: {tf_config}") self.hf_config = hf_config self.tf_config = tf_config diff --git a/verl/single_controller/base/megatron/worker_group.py b/verl/single_controller/base/megatron/worker_group.py index 26477b067..04d211ffe 100644 --- a/verl/single_controller/base/megatron/worker_group.py +++ b/verl/single_controller/base/megatron/worker_group.py @@ -14,22 +14,22 @@ from typing import Dict -from .worker import DistRankInfo, DistGlobalInfo from verl.single_controller.base import ResourcePool, WorkerGroup +from .worker import DistGlobalInfo, DistRankInfo + class MegatronWorkerGroup(WorkerGroup): - def __init__(self, resource_pool: ResourcePool, **kwargs): super().__init__(resource_pool=resource_pool, **kwargs) self._megatron_rank_info = None self._megatron_global_info: DistGlobalInfo = None def init_megatron(self, default_megatron_kwargs: Dict = None): - raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten") + raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten") def get_megatron_rank_info(self, rank: int) -> DistRankInfo: - assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}' + assert 0 <= rank < self.world_size, f"rank must be from [0, world_size), Got {rank}" return self._megatron_rank_info[rank] @property diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py index 430290cf2..de7f702a8 100644 --- a/verl/single_controller/base/register_center/ray.py +++ b/verl/single_controller/base/register_center/ray.py @@ -17,7 +17,6 @@ import ray @ray.remote class WorkerGroupRegisterCenter: - def __init__(self, rank_zero_info): self.rank_zero_info = rank_zero_info diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 0729d5b64..59ff599f0 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -14,10 +14,12 @@ """ the class for Worker """ + import os import socket from dataclasses import dataclass -from .decorator import register, Dispatch, Execute + +from .decorator import Dispatch, Execute, register @dataclass @@ -37,12 +39,11 @@ class DistGlobalInfo: class WorkerHelper: - def _get_node_ip(self): - def get_node_ip_by_sdk(): if os.getenv("WG_BACKEND", None) == "ray": import ray + return ray._private.services.get_node_ip_address() else: raise NotImplementedError("WG_BACKEND now just support ray mode.") @@ -57,7 +58,7 @@ class WorkerHelper: def _get_free_port(self): with socket.socket() as sock: - sock.bind(('', 0)) + sock.bind(("", 0)) return sock.getsockname()[1] def get_availale_master_addr_port(self): @@ -69,7 +70,13 @@ class WorkerHelper: class WorkerMeta: keys = [ - "WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES" + "WORLD_SIZE", + "RANK", + "LOCAL_WORLD_SIZE", + "LOCAL_RANK", + "MASTER_ADDR", + "MASTER_PORT", + "CUDA_VISIBLE_DEVICES", ] def __init__(self, store) -> None: @@ -87,7 +94,7 @@ class Worker(WorkerHelper): instance = super().__new__(cls) # note that here we use int to distinguish - disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0)) + disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) if disable_worker_init: return instance @@ -95,7 +102,7 @@ class Worker(WorkerHelper): worker_group_prefix = os.environ.get("WG_PREFIX", None) # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init - if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__: + if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) return instance @@ -112,8 +119,10 @@ class Worker(WorkerHelper): if os.getenv("WG_BACKEND", None) == "ray": from verl.single_controller.base.register_center.ray import create_worker_group_register_center - self.register_center = create_worker_group_register_center(name=register_center_name, - info=rank_zero_info) + + self.register_center = create_worker_group_register_center( + name=register_center_name, info=rank_zero_info + ) os.environ.update(rank_zero_info) @@ -129,12 +138,12 @@ class Worker(WorkerHelper): ### # [SUPPORT AMD: torch] if "AMD" in torch.cuda.get_device_name(): - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES') - os.environ['LOCAL_RANK'] = os.environ.get('RAY_LOCAL_RANK') + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("ROCR_VISIBLE_DEVICES") + os.environ["LOCAL_RANK"] = os.environ.get("RAY_LOCAL_RANK") ### - world_size = int(os.environ['WORLD_SIZE']) - rank = int(os.environ['RANK']) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) self._rank = rank self._world_size = world_size @@ -147,7 +156,7 @@ class Worker(WorkerHelper): ### # [SUPPORT AMD: torch] if "AMD" in torch.cuda.get_device_name(): - self.local_rank = int(os.environ['LOCAL_RANK']) + self.local_rank = int(os.environ["LOCAL_RANK"]) ### ### @@ -157,15 +166,15 @@ class Worker(WorkerHelper): ### store = { - '_world_size': world_size, - '_rank': rank, - '_local_world_size': local_world_size, - '_local_rank': local_rank, - '_master_addr': master_addr, - '_master_port': master_port + "_world_size": world_size, + "_rank": rank, + "_local_world_size": local_world_size, + "_local_rank": local_rank, + "_master_addr": master_addr, + "_master_port": master_port, } if cuda_visible_devices is not None: - store['_cuda_visible_devices'] = cuda_visible_devices + store["_cuda_visible_devices"] = cuda_visible_devices meta = WorkerMeta(store=store) self._configure_with_meta(meta=meta) @@ -189,14 +198,16 @@ class Worker(WorkerHelper): if val is not None: # print(f"set {key} to {val}") os.environ[key] = str(val) - os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace( - "]", "") if self._master_addr else "" + os.environ["REDIS_STORE_SERVER_HOST"] = ( + str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + ) def get_master_addr_port(self): return self._master_addr, self._master_port def get_cuda_visible_devices(self): import os + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") return cuda_visible_devices diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py index 3da3db46a..917fab837 100644 --- a/verl/single_controller/base/worker_group.py +++ b/verl/single_controller/base/worker_group.py @@ -14,11 +14,12 @@ """ the class of WorkerGroup """ + import logging -import threading import signal +import threading import time -from typing import List, Any, Callable, Dict +from typing import Any, Callable, Dict, List from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn @@ -81,6 +82,7 @@ class ClassWithInitArgs: def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: import time + while True: for worker in workers: if not is_alive(worker): @@ -110,7 +112,7 @@ class WorkerGroup: self._checker_thread: threading.Thread = None def _is_worker_alive(self, worker): - raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.") + raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") def _block_until_all_workers_alive(self) -> None: while True: @@ -124,8 +126,9 @@ class WorkerGroup: # before starting checking worker aliveness, make sure all workers are already alive self._block_until_all_workers_alive() - self._checker_thread = threading.Thread(target=check_workers_alive, - args=(self._workers, self._is_worker_alive, every_n_seconds)) + self._checker_thread = threading.Thread( + target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) + ) self._checker_thread.start() @property @@ -141,58 +144,59 @@ class WorkerGroup: """ for method_name in dir(user_defined_cls): - try: method = getattr(user_defined_cls, method_name) assert callable(method), f"{method_name} in {user_defined_cls} is not callable" - except Exception as e: + except Exception: # if it is a property, it will fail because Class doesn't have instance property continue if hasattr(method, MAGIC_ATTR): # this method is decorated by register attribute = getattr(method, MAGIC_ATTR) - assert isinstance(attribute, Dict), f'attribute must be a dictionary. Got {type(attribute)}' - assert 'dispatch_mode' in attribute, f'attribute must contain dispatch_mode in its key' + assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" + assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" - dispatch_mode = attribute['dispatch_mode'] - execute_mode = attribute['execute_mode'] - blocking = attribute['blocking'] + dispatch_mode = attribute["dispatch_mode"] + execute_mode = attribute["execute_mode"] + blocking = attribute["blocking"] # get dispatch fn if isinstance(dispatch_mode, Dispatch): # get default dispatch fn fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) - dispatch_fn = fn['dispatch_fn'] - collect_fn = fn['collect_fn'] + dispatch_fn = fn["dispatch_fn"] + collect_fn = fn["collect_fn"] else: assert isinstance(dispatch_mode, dict) - assert 'dispatch_fn' in dispatch_mode - assert 'collect_fn' in dispatch_mode - dispatch_fn = dispatch_mode['dispatch_fn'] - collect_fn = dispatch_mode['collect_fn'] + assert "dispatch_fn" in dispatch_mode + assert "collect_fn" in dispatch_mode + dispatch_fn = dispatch_mode["dispatch_fn"] + collect_fn = dispatch_mode["collect_fn"] # get execute_fn_name execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) - wg_execute_fn_name = execute_mode['execute_fn_name'] + wg_execute_fn_name = execute_mode["execute_fn_name"] # get execute_fn from string try: execute_fn = getattr(self, wg_execute_fn_name) - assert callable(execute_fn), 'execute_fn must be callable' - except Exception as e: - print(f'execute_fn {wg_execute_fn_name} is invalid') + assert callable(execute_fn), "execute_fn must be callable" + except Exception: + print(f"execute_fn {wg_execute_fn_name} is invalid") raise # bind a new method to the RayWorkerGroup - func = func_generator(self, - method_name, - dispatch_fn=dispatch_fn, - collect_fn=collect_fn, - execute_fn=execute_fn, - blocking=blocking) + func = func_generator( + self, + method_name, + dispatch_fn=dispatch_fn, + collect_fn=collect_fn, + execute_fn=execute_fn, + blocking=blocking, + ) try: setattr(self, method_name, func) - except Exception as e: - raise ValueError(f'Fail to set method_name {method_name}') + except Exception: + raise ValueError(f"Fail to set method_name {method_name}") diff --git a/verl/single_controller/ray/__init__.py b/verl/single_controller/ray/__init__.py index 2b567d0cf..dd1636ccf 100644 --- a/verl/single_controller/ray/__init__.py +++ b/verl/single_controller/ray/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls \ No newline at end of file +from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 90a2f4da5..32557d511 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -14,28 +14,28 @@ import logging import time -from typing import Dict, List, Any, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple import ray -from ray.util import list_named_actors -from ray.util.placement_group import placement_group, PlacementGroup -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from ray.experimental.state.api import get_actor +from ray.util import list_named_actors +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy -from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker +from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup -__all__ = ['Worker'] +__all__ = ["Worker"] def get_random_string(length: int) -> str: import random import string + letters_digits = string.ascii_letters + string.digits - return ''.join(random.choice(letters_digits) for _ in range(length)) + return "".join(random.choice(letters_digits) for _ in range(length)) def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking): - def func(*args, **kwargs): args, kwargs = dispatch_fn(self, *args, **kwargs) output = execute_fn(method_name, *args, **kwargs) @@ -68,13 +68,14 @@ def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[Placement class RayResourcePool(ResourcePool): - - def __init__(self, - process_on_nodes: Optional[List[int]] = None, - use_gpu: bool = True, - name_prefix: str = "", - max_colocate_count: int = 10, - detached=False) -> None: + def __init__( + self, + process_on_nodes: Optional[List[int]] = None, + use_gpu: bool = True, + name_prefix: str = "", + max_colocate_count: int = 10, + detached=False, + ) -> None: super().__init__(process_on_nodes, max_colocate_count) self.use_gpu = use_gpu # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}") @@ -86,17 +87,19 @@ class RayResourcePool(ResourcePool): if self.pgs is not None: return self.pgs - pg_name_prefix = name if name else \ - f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + pg_name_prefix = ( + name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + ) # print(f"pg_name_prefix = {pg_name_prefix}") - pg_scheme = [[{ - "CPU": self.max_colocate_count, - "GPU": 1 - } if self.use_gpu else { - "CPU": self.max_colocate_count - } for _ in range(process_count)] for process_count in self._store] + pg_scheme = [ + [ + {"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count} + for _ in range(process_count) + ] + for process_count in self._store + ] - lifetime = 'detached' if self.detached else None + lifetime = "detached" if self.detached else None pgs = [ placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) @@ -109,11 +112,13 @@ class RayResourcePool(ResourcePool): return pgs -def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], - resource_pool: RayResourcePool) -> List: - +def extract_pg_from_exist( + resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool +) -> List: src_pgs = [ - pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups() + pg + for role_name, resource_pool in resource_pools.items() + for pg in resource_pool.get_placement_groups() if role_name in src_role_names ] @@ -124,8 +129,9 @@ def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_n searching_idx = 0 for request_process, original_idx in sorted_process_on_nodes: assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" - assert request_process <= sorted_src_pgs[searching_idx].bundle_count, \ + assert request_process <= sorted_src_pgs[searching_idx].bundle_count, ( f"requesting {request_process} processes, bundle count cannot satisfy" + ) unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) searching_idx += 1 @@ -133,10 +139,10 @@ def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_n def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: - assert rp1.use_gpu == rp2.use_gpu, 'Both RayResourcePool must either use_gpu or not' - assert rp1.max_colocate_count == rp2.max_colocate_count, 'Both RayResourcePool must has the same max_colocate_count' - assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, 'Both RayResourcePool must has the same n_gpus_per_node' - assert rp1.detached == rp2.detached, 'Detached ResourcePool cannot be merged with non-detached ResourcePool' + assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" + assert rp1.max_colocate_count == rp2.max_colocate_count, "Both RayResourcePool must has the same max_colocate_count" + assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" + assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" new_store = rp1.store + rp2.store @@ -147,7 +153,6 @@ def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResour class RayClassWithInitArgs(ClassWithInitArgs): - def __init__(self, cls, *args, **kwargs) -> None: # self._options = kwargs.pop('options', dict()) super().__init__(cls, *args, **kwargs) @@ -160,24 +165,21 @@ class RayClassWithInitArgs(ClassWithInitArgs): def update_options(self, options: Dict): self._options.update(options) - def __call__(self, - placement_group, - placement_group_bundle_idx, - use_gpu: bool = True, - num_gpus=1, - sharing_with=None) -> Any: + def __call__( + self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None + ) -> Any: if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} - return self.cls.options(**options).remote(*self.args, - cuda_visible_devices=cuda_visible_devices, - **self.kwargs) + return self.cls.options(**options).remote( + *self.args, cuda_visible_devices=cuda_visible_devices, **self.kwargs + ) options = { - "scheduling_strategy": - PlacementGroupSchedulingStrategy(placement_group=placement_group, - placement_group_bundle_index=placement_group_bundle_idx) + "scheduling_strategy": PlacementGroupSchedulingStrategy( + placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx + ) } options.update(self._options) @@ -195,16 +197,17 @@ class RayClassWithInitArgs(ClassWithInitArgs): class RayWorkerGroup(WorkerGroup): - - def __init__(self, - resource_pool: RayResourcePool = None, - ray_cls_with_init: RayClassWithInitArgs = None, - bin_pack: bool = True, - name_prefix: str = None, - detached=False, - worker_names=None, - ray_wait_register_center_timeout: int = 300, - **kwargs) -> None: + def __init__( + self, + resource_pool: RayResourcePool = None, + ray_cls_with_init: RayClassWithInitArgs = None, + bin_pack: bool = True, + name_prefix: str = None, + detached=False, + worker_names=None, + ray_wait_register_center_timeout: int = 300, + **kwargs, + ) -> None: super().__init__(resource_pool=resource_pool, **kwargs) self.ray_cls_with_init = ray_cls_with_init self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix @@ -217,10 +220,9 @@ class RayWorkerGroup(WorkerGroup): if self._is_init_with_detached_workers: self._init_with_detached_workers(worker_names=worker_names) else: - self._init_with_resource_pool(resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - bin_pack=bin_pack, - detached=detached) + self._init_with_resource_pool( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached + ) if ray_cls_with_init is not None: self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) @@ -249,40 +251,39 @@ class RayWorkerGroup(WorkerGroup): rank = -1 local_world_size = resource_pool.store[0] for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): - assert local_world_size <= pg.bundle_count, \ - f"when generating for {self.name_prefix}, for the " + assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " for local_rank in range(local_world_size): rank += 1 # we pass in environment variable at option so that Worker can use environment variable to set env_vars = { - 'WORLD_SIZE': str(world_size), - 'RANK': str(rank), - 'WG_PREFIX': self.name_prefix, - 'WG_BACKEND': 'ray', - 'RAY_LOCAL_WORLD_SIZE': str(local_world_size), - 'RAY_LOCAL_RANK': str(local_rank), + "WORLD_SIZE": str(world_size), + "RANK": str(rank), + "WG_PREFIX": self.name_prefix, + "WG_BACKEND": "ray", + "RAY_LOCAL_WORLD_SIZE": str(local_world_size), + "RAY_LOCAL_RANK": str(local_rank), } if rank != 0: - env_vars['MASTER_ADDR'] = self._master_addr - env_vars['MASTER_PORT'] = self._master_port + env_vars["MASTER_ADDR"] = self._master_addr + env_vars["MASTER_PORT"] = self._master_port import re + cia_name = type(ray_cls_with_init.cls).__name__ match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 - ray_cls_with_init.update_options({'runtime_env': {'env_vars': env_vars}, 'name': name}) + ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) if detached: - ray_cls_with_init.update_options({'lifetime': 'detached'}) + ray_cls_with_init.update_options({"lifetime": "detached"}) # create a worker - worker = ray_cls_with_init(placement_group=pg, - placement_group_bundle_idx=local_rank, - use_gpu=use_gpu, - num_gpus=num_gpus) + worker = ray_cls_with_init( + placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus + ) self._workers.append(worker) self._worker_names.append(name) @@ -313,7 +314,7 @@ class RayWorkerGroup(WorkerGroup): ) rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) - self._master_addr, self._master_port = rank_zero_info['MASTER_ADDR'], rank_zero_info['MASTER_PORT'] + self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"] # print(f"rank_zero_info: {rank_zero_info}") # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}") @@ -323,10 +324,9 @@ class RayWorkerGroup(WorkerGroup): @classmethod def from_detached(cls, worker_names=None, ray_cls_with_init=None): - worker_group = cls(resource_pool=None, - ray_cls_with_init=ray_cls_with_init, - name_prefix=None, - worker_names=worker_names) + worker_group = cls( + resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=None, worker_names=worker_names + ) return worker_group def spawn(self, prefix_set): @@ -339,7 +339,7 @@ class RayWorkerGroup(WorkerGroup): """ bind the method with actor_prefix to its original name """ - prefix: str = actor_name + '_' + prefix: str = actor_name + "_" for method_name in dir(worker_group): if method_name.startswith(prefix): # only valid when Python >= 3.9 @@ -349,8 +349,9 @@ class RayWorkerGroup(WorkerGroup): new_worker_group_dict = {} for prefix in prefix_set: - new_worker_group = self.from_detached(worker_names=self._worker_names, - ray_cls_with_init=self.ray_cls_with_init) + new_worker_group = self.from_detached( + worker_names=self._worker_names, ray_cls_with_init=self.ray_cls_with_init + ) _rebind_actor_methods(new_worker_group, prefix) new_worker_group_dict[prefix] = new_worker_group @@ -412,28 +413,28 @@ Utilities that enables creating workers inside the same ray.Actor, with code written in separate ray.Actors. """ -from unittest.mock import patch -from verl.single_controller.base.decorator import MAGIC_ATTR import os +from unittest.mock import patch + +from verl.single_controller.base.decorator import MAGIC_ATTR def _bind_workers_method_to_parent(cls, key, user_defined_cls): """ - Binds the methods of each worker to the WorkerDict. + Binds the methods of each worker to the WorkerDict. Note that we only bind public methods that are decorated by register """ for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) assert callable(method), f"{method_name} in {user_defined_cls} is not callable" - except Exception as e: + except Exception: # if it is a property, it will fail because Class doesn't have instance property continue if hasattr(method, MAGIC_ATTR): def generate_function(name): - def func(self, *args, **kwargs): # dispatch to the actual worker return getattr(self.worker_dict[key], name)(*args, **kwargs) @@ -444,22 +445,22 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls): # pass MAGIC_ATTR for outer worker group setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR)) try: - method_name_with_prefix = key + '_' + method_name + method_name_with_prefix = key + "_" + method_name setattr(cls, method_name_with_prefix, func) # print(f'Binding {method_name_with_prefix}') - except Exception as e: - raise ValueError(f'Fail to set method_name {method_name}') + except Exception: + raise ValueError(f"Fail to set method_name {method_name}") def _unwrap_ray_remote(cls): - if hasattr(cls, '__ray_actor_class__'): + if hasattr(cls, "__ray_actor_class__"): cls = cls.__ray_actor_class__ return cls def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ - This function should return a class instance that delegates the calls to every + This function should return a class instance that delegates the calls to every cls in cls_dict """ cls_dict = {} @@ -469,16 +470,16 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): if worker_cls == None: worker_cls = cls.cls.__ray_actor_class__.__base__ else: - assert worker_cls == cls.cls.__ray_actor_class__.__base__, \ - 'the worker class should be the same when share the same process' + assert worker_cls == cls.cls.__ray_actor_class__.__base__, ( + "the worker class should be the same when share the same process" + ) cls_dict[key] = cls.cls - init_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs} + init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} assert cls_dict.keys() == init_args_dict.keys() # TODO: create a class with customizable name class WorkerDict(worker_cls): - def __init__(self): super().__init__() self.worker_dict = {} @@ -486,9 +487,10 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): user_defined_cls = _unwrap_ray_remote(user_defined_cls) # directly instantiate the class without remote # in worker class, e.g. when DISABLE_WORKER_INIT == 1 it will return immediately - with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}): - self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()), - **init_args_dict[key].get('kwargs', {})) + with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): + self.worker_dict[key] = user_defined_cls( + *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) + ) # now monkey-patch the methods from inner class to WorkerDict for key, user_defined_cls in cls_dict.items(): diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py index 3ccb23a15..868e532a0 100644 --- a/verl/single_controller/ray/megatron.py +++ b/verl/single_controller/ray/megatron.py @@ -16,10 +16,11 @@ from typing import Dict, Optional import ray -from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs -from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo +from verl.single_controller.base.megatron.worker import DistGlobalInfo, DistRankInfo from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup +from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup + # NOTE(sgm): for open-source megatron-core class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): @@ -30,9 +31,10 @@ class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') + self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") self._megatron_global_info: DistGlobalInfo = ray.get( - self.execute_rank_zero_async(method_name='get_megatron_global_info')) + self.execute_rank_zero_async(method_name="get_megatron_global_info") + ) class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): @@ -41,22 +43,27 @@ class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): so that the dispatcher can use it to dispatch data. """ - def __init__(self, - resource_pool: RayResourcePool, - ray_cls_with_init: RayClassWithInitArgs, - default_megatron_kwargs: Dict = None, - **kwargs): - super().__init__(resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - default_megatron_kwargs=default_megatron_kwargs, - **kwargs) + def __init__( + self, + resource_pool: RayResourcePool, + ray_cls_with_init: RayClassWithInitArgs, + default_megatron_kwargs: Dict = None, + **kwargs, + ): + super().__init__( + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + default_megatron_kwargs=default_megatron_kwargs, + **kwargs, + ) self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') + self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") self._megatron_global_info: DistGlobalInfo = ray.get( - self.execute_rank_zero_async(method_name='get_megatron_global_info')) + self.execute_rank_zero_async(method_name="get_megatron_global_info") + ) def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None): # after super, we will call init of each worker if not self._is_init_with_detached_workers: # only init_megatron if the WorkerGroup is created from scratch - self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs) + self.execute_all_sync(method_name="init_megatron", default_megatron_kwargs=default_megatron_kwargs) diff --git a/verl/third_party/sglang/__init__.py b/verl/third_party/sglang/__init__.py index 93d16900e..15593caaf 100644 --- a/verl/third_party/sglang/__init__.py +++ b/verl/third_party/sglang/__init__.py @@ -23,4 +23,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/verl/third_party/sglang/parallel_state.py b/verl/third_party/sglang/parallel_state.py index 30bbea3b3..0153139a7 100644 --- a/verl/third_party/sglang/parallel_state.py +++ b/verl/third_party/sglang/parallel_state.py @@ -4,18 +4,20 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Model and data parallel groups.""" + import os from typing import Optional +import sglang.srt.distributed.parallel_state as ps import torch import torch.distributed -import sglang.srt.distributed.parallel_state as ps from sglang.srt.distributed.parallel_state import ( get_pp_group, get_world_group, init_distributed_environment, init_model_parallel_group, ) + """ This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. - We assume the Megatron tp+dp+pp world is already established before calling this function. @@ -90,12 +92,14 @@ def ensure_model_parallel_initialized( assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( "tensor parallel group already initialized, but of unexpected size: " f"{get_tensor_model_parallel_world_size()=} vs. " - f"{tensor_model_parallel_size=}") + f"{tensor_model_parallel_size=}" + ) pp_world_size = get_pp_group().world_size assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size: " f"{pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}") + f"{pipeline_model_parallel_size=}" + ) # TODO(sgm): deviate from the v0.5.4, not pp now diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index 34d61e714..ade033510 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from importlib.metadata import version, PackageNotFoundError +from importlib.metadata import PackageNotFoundError, version + from packaging import version as vs + from verl.utils.import_utils import is_sglang_available @@ -24,36 +26,27 @@ def get_version(pkg): return None -package_name = 'vllm' +package_name = "vllm" package_version = get_version(package_name) vllm_version = None -if package_version == '0.3.1': - vllm_version = '0.3.1' - from .vllm_v_0_3_1.llm import LLM - from .vllm_v_0_3_1.llm import LLMEngine +if package_version == "0.3.1": + vllm_version = "0.3.1" from .vllm_v_0_3_1 import parallel_state -elif package_version == '0.4.2': - vllm_version = '0.4.2' - from .vllm_v_0_4_2.llm import LLM - from .vllm_v_0_4_2.llm import LLMEngine + from .vllm_v_0_3_1.llm import LLM, LLMEngine +elif package_version == "0.4.2": + vllm_version = "0.4.2" from .vllm_v_0_4_2 import parallel_state -elif package_version == '0.5.4': - vllm_version = '0.5.4' - from .vllm_v_0_5_4.llm import LLM - from .vllm_v_0_5_4.llm import LLMEngine + from .vllm_v_0_4_2.llm import LLM, LLMEngine +elif package_version == "0.5.4": + vllm_version = "0.5.4" from .vllm_v_0_5_4 import parallel_state -elif package_version == '0.6.3': - vllm_version = '0.6.3' - from .vllm_v_0_6_3.llm import LLM - from .vllm_v_0_6_3.llm import LLMEngine + from .vllm_v_0_5_4.llm import LLM, LLMEngine +elif package_version == "0.6.3" or package_version == "0.6.3+rocm624": + vllm_version = "0.6.3" from .vllm_v_0_6_3 import parallel_state -elif package_version == '0.6.3+rocm624': - vllm_version = '0.6.3' - from .vllm_v_0_6_3.llm import LLM - from .vllm_v_0_6_3.llm import LLMEngine - from .vllm_v_0_6_3 import parallel_state -elif vs.parse(package_version) >= vs.parse('0.7.0'): + from .vllm_v_0_6_3.llm import LLM, LLMEngine +elif vs.parse(package_version) >= vs.parse("0.7.0"): # From 0.6.6.post2 on, vllm supports SPMD inference # See https://github.com/vllm-project/vllm/pull/12071 @@ -62,5 +55,5 @@ elif vs.parse(package_version) >= vs.parse('0.7.0'): else: if not is_sglang_available(): raise ValueError( - f'vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+' + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+" ) diff --git a/verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py b/verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py index 1ae8f3b8f..8e744ee4d 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py @@ -15,20 +15,21 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple -import torch.nn as nn -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) from transformers import PretrainedConfig +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig + from .config import ModelConfig @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model_hf_config: PretrainedConfig = None - dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + dtype: str = "auto" + kv_cache_dtype: str = "auto" seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -45,7 +46,7 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None - load_format: str = 'model' + load_format: str = "model" enforce_eager: bool = False max_context_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False @@ -53,129 +54,149 @@ class EngineArgs: max_loras: int = 1 max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 - lora_dtype = 'auto' + lora_dtype = "auto" max_cpu_loras: Optional[int] = None - device: str = 'cuda' + device: str = "cuda" @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Shared CLI arguments for vLLM engine.""" # Model arguments # TODO(shengguangming): delete the unused args - parser.add_argument('--model', - type=str, - default='facebook/opt-125m', - help='name or path of the huggingface model to use') - parser.add_argument('--tokenizer', - type=str, - default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') - parser.add_argument('--revision', - type=str, - default=None, - help='the specific model version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-revision', - type=str, - default=None, - help='the specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') - parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') - parser.add_argument('--download-dir', - type=str, - default=EngineArgs.download_dir, - help='directory to download and load the weights, ' - 'default to the default cache dir of ' - 'huggingface') - parser.add_argument('--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument('--dtype', - type=str, - default=EngineArgs.dtype, - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--max-model-len', - type=int, - default=None, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + parser.add_argument( + "--model", type=str, default="facebook/opt-125m", help="name or path of the huggingface model to use" + ) + parser.add_argument( + "--tokenizer", + type=str, + default=EngineArgs.tokenizer, + help="name or path of the huggingface tokenizer to use", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="the specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-revision", + type=str, + default=None, + help="the specific tokenizer version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + "always use the slow tokenizer.", + ) + parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface") + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, default to the default cache dir of huggingface", + ) + parser.add_argument( + "--load-format", + type=str, + default=EngineArgs.load_format, + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling.", + ) + parser.add_argument( + "--dtype", + type=str, + default=EngineArgs.dtype, + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " + 'The "auto" option will use FP16 precision ' + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, will be automatically derived from the model.", + ) # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', - '-pp', - type=int, - default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', - '-tp', - type=int, - default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", + type=int, + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32], - help='token block size') + parser.add_argument( + "--block-size", type=int, default=EngineArgs.block_size, choices=[8, 16, 32], help="token block size" + ) # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed') - parser.add_argument('--swap-space', - type=int, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='the percentage of GPU memory to be used for' - 'the model executor') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') - parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + parser.add_argument("--seed", type=int, default=EngineArgs.seed, help="random seed") + parser.add_argument( + "--swap-space", type=int, default=EngineArgs.swap_space, help="CPU swap space size (GiB) per GPU" + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=EngineArgs.gpu_memory_utilization, + help="the percentage of GPU memory to be used forthe model executor", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument("--disable-log-stats", action="store_true", help="disable logging statistics") # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', None], - default=None, - help='Method used to quantize the weights') + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", None], + default=None, + help="Method used to quantize the weights", + ) return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -186,27 +207,53 @@ class EngineArgs: self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: device_config = DeviceConfig(self.device) - model_config = ModelConfig(self.model_hf_config, self.dtype, self.seed, self.load_format, self.revision, - self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, - self.max_context_len_to_capture) - cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window()) - parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, - self.max_parallel_loading_workers, self.disable_custom_all_reduce) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, - self.max_paddings) - lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else - None) if self.enable_lora else None + model_config = ModelConfig( + self.model_hf_config, + self.dtype, + self.seed, + self.load_format, + self.revision, + self.tokenizer_revision, + self.max_model_len, + self.quantization, + self.enforce_eager, + self.max_context_len_to_capture, + ) + cache_config = CacheConfig( + self.block_size, + self.gpu_memory_utilization, + self.swap_space, + self.kv_cache_dtype, + model_config.get_sliding_window(), + ) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, + self.disable_custom_all_reduce, + ) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, self.max_paddings + ) + lora_config = ( + LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None, + ) + if self.enable_lora + else None + ) return (model_config, cache_config, parallel_config, scheduler_config, device_config, lora_config) @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + engine_use_ray: bool = False disable_log_requests: bool = False max_log_len: Optional[int] = None @@ -214,15 +261,16 @@ class AsyncEngineArgs(EngineArgs): @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--engine-use-ray', - action='store_true', - help='use Ray to start the LLM engine in a ' - 'separate process as the server process.') - parser.add_argument('--disable-log-requests', action='store_true', help='disable logging requests') - parser.add_argument('--max-log-len', - type=int, - default=None, - help='max number of prompt characters or prompt ' - 'ID numbers being printed in log. ' - 'Default: unlimited.') + parser.add_argument( + "--engine-use-ray", + action="store_true", + help="use Ray to start the LLM engine in a separate process as the server process.", + ) + parser.add_argument("--disable-log-requests", action="store_true", help="disable logging requests") + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="max number of prompt characters or prompt ID numbers being printed in log. Default: unlimited.", + ) return parser diff --git a/verl/third_party/vllm/vllm_v_0_3_1/config.py b/verl/third_party/vllm/vllm_v_0_3_1/config.py index 1e1fead86..201b6b9d3 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/config.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/config.py @@ -13,15 +13,14 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py -from typing import Optional, Union, ClassVar from dataclasses import dataclass -import torch -from transformers import PretrainedConfig -from packaging.version import Version +from typing import ClassVar, Optional, Union +import torch +from packaging.version import Version +from transformers import PretrainedConfig from vllm.logger import init_logger -from vllm.transformers_utils.config import get_config -from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version +from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip logger = init_logger(__name__) @@ -77,7 +76,7 @@ class ModelConfig: hf_config: PretrainedConfig, dtype: str, seed: int, - load_format: str = 'model', + load_format: str = "model", revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, @@ -109,8 +108,10 @@ class ModelConfig: def _verify_load_format(self) -> None: load_format = self.load_format.lower() if load_format not in ["auto", "pt", "safetensors", "npcache", "dummy", "model"]: - raise ValueError(f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', 'dummy' or 'model'.") + raise ValueError( + f"Unknown load format: {self.load_format}. Must be one of " + "'auto', 'pt', 'safetensors', 'npcache', 'dummy' or 'model'." + ) self.load_format = load_format # def _verify_tokenizer_mode(self) -> None: @@ -134,30 +135,33 @@ class ModelConfig: if self.quantization is None: self.quantization = hf_quant_method elif self.quantization != hf_quant_method: - raise ValueError("Quantization method specified in the model config " - f"({hf_quant_method}) does not match the quantization " - f"method specified in the `quantization` argument " - f"({self.quantization}).") + raise ValueError( + "Quantization method specified in the model config " + f"({hf_quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: - raise ValueError(f"Unknown quantization method: {self.quantization}. Must " - f"be one of {supported_quantization}.") + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must be one of {supported_quantization}." + ) if is_hip() and self.quantization in rocm_not_supported_quantization: - raise ValueError(f"{self.quantization} quantization is currently not supported " - f"in ROCm.") - logger.warning(f"{self.quantization} quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.") + raise ValueError(f"{self.quantization} quantization is currently not supported in ROCm.") + logger.warning( + f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models." + ) def _verify_cuda_graph(self) -> None: if self.max_context_len_to_capture is None: self.max_context_len_to_capture = self.max_model_len self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len) - if (self.quantization in ["gptq", "squeezellm"] and not self.enforce_eager): + if self.quantization in ["gptq", "squeezellm"] and not self.enforce_eager: # Related issue: https://github.com/vllm-project/vllm/issues/2147 - logger.warning(f"{self.quantization} does not support CUDA graph " - "yet. Disabling CUDA graph.") + logger.warning(f"{self.quantization} does not support CUDA graph yet. Disabling CUDA graph.") self.enforce_eager = True def verify_with_parallel_config( @@ -167,16 +171,20 @@ class ModelConfig: total_num_attention_heads = self.hf_config.num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: - raise ValueError(f"Total number of attention heads ({total_num_attention_heads})" - " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size})." + ) total_num_hidden_layers = self.hf_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size if total_num_hidden_layers % pipeline_parallel_size != 0: - raise ValueError(f"Total number of hidden layers ({total_num_hidden_layers}) " - "must be divisible by pipeline parallel size " - f"({pipeline_parallel_size}).") + raise ValueError( + f"Total number of hidden layers ({total_num_hidden_layers}) " + "must be divisible by pipeline parallel size " + f"({pipeline_parallel_size})." + ) def get_sliding_window(self) -> Optional[int]: return getattr(self.hf_config, "sliding_window", None) @@ -198,8 +206,9 @@ class ModelConfig: # multi_query flag is ignored and we use n_head_kv for the number of # KV heads. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = (self.hf_config.model_type in falcon_model_types and - getattr(self.hf_config, "new_decoder_architecture", False)) + new_decoder_arch_falcon = self.hf_config.model_type in falcon_model_types and getattr( + self.hf_config, "new_decoder_architecture", False + ) if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. @@ -270,8 +279,7 @@ class CacheConfig: def _verify_args(self) -> None: if self.gpu_memory_utilization > 1.0: - raise ValueError("GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") + raise ValueError(f"GPU memory utilization must be less than 1.0. Got {self.gpu_memory_utilization}.") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": @@ -283,11 +291,13 @@ class CacheConfig: device_name = torch.cuda.get_device_name() if "AMD" in device_name: raise NotImplementedError("FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") - logger.info("Using fp8_e5m2 data type to store kv cache. It reduces " - "the GPU memory footprint and boosts the performance. " - "But it may cause slight accuracy drop. " - "Currently we only support fp8 without scaling factors and " - "make e5m2 as a default format.") + logger.info( + "Using fp8_e5m2 data type to store kv cache. It reduces " + "the GPU memory footprint and boosts the performance. " + "But it may cause slight accuracy drop. " + "Currently we only support fp8 without scaling factors and " + "make e5m2 as a default format." + ) else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -301,9 +311,11 @@ class CacheConfig: num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " - f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " - "allocated for the swap space.") + msg = ( + f"{cpu_memory_usage / _GB:.2f} GiB out of " + f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " + "allocated for the swap space." + ) if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: @@ -351,20 +363,22 @@ class ParallelConfig: if not self.disable_custom_all_reduce and self.world_size > 1: if is_hip(): self.disable_custom_all_reduce = True - logger.info("Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") + logger.info("Disabled the custom all-reduce kernel because it is not supported on AMD GPUs.") elif self.pipeline_parallel_size > 1: self.disable_custom_all_reduce = True - logger.info("Disabled the custom all-reduce kernel because it is not " - "supported with pipeline parallelism.") + logger.info( + "Disabled the custom all-reduce kernel because it is not supported with pipeline parallelism." + ) # FIXME(woosuk): Fix the stability issues and re-enable the custom # all-reduce kernel. if not self.disable_custom_all_reduce and self.world_size > 1: self.disable_custom_all_reduce = True - logger.info("Custom all-reduce kernels are temporarily disabled due to " - "stability issues. We will re-enable them once the issues are " - "resolved.") + logger.info( + "Custom all-reduce kernels are temporarily disabled due to " + "stability issues. We will re-enable them once the issues are " + "resolved." + ) class SchedulerConfig: @@ -400,20 +414,23 @@ class SchedulerConfig: def _verify_args(self) -> None: if self.max_num_batched_tokens < self.max_model_len: - raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " - f"smaller than max_model_len ({self.max_model_len}). " - "This effectively limits the maximum sequence length to " - "max_num_batched_tokens and makes vLLM reject longer " - "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: - raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " - "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs})." + ) class DeviceConfig: - def __init__(self, device: str = "cuda") -> None: self.device = torch.device(device) @@ -433,18 +450,17 @@ class LoRAConfig: possible_max_ranks = (8, 16, 32, 64) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: - raise ValueError(f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") + raise ValueError(f"max_lora_rank ({self.max_lora_rank}) must be one of {possible_max_ranks}.") if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError(f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) must be one of {possible_lora_extra_vocab_size}." + ) if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: self.max_cpu_loras = self.max_loras elif self.max_cpu_loras < self.max_loras: - raise ValueError(f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})") + raise ValueError(f"max_cpu_loras ({self.max_cpu_loras}) must be >= max_loras ({self.max_loras})") def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): @@ -456,9 +472,11 @@ class LoRAConfig: def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: - raise ValueError("Due to limitations of the custom LoRA CUDA kernel, " - "max_num_batched_tokens must be <= 65528 when " - "LoRA is enabled.") + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled." + ) _STR_DTYPE_TO_TORCH_DTYPE = { @@ -504,8 +522,7 @@ def _get_and_verify_dtype( rocm_supported_dtypes = [ k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() if (k not in _ROCM_NOT_SUPPORTED_DTYPE) ] - raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") + raise ValueError(f"dtype '{dtype}' is not supported in ROCm. Supported dtypes are {rocm_supported_dtypes}") # Verify the dtype. if torch_dtype != config_dtype: @@ -552,10 +569,12 @@ def _get_and_verify_max_len( return max_model_len default_max_len = 2048 - logger.warning("The model's config.json does not contain any of the following " - "keys to determine the original maximum length of the model: " - f"{possible_keys}. Assuming the model's maximum length is " - f"{default_max_len}.") + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + f"{possible_keys}. Assuming the model's maximum length is " + f"{default_max_len}." + ) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) @@ -569,9 +588,11 @@ def _get_and_verify_max_len( if max_model_len is None: max_model_len = derived_max_model_len elif max_model_len > derived_max_model_len: - raise ValueError(f"User-specified max_model_len ({max_model_len}) is greater than " - f"the derived max_model_len ({max_len_key}={derived_max_model_len}" - " in model's config.json). This may lead to incorrect model " - "outputs or CUDA errors. Make sure the value is correct and " - "within the model context size.") + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater than " + f"the derived max_model_len ({max_len_key}={derived_max_model_len}" + " in model's config.json). This may lead to incorrect model " + "outputs or CUDA errors. Make sure the value is correct and " + "within the model context size." + ) return int(max_model_len) diff --git a/verl/third_party/vllm/vllm_v_0_3_1/llm.py b/verl/third_party/vllm/vllm_v_0_3_1/llm.py index 8d2475998..a2fbb38f0 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/llm.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/llm.py @@ -15,20 +15,21 @@ from typing import Dict, List, Optional, Tuple, Union -from tqdm import tqdm -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers import PretrainedConfig +import torch import torch.nn as nn -from .arg_utils import EngineArgs -from .llm_engine_sp import LLMEngine +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.utils import Counter -import torch -from torch.nn.utils.rnn import pad_sequence + from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -86,7 +87,7 @@ class LLM: def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], model_hf_config: PretrainedConfig, tokenizer_mode: str = "auto", @@ -173,15 +174,13 @@ class LLM: completions in the same order as the input prompts. """ if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") + raise ValueError("Either prompts or prompt_token_ids must be provided.") if isinstance(prompts, str): # Convert a single prompt to a list. prompts = [prompts] if prompts is not None and prompt_token_ids is not None: if len(prompts) != len(prompt_token_ids): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + raise ValueError("The lengths of prompts and prompt_token_ids must be the same.") if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() @@ -207,12 +206,9 @@ class LLM: prefix_pos: Optional[int] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids, - lora_request=lora_request, - prefix_pos=prefix_pos) + self.llm_engine.add_request( + request_id, prompt, sampling_params, prompt_token_ids, lora_request=lora_request, prefix_pos=prefix_pos + ) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. @@ -241,7 +237,11 @@ class LLM: # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: # remove the left padding in the prompt token_id - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + pad_token_id = ( + self.llm_engine.tokenizer.pad_token_id + if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id + ) non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids @@ -262,7 +262,11 @@ class LLM: logprob.append(logprobs_dict[id]) logprobs.append(torch.tensor(logprob)) - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + pad_token_id = ( + self.llm_engine.tokenizer.pad_token_id + if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id + ) output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) if len(logprobs) > 0: logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) diff --git a/verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py index e264a8585..a0e0c09ca 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py @@ -16,21 +16,21 @@ import os import socket import time -import torch -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union -from vllm.lora.request import LoRARequest -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) +import torch +import torch.nn as nn +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.engine.metrics import StatLogger, Stats from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus from vllm.transformers_utils.tokenizer import detokenize_incrementally -from vllm.engine.metrics import StatLogger, Stats from vllm.utils import Counter -import torch.nn as nn + from .arg_utils import EngineArgs from .tokenizer import TokenizerGroup @@ -69,7 +69,7 @@ class LLMEngine: def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict tokenizer: nn.Module, model_config: ModelConfig, cache_config: CacheConfig, @@ -81,21 +81,23 @@ class LLMEngine: placement_group: Optional[None], log_stats: bool, ) -> None: - logger.info("Initializing an LLM engine with config: " - f"model={model_config.model!r}, " - f"tokenizer={model_config.tokenizer!r}, " - # f"tokenizer_mode={model_config.tokenizer_mode}, " - f"revision={model_config.revision}, " - f"tokenizer_revision={model_config.tokenizer_revision}, " - # f"trust_remote_code={model_config.trust_remote_code}, " - f"dtype={model_config.dtype}, " - f"max_seq_len={model_config.max_model_len}, " - # f"download_dir={model_config.download_dir!r}, " - # f"load_format={model_config.load_format}, " - f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " - f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " - f"quantization={model_config.quantization}, " - f"seed={model_config.seed})") + logger.info( + "Initializing an LLM engine with config: " + f"model={model_config.model!r}, " + f"tokenizer={model_config.tokenizer!r}, " + # f"tokenizer_mode={model_config.tokenizer_mode}, " + f"revision={model_config.revision}, " + f"tokenizer_revision={model_config.tokenizer_revision}, " + # f"trust_remote_code={model_config.trust_remote_code}, " + f"dtype={model_config.dtype}, " + f"max_seq_len={model_config.max_model_len}, " + # f"download_dir={model_config.download_dir!r}, " + # f"load_format={model_config.load_format}, " + f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " + f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"quantization={model_config.quantization}, " + f"seed={model_config.seed})" + ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config # TODO: currently is hfconfig @@ -136,9 +138,9 @@ class LLMEngine: self.num_generation_tokens: List[Tuple[float, int]] = [] def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): - init_kwargs = dict(enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None) + init_kwargs = dict( + enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None + ) init_kwargs.update(tokenizer_init_kwargs) self.tokenizer: TokenizerGroup = TokenizerGroup(tokenizer, **init_kwargs) @@ -149,7 +151,7 @@ class LLMEngine: def _init_workers_sp(self, model, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from .worker import Worker # pylint: disable=import-outside-toplevel + from .worker import Worker rank = int(os.getenv("RANK")) @@ -189,21 +191,24 @@ class LLMEngine: num_cpu_blocks = num_blocks[1] # FIXME(woosuk): Change to debug log. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") + logger.info(f"# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}") if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_seq_len = self.cache_config.block_size * num_gpu_blocks if self.model_config.max_model_len > max_seq_len: - raise ValueError(f"The model's max seq len ({self.model_config.max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine." + ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -227,12 +232,14 @@ class LLMEngine: # Initialize the cluster. distributed_init_method, placement_group = initialize_cluster(parallel_config) # Create the LLM engine. - engine = cls(model, - tokenizer, - *engine_configs, - distributed_init_method, - placement_group, - log_stats=not engine_args.disable_log_stats) + engine = cls( + model, + tokenizer, + *engine_configs, + distributed_init_method, + placement_group, + log_stats=not engine_args.disable_log_stats, + ) return engine def add_request( @@ -291,8 +298,7 @@ class LLMEngine: >>> ... """ if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + raise ValueError(f"Got lora_request {lora_request} but LoRA is not enabled!") if arrival_time is None: arrival_time = time.monotonic() if prompt_token_ids is None: @@ -305,8 +311,13 @@ class LLMEngine: seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, lora_request) # Check whether the input specifies prefix - prefix = self.scheduler.prefix_pool.add_or_get_prefix(prompt_token_ids[:prefix_pos], lora_request.lora_int_id if - lora_request else 0) if prefix_pos is not None else None + prefix = ( + self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos], lora_request.lora_int_id if lora_request else 0 + ) + if prefix_pos is not None + else None + ) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time, lora_request, prefix) @@ -357,33 +368,37 @@ class LLMEngine: if early_stopping is True: return True - current_worst_score = (current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(current_worst_seq).eos_token_id)) + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(current_worst_seq).eos_token_id + ) if early_stopping is False: - highest_attainable_score = (best_running_seq.get_beam_search_score( - length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id)) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id + ) else: assert early_stopping == "never" if length_penalty > 0.0: # If length_penalty > 0.0, beam search will prefer longer # sequences. The highest attainable score calculation is # based on the longest possible sequence length in this case. - max_possible_length = max(best_running_seq.get_prompt_len() + sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = (best_running_seq.get_beam_search_score( + max_possible_length = max( + best_running_seq.get_prompt_len() + sampling_params.max_tokens, self.scheduler_config.max_model_len + ) + highest_attainable_score = best_running_seq.get_beam_search_score( length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id, - seq_len=max_possible_length)) + seq_len=max_possible_length, + ) else: # Otherwise, beam search will prefer shorter sequences. The # highest attainable score calculation is based on the current # sequence length. - highest_attainable_score = (best_running_seq.get_beam_search_score( + highest_attainable_score = best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id, + ) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: - # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: @@ -460,9 +475,12 @@ class LLMEngine: new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs if seq.is_finished()] all_finished_seqs = existing_finished_seqs + new_finished_seqs # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), - reverse=True) + all_finished_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id + ), + reverse=True, + ) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: # A newly generated child sequence finishes and has a high @@ -486,9 +504,12 @@ class LLMEngine: # search. running_child_seqs = [(seq, parent) for seq, parent in child_seqs if not seq.is_finished()] # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), - reverse=True) + running_child_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id + ), + reverse=True, + ) # Check if we can stop the beam search. if len(running_child_seqs) == 0: @@ -501,9 +522,9 @@ class LLMEngine: # Check the early stopping criteria best_running_seq = running_child_seqs[0][0] current_worst_seq = all_finished_seqs[beam_width - 1][0] - stop_beam_search = self._check_beam_search_early_stopping(seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, - current_worst_seq) + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, seq_group.sampling_params, best_running_seq, current_worst_seq + ) if stop_beam_search: # Stop the beam search and remove all the running sequences from @@ -562,7 +583,7 @@ class LLMEngine: # Update prefix state, now all the uncomputed prefixes are computed. for seq_group in scheduled_seq_groups: - if (seq_group.prefix is not None and seq_group.prefix.allocated and not seq_group.prefix.computed): + if seq_group.prefix is not None and seq_group.prefix.allocated and not seq_group.prefix.computed: seq_group.prefix.computed = True # Log stats. @@ -583,10 +604,11 @@ class LLMEngine: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if not scheduler_outputs.is_empty(): output = self.worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, # TODO: check this input - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy,) + seq_group_metadata_list=seq_group_metadata_list, # TODO: check this input + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) else: return [RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups] @@ -607,7 +629,7 @@ class LLMEngine: gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage = 0. + cpu_cache_usage = 0.0 if num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks() cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) @@ -702,8 +724,7 @@ class LLMEngine: return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) and - seq.get_last_token_id() == self.get_tokenizer_for_seq(seq).eos_token_id): + if (not sampling_params.ignore_eos) and seq.get_last_token_id() == self.get_tokenizer_for_seq(seq).eos_token_id: seq.status = SequenceStatus.FINISHED_STOPPED return @@ -711,7 +732,7 @@ class LLMEngine: if not sampling_params.include_stop_str_in_output and stop_string: # Truncate the output text so that the stop string is # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + seq.output_text = seq.output_text[: -len(stop_string)] def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." diff --git a/verl/third_party/vllm/vllm_v_0_3_1/model_loader.py b/verl/third_party/vllm/vllm_v_0_3_1/model_loader.py index 450e2f4b4..fc76c0843 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/model_loader.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/model_loader.py @@ -13,25 +13,33 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader """Utilities for selecting and loading models.""" + import contextlib -from typing import Dict, Type, Union +from typing import Dict, Optional, Type, Union import torch import torch.nn as nn -from transformers import PretrainedConfig, PreTrainedModel from megatron.core.tensor_parallel.utils import VocabUtility - +from transformers import PretrainedConfig, PreTrainedModel +from vllm.config import DeviceConfig, LoRAConfig +from vllm.model_executor.layers.sampler import ( + Sampler, + _apply_logits_processors, + _apply_min_p, + _apply_penalties, + _apply_top_k_top_p, + _build_sampler_output, + _get_logprobs, + _prune_hidden_states, + _sample, +) from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors +from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights +from vllm.sequence import SamplerOutput from .config import ModelConfig -from vllm.config import DeviceConfig, LoRAConfig from .weight_loaders import * -from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors -from vllm.sequence import SamplerOutput -from typing import Optional -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.sampler import _prune_hidden_states, _apply_logits_processors, _apply_penalties, _apply_top_k_top_p, _apply_min_p, _apply_penalties, _sample, _get_logprobs, _build_sampler_output @contextlib.contextmanager @@ -49,13 +57,14 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: model_cls = ModelRegistry.load_model_cls(arch) if model_cls is not None: return model_cls - raise ValueError(f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) from vllm.model_executor.layers.linear import * -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead -from vllm.model_executor.layers.activation import ScaledActivation +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding __LAYER_WEIGHT_LOADER_REGISTRY__ = { ColumnParallelLinear: parallel_weight_loader, @@ -63,7 +72,7 @@ __LAYER_WEIGHT_LOADER_REGISTRY__ = { QKVParallelLinear: parallel_weight_loader, RowParallelLinear: parallel_weight_loader, VocabParallelEmbedding: parallel_weight_loader, - ParallelLMHead: parallel_weight_loader + ParallelLMHead: parallel_weight_loader, # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights # "default_weight_loader": default_weight_loader } @@ -73,10 +82,10 @@ for layer_class, weight_loader in __LAYER_WEIGHT_LOADER_REGISTRY__.items(): layer_class.weight_loader = weight_loader __MODEL_WEIGHT_LOADER_REGISTRY__ = { - 'GPT2LMHeadModel': gpt2_weight_loader, - 'LlamaForCausalLM': llama_weight_loader, - 'LLaMAForCausalLM': llama_weight_loader, - 'MistralForCausalLM': mistral_weight_loader, + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_weight_loader, + "LLaMAForCausalLM": llama_weight_loader, + "MistralForCausalLM": mistral_weight_loader, } # FIXME(shengguangming): the vLLM vocab will pad to 64, which may incur out of bounds @@ -84,12 +93,14 @@ __MODEL_WEIGHT_LOADER_REGISTRY__ = { DEFAULT_VOCAB_PADDING_SIZE = 64 -def vocab_init(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): +def vocab_init( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, +): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. @@ -105,15 +116,18 @@ def vocab_init(self, self.tp_size = get_tensor_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size)) - self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index) + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index self.weight = Parameter( torch.empty( self.num_embeddings_per_partition, self.embedding_dim, # device=torch.cuda.current_device(), - dtype=params_dtype)) + dtype=params_dtype, + ) + ) set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader}) @@ -123,34 +137,43 @@ VocabParallelEmbedding.__init__ = vocab_init def _get_model_weight_loader(arch: str): if arch in __MODEL_WEIGHT_LOADER_REGISTRY__: return __MODEL_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) -def get_model(actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig] = None) -> nn.Module: +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig] = None, +) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the quantization config. linear_method = None quant_config = None if model_config.quantization is not None: - quant_config = get_quant_config(model_config.quantization, model_config.model, model_config.hf_config, - model_config.download_dir) + quant_config = get_quant_config( + model_config.quantization, model_config.model, model_config.hf_config, model_config.download_dir + ) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): - raise ValueError(f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: - raise ValueError(f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): @@ -165,7 +188,7 @@ def get_model(actor_model: Union[PreTrainedModel, Dict], # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) - elif model_config.load_format == 'model' or model_config.load_format == 'auto': + elif model_config.load_format == "model" or model_config.load_format == "auto": # NOTE(shengguangming) Load the weights from the actor model if isinstance(actor_model, nn.Module): load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) @@ -190,8 +213,9 @@ def load_weights(actor_weights: Dict, vllm_model: nn.Module): # as they use ray, the sampler result will only need to return to the driver node, # therefore gather is enough. However, we use SPMD instead of a central scheduler, # all_gather is required (aligned with v0.2.6) -def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: +def _get_logits( + self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor] +) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: @@ -199,7 +223,7 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[:, :self.org_vocab_size] + logits = logits[:, : self.org_vocab_size] return logits @@ -232,14 +256,20 @@ def forward( logits = _apply_logits_processors(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata(sampling_metadata, vocab_size, logits.device, logits.dtype) + (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype + ) # Apply presence and frequency penalties. if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + logits = _apply_penalties( + logits, + sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties, + ) # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -269,7 +299,5 @@ def forward( return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) -from vllm.model_executor.layers.sampler import Sampler - Sampler._get_logits = _get_logits Sampler.forward = forward diff --git a/verl/third_party/vllm/vllm_v_0_3_1/model_runner.py b/verl/third_party/vllm/vllm_v_0_3_1/model_runner.py index 4acf3422d..605b44c68 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/model_runner.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/model_runner.py @@ -13,23 +13,20 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py -from typing import Dict, List, Optional, Tuple, Set, Union -import contextlib -import time -import numpy as np +from typing import Dict, List, Optional, Set, Tuple, Union + import torch import torch.nn as nn - -from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig) +from vllm.config import DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest from vllm.utils import in_wsl -from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner, _async_h2d +from vllm.worker.model_runner import CUDAGraphRunner, ModelRunner, _async_h2d from .model_loader import get_model @@ -44,10 +41,9 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] class ModelRunner(ModelRunner): - def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -62,9 +58,9 @@ class ModelRunner(ModelRunner): # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) + self.sliding_window = model_config.get_sliding_window() if model_config is not None else None - self.device_config = (device_config if device_config is not None else DeviceConfig()) + self.device_config = device_config if device_config is not None else DeviceConfig() self.device = self.device_config.device self.model = model # this will be replaced by get_model() @@ -74,8 +70,9 @@ class ModelRunner(ModelRunner): self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. - self.max_context_len_to_capture = (self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture if self.model_config is not None else 0 + ) # When using CUDA graph, the input block tables must be padded to # max_context_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -88,22 +85,29 @@ class ModelRunner(ModelRunner): self.kv_cache_dtype = kv_cache_dtype def load_model(self) -> None: - self.model = get_model(actor_model=self.model, - model_config=self.model_config, - device_config=self.device_config, - lora_config=self.lora_config) + self.model = get_model( + actor_model=self.model, + model_config=self.model_config, + device_config=self.device_config, + lora_config=self.lora_config, + ) vocab_size = self.model.config.vocab_size if self.lora_config: - assert hasattr( - self.model, - "supported_lora_modules") and self.model.supported_lora_modules, "Model does not support LoRA" + assert hasattr(self.model, "supported_lora_modules") and self.model.supported_lora_modules, ( + "Model does not support LoRA" + ) assert hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules" assert hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_paddings, vocab_size, - self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) + self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_paddings, + vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + ) self.model = self.lora_manager.create_lora_manager(self.model) def _prepare_sample( @@ -137,7 +141,8 @@ class ModelRunner(ModelRunner): if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( - range(selected_token_start_idx, selected_token_start_idx + subquery_len - 1)) + range(selected_token_start_idx, selected_token_start_idx + subquery_len - 1) + ) selected_token_indices.append(selected_token_start_idx + subquery_len - 1) selected_token_start_idx += max_subquery_len else: @@ -146,13 +151,13 @@ class ModelRunner(ModelRunner): selected_token_start_idx += num_seqs categorized_sample_indices[sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, categorized_sample_indices_start_idx + num_seqs)) + range(categorized_sample_indices_start_idx, categorized_sample_indices_start_idx + num_seqs) + ) categorized_sample_indices_start_idx += num_seqs - selected_token_indices = _async_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=not self.in_wsl) + selected_token_indices = _async_h2d( + selected_token_indices, dtype=torch.long, target_device=self.device, pin_memory=not self.in_wsl + ) categorized_sample_indices = { t: _async_h2d(seq_ids, dtype=torch.int, target_device=self.device, pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() @@ -180,11 +185,20 @@ class ModelRunner(ModelRunner): is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, - lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) + ( + input_tokens, + input_positions, + input_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + ) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = ( + self._prepare_decode(seq_group_metadata_list) + ) prompt_lens = [] subquery_lens = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -205,8 +219,9 @@ class ModelRunner(ModelRunner): seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, - lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping) = ( + self.prepare_input_tensors(seq_group_metadata_list) + ) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -265,7 +280,7 @@ class ModelRunner(ModelRunner): # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + seq_len = max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs) seq_data = SequenceData([0] * seq_len) seq = SequenceGroupMetadata( request_id=str(group_id), diff --git a/verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py b/verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py index c3b7a45c8..32648ee4f 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py @@ -7,8 +7,8 @@ import torch import torch.distributed - import vllm.model_executor.parallel_utils.parallel_state as ps + """ This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. - We assume the Megatron tp+dp+pp world is already established before calling this function. @@ -24,10 +24,11 @@ _MICRO_DATA_PARALLEL_GROUP = None def initialize_model_parallel_from_megatron( - tensor_model_parallel_size=None # we set None for backward compatibility to set infer_tp = train_tp + tensor_model_parallel_size=None, # we set None for backward compatibility to set infer_tp = train_tp ) -> None: from megatron.core import parallel_state as mpu from megatron.distributed import new_group + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() @@ -37,10 +38,11 @@ def initialize_model_parallel_from_megatron( assert isinstance(tensor_model_parallel_size, int) # Build the tensor model-parallel groups. - assert ps._TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized") + assert ps._TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" - assert tensor_model_parallel_size <= mpu.get_tensor_model_parallel_world_size( - ), 'Not implemented for infer_tp > train_tp' + assert tensor_model_parallel_size <= mpu.get_tensor_model_parallel_world_size(), ( + "Not implemented for infer_tp > train_tp" + ) global _TENSOR_MODEL_PARALLEL_GROUP global _MICRO_DATA_PARALLEL_GROUP @@ -56,10 +58,10 @@ def initialize_model_parallel_from_megatron( rank = torch.distributed.get_rank() # Build the micro dp groups. - assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized") + assert _MICRO_DATA_PARALLEL_GROUP is None, "micro data parallel group is already initialized" for i in range(num_micro_dp_groups): ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size) - group = new_group(rank=rank, ranks=ranks, group_type='micro_dp') + group = new_group(rank=rank, ranks=ranks, group_type="micro_dp") if rank in ranks: _MICRO_DATA_PARALLEL_GROUP = group @@ -78,7 +80,7 @@ def initialize_model_parallel_from_megatron( train_tp = mpu.get_tensor_model_parallel_world_size() num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - assert _TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized") + assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): start = train_tp * i end = train_tp * (i + 1) @@ -87,7 +89,7 @@ def initialize_model_parallel_from_megatron( for i in range(len(ranks)): ranks[i] += j # group = torch.distributed.new_group(ranks) - group = new_group(rank=rank, ranks=ranks, group_type='infer_tp') + group = new_group(rank=rank, ranks=ranks, group_type="infer_tp") if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group ps._TENSOR_MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP @@ -107,7 +109,7 @@ Tensor model parallel utilities def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ("tensor model parallel group is not initialized") + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized" return _TENSOR_MODEL_PARALLEL_GROUP diff --git a/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py b/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py index b8de24afb..b2cdc3f85 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py @@ -13,20 +13,20 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py -from typing import List, Optional, Tuple, Union - -from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from typing import List, Optional +from transformers import PreTrainedTokenizer from vllm.lora.request import LoRARequest -from vllm.utils import make_async, LRUCache from vllm.transformers_utils.tokenizers import * +from vllm.utils import LRUCache class TokenizerGroup: """A group of tokenizers that can be used for LoRA adapters.""" - def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int]): + def __init__( + self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int] + ): self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = tokenizer @@ -35,17 +35,15 @@ class TokenizerGroup: else: self.lora_tokenizers = None - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + def encode( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) return tokenizer.encode(prompt) - async def encode_async(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + async def encode_async( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) return tokenizer.encode(prompt) diff --git a/verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py b/verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py index 72aa26d06..ac3cd4172 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py @@ -14,6 +14,7 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models from typing import Dict + import torch import torch.nn as nn @@ -21,10 +22,14 @@ import torch.nn as nn # NOTE(shengguangming): replace the origin weight loader function in the class def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Parallel Linear weight loader.""" - assert param.size() == loaded_weight.size( - ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format( - param.size(), loaded_weight.size()) - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + assert param.size() == loaded_weight.size(), ( + "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size() + ) + ) + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -32,7 +37,9 @@ def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tenso def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -67,11 +74,11 @@ def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: # NOTE(shengguangming): the megatron llama may have this prefix - prefix = '0.module.module.' + prefix = "0.module.module." params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): - if name[:len(prefix)] == prefix: - name = name[len(prefix):] + if name[: len(prefix)] == prefix: + name = name[len(prefix) :] if "rotary_emb.inv_freq" in name: continue else: @@ -82,11 +89,11 @@ def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: # TODO: need to implement a general way to deal with prefix - prefix = '0.module.module.' + prefix = "0.module.module." params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): - if name[:len(prefix)] == prefix: - name = name[len(prefix):] + if name[: len(prefix)] == prefix: + name = name[len(prefix) :] if "rotary_emb.inv_freq" in name: continue else: diff --git a/verl/third_party/vllm/vllm_v_0_3_1/worker.py b/verl/third_party/vllm/vllm_v_0_3_1/worker.py index 50eebd70b..2725d79ff 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/worker.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/worker.py @@ -13,27 +13,25 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py """A GPU worker class.""" -import os + import gc -from typing import Dict, List, Tuple, Optional, Union, Set +import os +from typing import Dict, List, Optional, Set, Tuple, Union import torch import torch.distributed import torch.nn as nn - -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) -from vllm.model_executor import InputMetadata, set_random_seed -from vllm.model_executor.parallel_utils.parallel_state import (initialize_model_parallel) -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.cache_engine import CacheEngine -from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar -from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_group - -from .model_runner import ModelRunner -from .model_loader import load_weights -from .parallel_state import initialize_model_parallel_from_megatron +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_group, initialize_model_parallel +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.cache_engine import CacheEngine + +from .model_loader import load_weights +from .model_runner import ModelRunner +from .parallel_state import initialize_model_parallel_from_megatron class Worker: @@ -46,7 +44,7 @@ class Worker: def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -140,8 +138,9 @@ class Worker: free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = total_gpu_memory - free_gpu_memory - cache_block_size = CacheEngine.get_cache_block_size(block_size, cache_dtype, self.model_config, - self.parallel_config) + cache_block_size = CacheEngine.get_cache_block_size( + block_size, cache_dtype, self.model_config, self.parallel_config + ) # NOTE(sgm) use the remaining memory num_gpu_blocks = int((free_gpu_memory * gpu_memory_utilization) // cache_block_size) # num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size) @@ -153,14 +152,14 @@ class Worker: gc.collect() torch.cuda.empty_cache() # Synchronize number of blocks with all the rank - num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda') - num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda') - torch.distributed.all_reduce(num_gpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group()) - torch.distributed.all_reduce(num_cpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group()) + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + torch.distributed.all_reduce( + num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group() + ) + torch.distributed.all_reduce( + num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group() + ) num_gpu_blocks = num_gpu_blocks.item() num_cpu_blocks = num_cpu_blocks.item() return num_gpu_blocks, num_cpu_blocks @@ -251,7 +250,7 @@ class Worker: if self.cpu_model == None: self.cpu_model = {} for name, params in self.model_runner.model.named_parameters(): - self.cpu_model[name] = torch.empty_like(params, device='cpu') + self.cpu_model[name] = torch.empty_like(params, device="cpu") params.data = self.cpu_model[name] else: for name, params in self.model_runner.model.named_parameters(): @@ -274,10 +273,9 @@ def _init_distributed_environment( ) -> None: """Initialize the distributed environment.""" if torch.distributed.is_initialized(): - print('The distributed environment has been initialized before vLLM') + print("The distributed environment has been initialized before vLLM") elif not distributed_init_method: - raise ValueError("distributed_init_method must be set if torch.distributed " - "is not already initialized") + raise ValueError("distributed_init_method must be set if torch.distributed is not already initialized") else: torch.distributed.init_process_group( backend="nccl", @@ -309,6 +307,8 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): compute_capability = torch.cuda.get_device_capability() if compute_capability[0] < 8: gpu_name = torch.cuda.get_device_name() - raise ValueError("Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}.") + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}." + ) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py b/verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py index 089bbd748..7ae98c862 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py @@ -13,22 +13,29 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py -import os import argparse import dataclasses +import os from dataclasses import dataclass from typing import List, Optional, Union -import torch.nn as nn - from transformers import PretrainedConfig -from .config import ModelConfig, LoadConfig - -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoRAConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + TokenizerPoolConfig, + VisionLanguageConfig, +) from vllm.utils import str_to_int_tuple +from .config import LoadConfig, ModelConfig + def nullable_str(val: str): if not val or val == "None": @@ -39,13 +46,14 @@ def nullable_str(val: str): @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model_hf_config: PretrainedConfig = None skip_tokenizer_init: bool = False served_model_name: Optional[Union[str, List[str]]] = None # TODO download_dir: Optional[str] = None - load_format: str = 'auto' - dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + load_format: str = "auto" + dtype: str = "auto" + kv_cache_dtype: str = "auto" quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None @@ -78,9 +86,9 @@ class EngineArgs: max_lora_rank: int = 16 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 - lora_dtype = 'auto' + lora_dtype = "auto" max_cpu_loras: Optional[int] = None - device: str = 'auto' + device: str = "auto" ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -94,7 +102,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False - guided_decoding_backend: str = 'outlines' + guided_decoding_backend: str = "outlines" # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None @@ -107,120 +115,140 @@ class EngineArgs: """Shared CLI arguments for vLLM engine.""" # Model arguments # TODO(shengguangming): delete the unused args - parser.add_argument('--model', - type=str, - default='facebook/opt-125m', - help='name or path of the huggingface model to use') - parser.add_argument('--tokenizer', - type=str, - default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') - parser.add_argument('--revision', - type=str, - default=None, - help='the specific model version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-revision', - type=str, - default=None, - help='the specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') - parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') - parser.add_argument('--download-dir', - type=str, - default=EngineArgs.download_dir, - help='directory to download and load the weights, ' - 'default to the default cache dir of ' - 'huggingface') - parser.add_argument('--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument('--dtype', - type=str, - default=EngineArgs.dtype, - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--max-model-len', - type=int, - default=None, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + parser.add_argument( + "--model", type=str, default="facebook/opt-125m", help="name or path of the huggingface model to use" + ) + parser.add_argument( + "--tokenizer", + type=str, + default=EngineArgs.tokenizer, + help="name or path of the huggingface tokenizer to use", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="the specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-revision", + type=str, + default=None, + help="the specific tokenizer version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + "always use the slow tokenizer.", + ) + parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface") + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, default to the default cache dir of huggingface", + ) + parser.add_argument( + "--load-format", + type=str, + default=EngineArgs.load_format, + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling.", + ) + parser.add_argument( + "--dtype", + type=str, + default=EngineArgs.dtype, + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " + 'The "auto" option will use FP16 precision ' + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, will be automatically derived from the model.", + ) # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', - '-pp', - type=int, - default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', - '-tp', - type=int, - default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", + type=int, + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32], - help='token block size') + parser.add_argument( + "--block-size", type=int, default=EngineArgs.block_size, choices=[8, 16, 32], help="token block size" + ) # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed') - parser.add_argument('--swap-space', - type=int, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='the percentage of GPU memory to be used for' - 'the model executor') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') - parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + parser.add_argument("--seed", type=int, default=EngineArgs.seed, help="random seed") + parser.add_argument( + "--swap-space", type=int, default=EngineArgs.swap_space, help="CPU swap space size (GiB) per GPU" + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=EngineArgs.gpu_memory_utilization, + help="the percentage of GPU memory to be used forthe model executor", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument("--disable-log-stats", action="store_true", help="disable logging statistics") # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', None], - default=None, - help='Method used to quantize the weights') + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", None], + default=None, + help="Method used to quantize the weights", + ) return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -232,22 +260,45 @@ class EngineArgs: ) -> EngineConfig: device_config = DeviceConfig(self.device) # NOTE(sgm): we only modify ModelConfig, other configs are import from vllm - model_config = ModelConfig(self.model_hf_config, self.dtype, self.seed, self.revision, self.code_revision, - self.tokenizer_revision, self.max_model_len, self.quantization, - self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_seq_len_to_capture, self.max_logprobs, self.skip_tokenizer_init, - self.served_model_name) - cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, - self.swap_space, self.kv_cache_dtype, self.num_gpu_blocks_override, - model_config.get_sliding_window(), self.enable_prefix_caching) + model_config = ModelConfig( + self.model_hf_config, + self.dtype, + self.seed, + self.revision, + self.code_revision, + self.tokenizer_revision, + self.max_model_len, + self.quantization, + self.quantization_param_path, + self.enforce_eager, + self.max_context_len_to_capture, + self.max_seq_len_to_capture, + self.max_logprobs, + self.skip_tokenizer_init, + self.served_model_name, + ) + cache_config = CacheConfig( + self.block_size, + self.gpu_memory_utilization, + self.swap_space, + self.kv_cache_dtype, + self.num_gpu_blocks_override, + model_config.get_sliding_window(), + self.enable_prefix_caching, + ) parallel_config = ParallelConfig( - self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, - self.max_parallel_loading_workers, self.disable_custom_all_reduce, + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, + self.disable_custom_all_reduce, TokenizerPoolConfig.create_config( self.tokenizer_pool_size, self.tokenizer_pool_type, self.tokenizer_pool_extra_config, - ), self.ray_workers_use_nsight) + ), + self.ray_workers_use_nsight, + ) # Use the world_size set by TORCHRUN world_size = int(os.getenv("WORLD_SIZE", "-1")) @@ -273,19 +324,25 @@ class EngineArgs: self.max_num_seqs, model_config.max_model_len, self.use_v2_block_manager, - num_lookahead_slots=(self.num_lookahead_slots - if speculative_config is None else speculative_config.num_lookahead_slots), + num_lookahead_slots=( + self.num_lookahead_slots if speculative_config is None else speculative_config.num_lookahead_slots + ), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, ) - lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else - None) if self.enable_lora else None + lora_config = ( + LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None, + ) + if self.enable_lora + else None + ) load_config = LoadConfig( load_format=self.load_format, @@ -294,9 +351,11 @@ class EngineArgs: ) if self.image_input_type: - if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): - raise ValueError('Specify `image_token_id`, `image_input_shape` and ' - '`image_feature_size` together with `image_input_type`.') + if not self.image_token_id or not self.image_input_shape or not self.image_feature_size: + raise ValueError( + "Specify `image_token_id`, `image_input_shape` and " + "`image_feature_size` together with `image_input_type`." + ) vision_language_config = VisionLanguageConfig( image_input_type=VisionLanguageConfig.get_image_input_enum_type(self.image_input_type), image_token_id=self.image_token_id, @@ -308,13 +367,15 @@ class EngineArgs: decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend) - return EngineConfig(model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - vision_language_config=vision_language_config, - speculative_config=speculative_config, - load_config=load_config, - decoding_config=decoding_config) + return EngineConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + ) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/config.py b/verl/third_party/vllm/vllm_v_0_4_2/config.py index 6af04417b..3a77584c2 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/config.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/config.py @@ -15,17 +15,18 @@ import enum import json +from dataclasses import dataclass, field from typing import List, Optional, Union -from dataclasses import dataclass, field, fields from transformers import PretrainedConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import get_quantization_config -from vllm.transformers_utils.config import get_hf_text_config -from vllm.utils import is_hip # Add for verl from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.transformers_utils.config import get_hf_text_config +from vllm.utils import is_hip GPTQMarlinConfig = get_quantization_config("gptq_marlin") @@ -90,8 +91,8 @@ class ModelConfig(ModelConfig): skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, the model name will be the same as `model`. """ @@ -124,9 +125,8 @@ class ModelConfig(ModelConfig): self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture if self.max_context_len_to_capture is not None: - raise ValueError("`max_context_len_to_capture` is deprecated. " - "Use `max_seq_len_to_capture` instead.") - self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture) + raise ValueError("`max_context_len_to_capture` is deprecated. Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = max_seq_len_to_capture or max_context_len_to_capture self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -145,35 +145,35 @@ class ModelConfig(ModelConfig): class LoadFormat(str, enum.Enum): - AUTO = 'auto' + AUTO = "auto" MEGATRON = "megatron" HF = "hf" - DTENSOR = 'dtensor' - DUMMY_HF = 'dummy_hf' - DUMMY_MEGATRON = 'dummy_megatron' - DUMMY_DTENSOR = 'dummy_dtensor' + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" @dataclass class LoadConfig: """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. """ - load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + load_format: Union[str, LoadFormat, BaseModelLoader] = LoadFormat.AUTO download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) @@ -195,6 +195,8 @@ class LoadConfig: rocm_supported_load_format = [ f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) ] - raise ValueError(f"load format '{load_format}' is not supported in ROCm. " - f"Supported load formats are " - f"{rocm_supported_load_format}") + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}" + ) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py index 6668b7509..186f74616 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py @@ -13,13 +13,11 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models -from typing import Dict, Iterable, Tuple -import torch -import torch.nn as nn -from torch.distributed._tensor import DTensor, Shard, Replicate +from typing import Dict +import torch.nn as nn +from torch.distributed._tensor import DTensor from vllm.model_executor.layers.linear import * -from vllm.model_executor.models import ModelRegistry from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -35,7 +33,7 @@ def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue stacked_name = name.replace(shard_name, param_name) @@ -99,7 +97,7 @@ def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -130,7 +128,7 @@ def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n for name, loaded_weight in actor_weights.items(): if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -139,7 +137,7 @@ def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n # processed with quantization, LoRA, fine-tuning, etc. if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -176,7 +174,7 @@ def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n continue if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -205,11 +203,13 @@ def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): param_name = _process_parameter_names(name=param_name) if parallelize_plan is not None: - assert param_name in parallelize_plan.keys(), \ + assert param_name in parallelize_plan, ( f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + ) placement = parallelize_plan[param_name] - local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, - placements=placement).to_local() + local_loaded_weights = loaded_weights.redistribute( + device_mesh=loaded_weights.device_mesh, placements=placement + ).to_local() else: local_loaded_weights = loaded_weights.full_tensor() return local_loaded_weights @@ -222,9 +222,9 @@ def _process_parameter_names(name): # Remove 'model.layers.x.' or 'model.' prefix if "model.layers" in name: - parts = name.split('.') + parts = name.split(".") # Reconstruct the string without 'model.layers.x.' - name = '.'.join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' elif name.startswith("model."): name = name[6:] # Remove 'model.' @@ -232,18 +232,18 @@ def _process_parameter_names(name): __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { - 'GPT2LMHeadModel': gpt2_dtensor_weight_loader, - 'LlamaForCausalLM': llama_dtensor_weight_loader, - 'LLaMAForCausalLM': llama_dtensor_weight_loader, - 'MistralForCausalLM': llama_dtensor_weight_loader, # mistral is the same as llama in vLLM - 'InternLMForCausalLM': llama_dtensor_weight_loader, - 'AquilaModel': llama_dtensor_weight_loader, - 'AquilaForCausalLM': llama_dtensor_weight_loader, - 'Phi3ForCausalLM': llama_dtensor_weight_loader, - 'GemmaForCausalLM': gemma_dtensor_weight_loader, - 'GPTBigCodeForCausalLM': gptbigcode_dtensor_load_weights, - 'Starcoder2ForCausalLM': starcoder2_dtensor_load_weights, - 'Qwen2ForCausalLM': qwen2_dtensor_weight_loader + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, } @@ -260,8 +260,10 @@ def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): def _get_model_weight_loader(arch: str): if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}" + ) # NOTE(sgm): we use per-parameter weight loader in each vllm sub diff --git a/verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py index 0d562e596..fcbd525f2 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py @@ -13,17 +13,17 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models -from typing import Dict, Union, Optional, Iterable, Tuple +from typing import Dict, Iterable, Tuple import torch import torch.nn as nn - from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader def update_hf_weight_loader(): from vllm.model_executor.models.gemma import GemmaForCausalLM + GemmaForCausalLM.load_weights = gemma_load_weights @@ -39,7 +39,7 @@ def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) loaded_params = set() for name, loaded_weight in weights: - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -72,8 +72,7 @@ def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - raise RuntimeError("Some weights are not initialized from checkpoints: " - f"{unloaded_params}") + raise RuntimeError(f"Some weights are not initialized from checkpoints: {unloaded_params}") def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): diff --git a/verl/third_party/vllm/vllm_v_0_4_2/llm.py b/verl/third_party/vllm/vllm_v_0_4_2/llm.py index 9e5970697..17622cc0a 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/llm.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/llm.py @@ -13,24 +13,24 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py -from typing import Dict, List, Optional, Tuple, Union, Iterable +from typing import Dict, Iterable, List, Optional, Tuple, Union -from tqdm import tqdm -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers import PretrainedConfig +import torch import torch.nn as nn -from .arg_utils import EngineArgs -from .llm_engine_sp import LLMEngine +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData -from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter -import torch -from torch.nn.utils.rnn import pad_sequence + from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -88,7 +88,7 @@ class LLM: def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], model_hf_config: PretrainedConfig, tokenizer_mode: str = "auto", @@ -104,7 +104,7 @@ class LLM: enforce_eager: bool = False, max_context_len_to_capture: int = None, disable_custom_all_reduce: bool = False, - load_format = 'auto', + load_format="auto", **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -167,9 +167,9 @@ class LLM: Args: prompts: A list of prompts to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - When it is a single value, it is applied to every prompt. - When it is a list, the list must have the same length as the + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. @@ -182,18 +182,14 @@ class LLM: completions in the same order as the input prompts. """ if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if self.llm_engine.model_config.skip_tokenizer_init \ - and prompts is not None: - raise ValueError("prompts must be None if skip_tokenizer_init " - "is True") + raise ValueError("Either prompts or prompt_token_ids must be provided.") + if self.llm_engine.model_config.skip_tokenizer_init and prompts is not None: + raise ValueError("prompts must be None if skip_tokenizer_init is True") if isinstance(prompts, str): # Convert a single prompt to a list. prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if prompts is not None and prompt_token_ids is not None and len(prompts) != len(prompt_token_ids): + raise ValueError("The lengths of prompts and prompt_token_ids must be the same.") if prompts is not None: num_requests = len(prompts) @@ -206,8 +202,7 @@ class LLM: sampling_params = SamplingParams() elif isinstance(sampling_params, list) and len(sampling_params) != num_requests: - raise ValueError("The lengths of prompts and sampling_params " - "must be the same.") + raise ValueError("The lengths of prompts and sampling_params must be the same.") if multi_modal_data: multi_modal_data.data = multi_modal_data.data.to(torch.float16) @@ -225,7 +220,8 @@ class LLM: lora_request=lora_request, # Get ith image while maintaining the batch dim. multi_modal_data=MultiModalData(type=multi_modal_data.type, data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, + if multi_modal_data + else None, ) return self._run_engine(use_tqdm) @@ -238,12 +234,14 @@ class LLM: multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + self.llm_engine.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. @@ -272,7 +270,11 @@ class LLM: # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: # remove the left padding in the prompt token_id - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + pad_token_id = ( + self.llm_engine.tokenizer.pad_token_id + if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id + ) non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids @@ -293,7 +295,11 @@ class LLM: logprob.append(logprobs_dict[id].logprob) logprobs.append(torch.tensor(logprob)) - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + pad_token_id = ( + self.llm_engine.tokenizer.pad_token_id + if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id + ) output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) if len(logprobs) > 0: logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py index 2471ce168..21ac7bfe6 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py @@ -13,28 +13,34 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py -import torch -from typing import Dict, Optional, Union, Type, Iterable +from typing import Dict, Iterable, Optional, Type, Union +import torch.nn as nn import vllm -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + LoRAConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VisionLanguageConfig, +) from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.interfaces import (SequenceGroupOutputProcessor) +from vllm.engine.llm_engine import LLMEngine, _load_generation_config_dict +from vllm.engine.metrics import StatLogger +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.engine.metrics import StatLogger -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message from vllm.utils import Counter -from vllm.engine.llm_engine import _load_generation_config_dict -from vllm.engine.llm_engine import LLMEngine -import torch.nn as nn from .arg_utils import EngineArgs +from .config import LoadConfig, ModelConfig from .tokenizer import TokenizerGroup -from .config import ModelConfig, LoadConfig logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -74,7 +80,7 @@ class LLMEngine(LLMEngine): def __init__( self, # NOTE(sgm): first two arguments are added for verl - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict tokenizer: nn.Module, # NOTE(sgm): vllm original arguments model_config: ModelConfig, @@ -154,7 +160,7 @@ class LLMEngine(LLMEngine): self.generation_config_fields = _load_generation_config_dict(model_config) self.model_executor = executor_class( - model=model, # add for spmd_gpu_executor + model=model, # add for spmd_gpu_executor model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -171,7 +177,8 @@ class LLMEngine(LLMEngine): # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import (get_architecture_class_name) + from vllm.model_executor.model_loader import get_architecture_class_name + usage_message.report_usage( get_architecture_class_name(model_config), usage_context, @@ -181,17 +188,16 @@ class LLMEngine(LLMEngine): "tensor_parallel_size": parallel_config.tensor_parallel_size, "block_size": cache_config.block_size, "gpu_memory_utilization": cache_config.gpu_memory_utilization, - # Quantization "quantization": model_config.quantization, "kv_cache_dtype": cache_config.cache_dtype, - # Feature flags "enable_lora": bool(lora_config), "enable_prefix_caching": cache_config.enable_prefix_caching, "enforce_eager": model_config.enforce_eager, "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, - }) + }, + ) if self.tokenizer: # Ping the tokenizer to ensure liveness if it runs in a @@ -206,14 +212,16 @@ class LLMEngine(LLMEngine): # Metric Logging. if self.log_stats: - self.stat_logger = StatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len) + self.stat_logger = StatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ) self.stat_logger.info("cache_config", self.cache_config) # Create sequence output processor, e.g. for beam search or # speculative decoding. - self.output_processor = (SequenceGroupOutputProcessor.create_output_processor( + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( self.scheduler_config, self.detokenizer, self.scheduler, @@ -223,13 +231,13 @@ class LLMEngine(LLMEngine): self.scheduler_config.max_model_len, self.get_tokenizer_for_seq, ), - )) + ) # TODO(sgm): add for verl but we may not tokenizer in Rollout def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): - init_kwargs = dict(enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None) + init_kwargs = dict( + enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None + ) init_kwargs.update(tokenizer_init_kwargs) self.tokenizer: TokenizerGroup = TokenizerGroup(tokenizer, **init_kwargs) @@ -256,13 +264,15 @@ class LLMEngine(LLMEngine): engine_config = engine_args.create_engine_config() # Initialize the cluster and specify the executor class. - assert engine_config.device_config.device_type == "cuda", \ + assert engine_config.device_config.device_type == "cuda", ( "Currently, the vllm in verl only support running on GPU" + ) if engine_config.parallel_config.world_size == 1: engine_config.load_config.load_format = "dummy_hf" from .spmd_gpu_executor import SPMDGPUExecutor + executor_class = SPMDGPUExecutor # Create the LLM engine. diff --git a/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py index 4820c2c07..97373d697 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py @@ -14,22 +14,25 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models from typing import Dict, Iterable + import torch import torch.nn as nn - from vllm.model_executor.layers.linear import * -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead -from vllm.model_executor.layers.activation import ScaledActivation +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding from vllm.model_executor.models import ModelRegistry # NOTE(shengguangming): replace the origin weight loader function in the class def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Parallel Linear weight loader.""" - assert param.size() == loaded_weight.size( - ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format( - param.size(), loaded_weight.size()) - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + assert param.size() == loaded_weight.size(), ( + "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size() + ) + ) + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -37,7 +40,9 @@ def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tenso def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -90,20 +95,20 @@ def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Mod ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'), - ('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -118,22 +123,22 @@ def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module # (megatron core gpt model name, vllm model name) ("embedding.word_embeddings", "model.embed_tokens"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), + ("self_attention.linear_proj", "self_attn.o_proj"), ( - 'input_layernorm', - 'input_layernorm', + "input_layernorm", + "input_layernorm", ), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -147,19 +152,19 @@ def _replace_name(megatron_name, name_mapping): for m_name, v_name in name_mapping: if m_name not in megatron_name: continue - if 'layers' in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace('decoder', 'model') - megatron_name_list = megatron_name.split('.') - if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list: + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: param_name_list = megatron_name_list[:3] param_name_list.append(v_name) - param_name = '.'.join(param_name_list) + param_name = ".".join(param_name_list) else: param_name_list = megatron_name_list[:3] weight_or_bias = megatron_name_list[-1] param_name_list.append(v_name) param_name_list.append(weight_or_bias) - param_name = '.'.join(param_name_list) + param_name = ".".join(param_name_list) return param_name else: param_name = megatron_name.replace(m_name, v_name) @@ -174,20 +179,20 @@ def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Mod ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'), - ('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -202,22 +207,22 @@ def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module # (megatron core gpt model name, vllm model name) ("embedding.word_embeddings", "model.embed_tokens"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), + ("self_attention.linear_proj", "self_attn.o_proj"), ( - 'input_layernorm', - 'input_layernorm', + "input_layernorm", + "input_layernorm", ), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -245,7 +250,7 @@ __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { QKVParallelLinear: parallel_weight_loader, RowParallelLinear: parallel_weight_loader, VocabParallelEmbedding: parallel_weight_loader, - ParallelLMHead: parallel_weight_loader + ParallelLMHead: parallel_weight_loader, # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights # "default_weight_loader": default_weight_loader } @@ -255,10 +260,10 @@ __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { # layer_class.weight_loader = weight_loader __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { - 'GPT2LMHeadModel': gpt2_weight_loader, - 'LlamaForCausalLM': llama_megatron_core_te_weight_loader, # use te backend for open-source megatron - 'LLaMAForCausalLM': llama_megatron_core_te_weight_loader, - 'MistralForCausalLM': mistral_megatron_weight_loader, + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_megatron_core_te_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_core_te_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, } @@ -275,8 +280,10 @@ def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): def _get_model_weight_loader(arch: str): if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) def update_megatron_weight_loader(): @@ -290,12 +297,14 @@ def update_megatron_weight_loader(): DEFAULT_VOCAB_PADDING_SIZE = 64 -def vocab_init(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): +def vocab_init( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, +): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. @@ -313,13 +322,17 @@ def vocab_init(self, # TODO: remove dependencies from megatron from megatron.core.tensor_parallel.utils import VocabUtility - self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size)) - self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index) + + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index self.weight = Parameter( torch.empty( self.num_embeddings_per_partition, self.embedding_dim, # device=torch.cuda.current_device(), - dtype=params_dtype)) + dtype=params_dtype, + ) + ) set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader}) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/model_loader.py b/verl/third_party/vllm/vllm_v_0_4_2/model_loader.py index 5f4013451..22c73af8d 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/model_loader.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/model_loader.py @@ -13,43 +13,54 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader """Utilities for selecting and loading models.""" -from typing import Dict, Union, Optional, Iterable, Tuple + +from typing import Dict, Optional, Union import torch import torch.nn as nn from transformers import PreTrainedModel - -from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig +from vllm.distributed.communication_op import tensor_model_parallel_all_gather from vllm.model_executor.model_loader import BaseModelLoader from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.distributed.communication_op import tensor_model_parallel_all_gather -from .config import ModelConfig, LoadFormat, LoadConfig -from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader +from .config import LoadConfig, LoadFormat, ModelConfig from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader from .hf_weight_loader import update_hf_weight_loader +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader -def get_model(actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], +) -> nn.Module: loader = get_model_loader(load_config) - if load_config.load_format.startswith('dummy'): - return loader.load_model(model_config=model_config, - device_config=device_config, - lora_config=lora_config, - vision_language_config=vision_language_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config) + if load_config.load_format.startswith("dummy"): + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + ) else: - return loader.load_model(actor_model=actor_model, - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - vision_language_config=vision_language_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config) + return loader.load_model( + actor_model=actor_model, + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + ) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: @@ -87,8 +98,11 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: update_dtensor_weight_loader() return DummyModelLoader(load_config) - raise ValueError('load format not supported in verl: {}, only support {} and {}'.format( - load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + raise ValueError( + "load format not supported in verl: {}, only support {} and {}".format( + load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF + ) + ) class DummyModelLoader(BaseModelLoader): @@ -97,12 +111,18 @@ class DummyModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config) @@ -118,8 +138,7 @@ class MegatronLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): # NOTE(shengguangming) Load the weights from the actor model @@ -130,18 +149,25 @@ class MegatronLoader(BaseModelLoader): # load_weights(actor_weights=actor_model, vllm_model=model) # return actor_model - def load_model(self, actor_model: Union[PreTrainedModel, - Dict], model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], - parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config) # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm if isinstance(actor_model, nn.Module): - load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), - vllm_model=model) + load_megatron_weights( + actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model + ) else: load_megatron_weights(actor_weights=actor_model, vllm_model=model) @@ -164,8 +190,7 @@ class HFLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): if isinstance(actor_model, Dict): @@ -173,12 +198,18 @@ class HFLoader(BaseModelLoader): elif isinstance(actor_model, nn.Module): return dict(actor_model.named_parameters()).items() else: - raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}') + raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") - def load_model(self, actor_model: Union[PreTrainedModel, - Dict], model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], - parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): # with torch.device(device_config.device): # NOTE(sgm): init the model in cpu @@ -203,8 +234,7 @@ class DTensorLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): # NOTE(shengguangming) Load the weights from the actor model @@ -215,18 +245,25 @@ class DTensorLoader(BaseModelLoader): # load_weights(actor_weights=actor_model, vllm_model=model) # return actor_model - def load_model(self, actor_model: Union[PreTrainedModel, - Dict], model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], - parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config) # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm if isinstance(actor_model, nn.Module): - load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), - vllm_model=model) + load_dtensor_weights( + actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model + ) else: load_dtensor_weights(actor_weights=actor_model, vllm_model=model) @@ -247,8 +284,9 @@ class DTensorLoader(BaseModelLoader): # as they use ray, the _get_logits result will only need to return to the driver node, # therefore gather is enough. However, we use SPMD instead of a central scheduler, # all_gather is required (aligned with v0.2.6) -def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: +def _get_logits( + self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor] +) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: @@ -256,7 +294,7 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[:, :self.org_vocab_size] + logits = logits[:, : self.org_vocab_size] return logits diff --git a/verl/third_party/vllm/vllm_v_0_4_2/model_runner.py b/verl/third_party/vllm/vllm_v_0_4_2/model_runner.py index 1604b0363..90108d624 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/model_runner.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/model_runner.py @@ -13,24 +13,24 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py -import torch -import torch.nn as nn from enum import IntEnum from typing import Dict, List, Optional, Set, Tuple, Union -from vllm.attention import (AttentionMetadata, get_attn_backend) -from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +import torch +import torch.nn as nn +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata -from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available) -from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import CudaMemoryProfiler, is_hip, is_pin_memory_available +from vllm.worker.model_runner import CUDAGraphRunner, ModelRunner +from .config import LoadConfig, ModelConfig from .model_loader import get_model -from .config import ModelConfig, LoadConfig logger = init_logger(__name__) @@ -46,10 +46,9 @@ class BatchType(IntEnum): class ModelRunner(ModelRunner): - def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -67,8 +66,8 @@ class ModelRunner(ModelRunner): # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) - self.device_config = (device_config if device_config is not None else DeviceConfig()) + self.sliding_window = model_config.get_sliding_window() if model_config is not None else None + self.device_config = device_config if device_config is not None else DeviceConfig() self.device = self.device_config.device # NOTE(sgm): add for verl @@ -80,7 +79,7 @@ class ModelRunner(ModelRunner): self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool: Optional[Tuple[int, int]] = None # Set during graph capture. - self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture if self.model_config is not None else 0) + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture if self.model_config is not None else 0 self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -104,26 +103,34 @@ class ModelRunner(ModelRunner): # NOTE(sgm): initialize model using the actor model def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model(actor_model=self.model, - model_config=self.model_config, - device_config=self.device_config, - lora_config=self.lora_config, - load_config=self.load_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - vision_language_config=self.vision_language_config) + self.model = get_model( + actor_model=self.model, + model_config=self.model_config, + device_config=self.device_config, + lora_config=self.lora_config, + load_config=self.load_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + vision_language_config=self.vision_language_config, + ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) if self.lora_config: assert hasattr(self.model, "supported_lora_modules") and self.model.supported_lora_modules, ( - "Model does not support LoRA") + "Model does not support LoRA" + ) assert hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules" assert hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules" - self.lora_manager = LRUCacheWorkerLoRAManager(self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): @@ -134,21 +141,28 @@ class ModelRunner(ModelRunner): else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", self.model.__class__) + "model %s does not support loading scaling factors.", + self.model.__class__, + ) else: - logger.warning("Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) elif self.model_config.quantization_param_path is not None: - logger.warning("KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used.") + logger.warning( + "KV cache scaling factors provided, " + "but the KV cache data type is not FP8. " + "KV cache scaling factors will not be used." + ) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, - torch.Tensor]: + ) -> Tuple[ + torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor + ]: # NOTE(sgm): all workers prepare the input in the same way prefill_reqs = [] decode_reqs = [] @@ -180,8 +194,9 @@ class ModelRunner(ModelRunner): decode_lora_requests, decode_slot_mapping, ) = self._prepare_decode(decode_reqs) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory + ) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -214,7 +229,7 @@ class ModelRunner(ModelRunner): # Broadcast the metadata. # If batch contains both prefill and decode, it sends 2 broadcasts. # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None and decode_attn_metadata is not None): + if prefill_attn_metadata is not None and decode_attn_metadata is not None: batch_type = BatchType.MIXED elif prefill_attn_metadata is not None: batch_type = BatchType.PREFILL @@ -231,8 +246,15 @@ class ModelRunner(ModelRunner): kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_input) + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + lora_requests, + lora_mapping, + multi_modal_input, + ) @torch.inference_mode() def execute_model( @@ -240,8 +262,15 @@ class ModelRunner(ModelRunner): seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_input) = self.prepare_input_tensors(seq_group_metadata_list) + ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + lora_requests, + lora_mapping, + multi_modal_input, + ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py b/verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py index 9336089ed..a13b08dcc 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py @@ -4,17 +4,15 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Model and data parallel groups.""" + import os -import torch -import torch.distributed from typing import Optional +import torch +import torch.distributed import vllm.distributed.parallel_state as ps - -import vllm.envs as envs -from vllm.logger import init_logger - from torch.distributed.device_mesh import init_device_mesh +from vllm.logger import init_logger logger = init_logger(__name__) """ @@ -57,8 +55,10 @@ def initialize_parallel_state( ps.init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) if torch.distributed.get_world_size() > 1: # NOTE: build a sepearate inference group with infer tp & micro dp - initialize_model_parallel_for_vllm(tensor_model_parallel_size=tensor_model_parallel_size, - num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp) + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) else: initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) @@ -78,10 +78,11 @@ def ensure_model_parallel_initialized( initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) return - assert (get_tensor_model_parallel_world_size() == tensor_model_parallel_size), ( + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( "tensor parallel group already initialized, but of unexpected size: " f"{get_tensor_model_parallel_world_size()=} vs. " - f"{tensor_model_parallel_size=}") + f"{tensor_model_parallel_size=}" + ) # assert (get_pipeline_model_parallel_world_size( # ) == pipeline_model_parallel_size), ( # "pipeline parallel group already initialized, but of unexpected size: " @@ -91,13 +92,13 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (ps._TP_DEVICE_GROUP is not None) + return ps._TP_DEVICE_GROUP is not None # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) -def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, - num_tensor_model_parallel_groups_per_train_tp: int = 1) -> None: - from torch.distributed import new_group +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, num_tensor_model_parallel_groups_per_train_tp: int = 1 +) -> None: # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() @@ -107,7 +108,7 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group # Build the tensor model-parallel groups. - assert ps._TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized") + assert ps._TP_DEVICE_GROUP is None, "tensor model parallel group is already initialized" global _TP_DEVICE_GROUP global _TP_CPU_GROUP @@ -144,7 +145,7 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, # train_tp = train_tensor_parallel_size train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size - assert _TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized") + assert _TP_DEVICE_GROUP is None, "tensor model parallel group is already initialized" for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): start = train_tp * i end = train_tp * (i + 1) @@ -153,7 +154,7 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, for i in range(len(ranks)): ranks[i] += j group = torch.distributed.new_group(ranks) - cpu_group = torch.distributed.new_group(ranks, backend='gloo') + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group @@ -177,7 +178,7 @@ def initialize_model_parallel( """ NOTE: This method is a hack from the open-sourced version without asertion of world_size = tp * pp - + Initialize model parallel groups. Arguments: @@ -208,14 +209,17 @@ def initialize_model_parallel( # NOTE(sgm) we don't assert world_size == tp * pp # DP is not managed by vllm but by the verl WorkerGroup - num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) - num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size rank = torch.distributed.get_rank() # Build device mesh for TP if num_tensor_model_parallel_groups > 1: - device_mesh = init_device_mesh("cuda", (num_tensor_model_parallel_groups, tensor_model_parallel_size), - mesh_dim_names=("replicate", "tp_shard")) + device_mesh = init_device_mesh( + "cuda", + (num_tensor_model_parallel_groups, tensor_model_parallel_size), + mesh_dim_names=("replicate", "tp_shard"), + ) else: device_mesh = init_device_mesh("cuda", (tensor_model_parallel_size,), mesh_dim_names=["tp_shard"]) shard_group = device_mesh.get_group(mesh_dim="tp_shard") @@ -223,8 +227,8 @@ def initialize_model_parallel( # Build the tensor model-parallel groups. global _TP_DEVICE_GROUP, _TP_CPU_GROUP global _DEVICE_MESH - assert _TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized") - assert _DEVICE_MESH is None, ("device mesh in vllm is already initialized") + assert _TP_DEVICE_GROUP is None, "tensor model parallel group is already initialized" + assert _DEVICE_MESH is None, "device mesh in vllm is already initialized" _DEVICE_MESH = device_mesh # for i in range(num_tensor_model_parallel_groups): @@ -246,7 +250,7 @@ def initialize_model_parallel( # TODO: init using device mesh # Build the pipeline model-parallel groups. - assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = torch.distributed.new_group(ranks, backend=backend) @@ -261,7 +265,7 @@ Device mesh utilities def get_device_mesh(): - assert _DEVICE_MESH is not None, ("device mesh is not initialized") + assert _DEVICE_MESH is not None, "device mesh is not initialized" return _DEVICE_MESH @@ -272,7 +276,7 @@ Tensor model parallel utilities def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TP_DEVICE_GROUP is not None, ("tensor model parallel group is not initialized") + assert _TP_DEVICE_GROUP is not None, "tensor model parallel group is not initialized" return _TP_DEVICE_GROUP diff --git a/verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py index c669cf590..075b47049 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py @@ -14,18 +14,24 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py import os import socket -from typing import Any, Dict, List, Optional, Set, Tuple, Iterable +from typing import Iterable, List, Optional, Set, Tuple import torch -import vllm.envs as envs -from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VisionLanguageConfig, +) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) -from .config import ModelConfig, LoadConfig +from .config import LoadConfig, ModelConfig logger = init_logger(__name__) @@ -35,7 +41,7 @@ class SPMDGPUExecutor(ExecutorBase): def __init__( self, - model, # pytorch model itself or its parameter dict + model, # pytorch model itself or its parameter dict model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig, @@ -61,7 +67,7 @@ class SPMDGPUExecutor(ExecutorBase): # TODO(sgm): verl not support speculative decode now def _init_executor(self, model, distributed_init_method) -> None: - assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend." + assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." # Create the parallel worker for each GPU. self._init_workers_sp(model, distributed_init_method) @@ -69,11 +75,11 @@ class SPMDGPUExecutor(ExecutorBase): def _init_workers_sp(self, model, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from .worker import Worker # pylint: disable=import-outside-toplevel + from .worker import Worker rank = int(os.getenv("RANK")) local_rank = int(os.getenv("LOCAL_RANK")) - print(f'local rank {local_rank}') + print(f"local rank {local_rank}") self.worker = Worker( model, @@ -115,8 +121,7 @@ class SPMDGPUExecutor(ExecutorBase): return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ + """Initialize the KV cache in all workers.""" # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors @@ -128,12 +133,12 @@ class SPMDGPUExecutor(ExecutorBase): if torch.distributed.get_rank() == 0: print( - f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB' + f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" ) self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) if torch.distributed.get_rank() == 0: print( - f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB' + f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" ) # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache @@ -195,7 +200,7 @@ def initialize_cluster( # We need to setup the distributed init method to make sure # the distributed megatron code (e.g., get world size) works correctly. # distributed_init_method = f"tcp://localhost:{port}" - distributed_init_method = 'env://' + distributed_init_method = "env://" return distributed_init_method @@ -207,7 +212,6 @@ def get_open_port(): # TODO(sgm): not implemented async executor yet class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): - async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py b/verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py index aa625a033..d938e6a05 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py @@ -13,20 +13,20 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py -from typing import List, Optional, Tuple, Union - -from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from typing import List, Optional +from transformers import PreTrainedTokenizer from vllm.lora.request import LoRARequest -from vllm.utils import make_async, LRUCache from vllm.transformers_utils.tokenizers import * +from vllm.utils import LRUCache class TokenizerGroup: """A group of tokenizers that can be used for LoRA adapters.""" - def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int]): + def __init__( + self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int] + ): self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = tokenizer @@ -40,17 +40,15 @@ class TokenizerGroup: """Get the maximum input length for the LoRA request.""" return self.max_input_length - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + def encode( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) return tokenizer.encode(prompt) - async def encode_async(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + async def encode_async( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) return tokenizer.encode(prompt) diff --git a/verl/third_party/vllm/vllm_v_0_4_2/worker.py b/verl/third_party/vllm/vllm_v_0_4_2/worker.py index 1fab3e41f..877bd82f3 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/worker.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/worker.py @@ -13,30 +13,29 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py """A GPU worker class.""" -import os + import gc -from typing import Dict, List, Tuple, Optional, Union +import os +from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed import torch.nn as nn +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, ExecuteModelRequest -from vllm.worker.cache_engine import CacheEngine -from vllm.distributed.device_communicators import pynccl_utils -from vllm.distributed.device_communicators.custom_all_reduce import (init_custom_ar) # TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state -from vllm.distributed import get_tensor_model_parallel_cpu_group, init_distributed_environment, get_tensor_model_parallel_group +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype -from .model_runner import ModelRunner -from .megatron_weight_loaders import load_megatron_weights -from .hf_weight_loader import load_hf_weights +from .config import LoadConfig, LoadFormat, ModelConfig from .dtensor_weight_loaders import load_dtensor_weights -from .parallel_state import (ensure_model_parallel_initialized) -from .config import ModelConfig, LoadConfig, LoadFormat +from .hf_weight_loader import load_hf_weights +from .megatron_weight_loaders import load_megatron_weights +from .model_runner import ModelRunner +from .parallel_state import ensure_model_parallel_initialized class Worker(Worker): @@ -49,7 +48,7 @@ class Worker(Worker): def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -80,7 +79,7 @@ class Worker(Worker): self.vision_language_config = vision_language_config if self.vision_language_config: - assert not self.lora_config, ("To be tested: vision language model with LoRA settings.") + assert not self.lora_config, "To be tested: vision language model with LoRA settings." self.model_runner = ModelRunner( model, @@ -132,8 +131,9 @@ class Worker(Worker): raise RuntimeError(f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method, self.local_rank + ) # Set random seed. set_random_seed(self.model_config.seed) # self.model = get_model(actor_model=self.model, model_config=self.model_config) @@ -166,8 +166,10 @@ class Worker(Worker): free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = total_gpu_memory - free_gpu_memory - assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + assert peak_memory > 0, ( + "Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance." + ) cache_block_size = self.get_cache_block_size_bytes() @@ -182,14 +184,14 @@ class Worker(Worker): self.model_runner.remove_all_loras() # NOTE(sgm): Add for verl, synchronize number of blocks with all the rank - num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda') - num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda') - torch.distributed.all_reduce(num_gpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group()) - torch.distributed.all_reduce(num_cpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group()) + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + torch.distributed.all_reduce( + num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group() + ) + torch.distributed.all_reduce( + num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group() + ) num_gpu_blocks = num_gpu_blocks.item() num_cpu_blocks = num_cpu_blocks.item() gc.collect() @@ -207,7 +209,6 @@ class Worker(Worker): @torch.inference_mode() def execute_model(self, execute_model_req: Optional[ExecuteModelRequest] = None) -> List[SamplerOutput]: - if execute_model_req is None: seq_group_metadata_list = None else: @@ -247,7 +248,7 @@ class Worker(Worker): if self.cpu_model == None: self.cpu_model = {} for name, params in self.model_runner.model.named_parameters(): - self.cpu_model[name] = torch.empty_like(params, device='cpu') + self.cpu_model[name] = torch.empty_like(params, device="cpu") params.data = self.cpu_model[name] else: for name, params in self.model_runner.model.named_parameters(): @@ -264,8 +265,10 @@ def init_worker_distributed_environment( # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - ensure_model_parallel_initialized(tensor_model_parallel_size=parallel_config.tensor_parallel_size, - pipeline_model_parallel_size=parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) # TODO(sgm): check whether need this # if pynccl_utils.is_initialized(): diff --git a/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py b/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py index 6c577277b..33fca600c 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py @@ -13,29 +13,34 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py -import os import argparse import dataclasses -import json +import os from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union -import torch.nn as nn - from transformers import PretrainedConfig -from .config import ModelConfig, LoadConfig - -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig) +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoRAConfig, + MultiModalConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, + TokenizerPoolConfig, +) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.utils import FlexibleArgumentParser -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import str_to_int_tuple + +from .config import LoadConfig, ModelConfig if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (BaseTokenizerGroup) + from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup logger = init_logger(__name__) @@ -49,16 +54,17 @@ def nullable_str(val: str): @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model_hf_config: PretrainedConfig = None # for verl served_model_name = None # TODO(sgm): check this # tokenizer: Optional[str] = None # TODO(sgm): check this skip_tokenizer_init: bool = False - tokenizer_mode: str = 'auto' + tokenizer_mode: str = "auto" trust_remote_code: bool = False download_dir: Optional[str] = None - load_format: str = 'auto' - dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + load_format: str = "auto" + dtype: str = "auto" + kv_cache_dtype: str = "auto" quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None @@ -106,9 +112,9 @@ class EngineArgs: fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: str = 'auto' + lora_dtype: str = "auto" max_cpu_loras: Optional[int] = None - device: str = 'auto' + device: str = "auto" ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -119,7 +125,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: Optional[bool] = None - guided_decoding_backend: str = 'outlines' + guided_decoding_backend: str = "outlines" # Speculative decoding configuration. speculative_model: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None @@ -128,7 +134,7 @@ class EngineArgs: speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - spec_decoding_acceptance_method: str = 'rejection_sampler' + spec_decoding_acceptance_method: str = "rejection_sampler" typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None qlora_adapter_name_or_path: Optional[str] = None @@ -141,120 +147,140 @@ class EngineArgs: """Shared CLI arguments for vLLM engine.""" # Model arguments # TODO(shengguangming): delete the unused args - parser.add_argument('--model', - type=str, - default='facebook/opt-125m', - help='name or path of the huggingface model to use') - parser.add_argument('--tokenizer', - type=str, - default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') - parser.add_argument('--revision', - type=str, - default=None, - help='the specific model version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-revision', - type=str, - default=None, - help='the specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') - parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') - parser.add_argument('--download-dir', - type=str, - default=EngineArgs.download_dir, - help='directory to download and load the weights, ' - 'default to the default cache dir of ' - 'huggingface') - parser.add_argument('--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument('--dtype', - type=str, - default=EngineArgs.dtype, - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--max-model-len', - type=int, - default=None, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + parser.add_argument( + "--model", type=str, default="facebook/opt-125m", help="name or path of the huggingface model to use" + ) + parser.add_argument( + "--tokenizer", + type=str, + default=EngineArgs.tokenizer, + help="name or path of the huggingface tokenizer to use", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="the specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-revision", + type=str, + default=None, + help="the specific tokenizer version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + "always use the slow tokenizer.", + ) + parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface") + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, default to the default cache dir of huggingface", + ) + parser.add_argument( + "--load-format", + type=str, + default=EngineArgs.load_format, + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling.", + ) + parser.add_argument( + "--dtype", + type=str, + default=EngineArgs.dtype, + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " + 'The "auto" option will use FP16 precision ' + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, will be automatically derived from the model.", + ) # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', - '-pp', - type=int, - default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', - '-tp', - type=int, - default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", + type=int, + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32], - help='token block size') + parser.add_argument( + "--block-size", type=int, default=EngineArgs.block_size, choices=[8, 16, 32], help="token block size" + ) # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed') - parser.add_argument('--swap-space', - type=int, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='the percentage of GPU memory to be used for' - 'the model executor') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') - parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + parser.add_argument("--seed", type=int, default=EngineArgs.seed, help="random seed") + parser.add_argument( + "--swap-space", type=int, default=EngineArgs.swap_space, help="CPU swap space size (GiB) per GPU" + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=EngineArgs.gpu_memory_utilization, + help="the percentage of GPU memory to be used forthe model executor", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument("--disable-log-stats", action="store_true", help="disable logging statistics") # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', None], - default=None, - help='Method used to quantize the weights') + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", None], + default=None, + help="Method used to quantize the weights", + ) return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -266,45 +292,50 @@ class EngineArgs: ) -> EngineConfig: # bitsandbytes quantization needs a specific model loader # so we make sure the quant method and the load format are consistent - if (self.quantization == "bitsandbytes" or - self.qlora_adapter_name_or_path is not None) and \ - self.load_format != "bitsandbytes": - raise ValueError("BitsAndBytes quantization and QLoRA adapter only support " - f"'bitsandbytes' load format, but got {self.load_format}") + if ( + self.quantization == "bitsandbytes" or self.qlora_adapter_name_or_path is not None + ) and self.load_format != "bitsandbytes": + raise ValueError( + "BitsAndBytes quantization and QLoRA adapter only support " + f"'bitsandbytes' load format, but got {self.load_format}" + ) - if (self.load_format == "bitsandbytes" or - self.qlora_adapter_name_or_path is not None) and \ - self.quantization != "bitsandbytes": - raise ValueError("BitsAndBytes load format and QLoRA adapter only support " - f"'bitsandbytes' quantization, but got {self.quantization}") + if ( + self.load_format == "bitsandbytes" or self.qlora_adapter_name_or_path is not None + ) and self.quantization != "bitsandbytes": + raise ValueError( + "BitsAndBytes load format and QLoRA adapter only support " + f"'bitsandbytes' quantization, but got {self.quantization}" + ) - assert self.cpu_offload_gb >= 0, ("CPU offload space must be non-negative" - f", but got {self.cpu_offload_gb}") + assert self.cpu_offload_gb >= 0, f"CPU offload space must be non-negative, but got {self.cpu_offload_gb}" multimodal_config = MultiModalConfig() device_config = DeviceConfig(self.device) # NOTE(sgm): we only modify ModelConfig, other configs are import from vllm - model_config = ModelConfig(hf_config=self.model_hf_config, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code, - dtype=self.dtype, - seed=self.seed, - revision=self.revision, - code_revision=self.code_revision, - rope_scaling=self.rope_scaling, - rope_theta=self.rope_theta, - tokenizer_revision=self.tokenizer_revision, - max_model_len=self.max_model_len, - quantization=self.quantization, - quantization_param_path=self.quantization_param_path, - enforce_eager=self.enforce_eager, - max_context_len_to_capture=self.max_context_len_to_capture, - max_seq_len_to_capture=self.max_seq_len_to_capture, - max_logprobs=self.max_logprobs, - disable_sliding_window=self.disable_sliding_window, - skip_tokenizer_init=self.skip_tokenizer_init, - served_model_name=self.served_model_name, - multimodal_config=multimodal_config) + model_config = ModelConfig( + hf_config=self.model_hf_config, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + multimodal_config=multimodal_config, + ) cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -315,18 +346,20 @@ class EngineArgs: enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, ) - parallel_config = ParallelConfig(pipeline_parallel_size=self.pipeline_parallel_size, - tensor_parallel_size=self.tensor_parallel_size, - worker_use_ray=self.worker_use_ray, - max_parallel_loading_workers=self.max_parallel_loading_workers, - disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), - ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + worker_use_ray=self.worker_use_ray, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend, + ) # NOTE[VERL]: Use the world_size set by TORCHRUN world_size = int(os.getenv("WORLD_SIZE", "-1")) @@ -341,18 +374,26 @@ class EngineArgs: # initial memory profiling phase. if use_long_context: is_gpu = device_config.device_type == "cuda" - use_sliding_window = (model_config.get_sliding_window() is not None) + use_sliding_window = model_config.get_sliding_window() is not None use_spec_decode = self.speculative_model is not None - has_seqlen_agnostic_layers = (model_config.contains_seqlen_agnostic_layers(parallel_config)) - if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and - not self.enable_prompt_adapter and not self.enable_prefix_caching and - not has_seqlen_agnostic_layers): + has_seqlen_agnostic_layers = model_config.contains_seqlen_agnostic_layers(parallel_config) + if ( + is_gpu + and not use_sliding_window + and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter + and not self.enable_prefix_caching + and not has_seqlen_agnostic_layers + ): self.enable_chunked_prefill = True - logger.warning("Chunked prefill is enabled by default for models with " - "max_model_len > 32K. Currently, chunked prefill might " - "not work with some features or models. If you " - "encounter any issues, please disable chunked prefill " - "by setting --enable-chunked-prefill=False.") + logger.warning( + "Chunked prefill is enabled by default for models with " + "max_model_len > 32K. Currently, chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable chunked prefill " + "by setting --enable-chunked-prefill=False." + ) if self.enable_chunked_prefill is None: self.enable_chunked_prefill = False @@ -361,7 +402,9 @@ class EngineArgs: "The model has a long context length (%s). This may cause OOM " "errors during the initial memory profiling phase, or result " "in low performance due to small KV cache space. Consider " - "setting --max-model-len to a smaller value.", max_model_len) + "setting --max-model-len to a smaller value.", + max_model_len, + ) # TODO: spec config speculative_config = SpeculativeConfig.maybe_create_spec_config( @@ -369,23 +412,18 @@ class EngineArgs: target_parallel_config=parallel_config, target_dtype=self.dtype, speculative_model=self.speculative_model, - speculative_draft_tensor_parallel_size = \ - self.speculative_draft_tensor_parallel_size, + speculative_draft_tensor_parallel_size=self.speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens, - speculative_disable_by_batch_size=self. - speculative_disable_by_batch_size, + speculative_disable_by_batch_size=self.speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_acceptance_method=\ - self.spec_decoding_acceptance_method, - typical_acceptance_sampler_posterior_threshold=self. - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=self. - typical_acceptance_sampler_posterior_alpha, + draft_token_acceptance_method=self.spec_decoding_acceptance_method, + typical_acceptance_sampler_posterior_threshold=self.typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=self.typical_acceptance_sampler_posterior_alpha, disable_logprobs=self.disable_logprobs_during_spec_decoding, ) @@ -394,24 +432,29 @@ class EngineArgs: max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, use_v2_block_manager=self.use_v2_block_manager, - num_lookahead_slots=(self.num_lookahead_slots - if speculative_config is None else speculative_config.num_lookahead_slots), + num_lookahead_slots=( + self.num_lookahead_slots if speculative_config is None else speculative_config.num_lookahead_slots + ), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, preemption_mode=self.preemption_mode, ) - lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - long_lora_scaling_factors=self.long_lora_scaling_factors, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else - None) if self.enable_lora else None + lora_config = ( + LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None, + ) + if self.enable_lora + else None + ) - if self.qlora_adapter_name_or_path is not None and \ - self.qlora_adapter_name_or_path != "": + if self.qlora_adapter_name_or_path is not None and self.qlora_adapter_name_or_path != "": if self.model_loader_extra_config is None: self.model_loader_extra_config = {} self.model_loader_extra_config["qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path @@ -423,19 +466,27 @@ class EngineArgs: ignore_patterns=self.ignore_patterns, ) - prompt_adapter_config = PromptAdapterConfig( - max_prompt_adapters=self.max_prompt_adapters, - max_prompt_adapter_token=self.max_prompt_adapter_token) \ - if self.enable_prompt_adapter else None + prompt_adapter_config = ( + PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, max_prompt_adapter_token=self.max_prompt_adapter_token + ) + if self.enable_prompt_adapter + else None + ) decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend) observability_config = ObservabilityConfig(otlp_traces_endpoint=self.otlp_traces_endpoint) - if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled and - not scheduler_config.use_v2_block_manager): - raise ValueError("Chunked prefill is not supported with sliding window. " - "Set --disable-sliding-window to disable sliding window.") + if ( + model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled + and not scheduler_config.use_v2_block_manager + ): + raise ValueError( + "Chunked prefill is not supported with sliding window. " + "Set --disable-sliding-window to disable sliding window." + ) return EngineConfig( model_config=model_config, diff --git a/verl/third_party/vllm/vllm_v_0_5_4/config.py b/verl/third_party/vllm/vllm_v_0_5_4/config.py index 5fc61e6fe..03d5a39dd 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/config.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/config.py @@ -15,18 +15,25 @@ import enum import json +from dataclasses import dataclass, field from typing import List, Optional, Union -from dataclasses import dataclass, field, fields import torch from transformers import PretrainedConfig +# Add for verl +from vllm.config import ( + ModelConfig, + MultiModalConfig, + _get_and_verify_dtype, + _get_and_verify_max_len, + get_served_model_name, +) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.model_loader import BaseModelLoader from vllm.transformers_utils.config import get_hf_text_config from vllm.utils import is_hip, print_warning_once -# Add for verl -from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len, get_served_model_name GPTQMarlinConfig = get_quantization_config("gptq_marlin") @@ -91,8 +98,8 @@ class ModelConfig(ModelConfig): skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, the model name will be the same as `model`. """ @@ -118,7 +125,7 @@ class ModelConfig(ModelConfig): disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, - multimodal_config: Optional["MultiModalConfig"] = None, + multimodal_config: Optional[MultiModalConfig] = None, ) -> None: self.model = hf_config._name_or_path self.tokenizer = hf_config._name_or_path @@ -139,8 +146,7 @@ class ModelConfig(ModelConfig): self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager if max_context_len_to_capture is not None: - raise ValueError("`max_context_len_to_capture` is deprecated. " - "Use `max_seq_len_to_capture` instead.") + raise ValueError("`max_context_len_to_capture` is deprecated. Use `max_seq_len_to_capture` instead.") self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window @@ -154,21 +160,29 @@ class ModelConfig(ModelConfig): # served_model_name) # self._verify_load_format() # self._verify_tokenizer_mode() - if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and - self.hf_text_config.sliding_window is not None): - print_warning_once("Gemma 2 uses sliding window attention for every odd layer, " - "which is currently not supported by vLLM. Disabling sliding " - "window and capping the max length to the sliding window size " - f"({self.hf_text_config.sliding_window}).") + if ( + not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None + ): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window})." + ) self.disable_sliding_window = True - self.max_model_len = _get_and_verify_max_len(hf_config=self.hf_text_config, - max_model_len=max_model_len, - disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window()) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + ) self.served_model_name = get_served_model_name( self.model, # str - served_model_name) + served_model_name, + ) self.multimodal_config = multimodal_config if not self.skip_tokenizer_init: @@ -179,41 +193,41 @@ class ModelConfig(ModelConfig): class LoadFormat(str, enum.Enum): - AUTO = 'auto' + AUTO = "auto" MEGATRON = "megatron" HF = "hf" - DTENSOR = 'dtensor' - DUMMY_HF = 'dummy_hf' - DUMMY_MEGATRON = 'dummy_megatron' - DUMMY_DTENSOR = 'dummy_dtensor' + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" # TODO: check whether this is necessary @dataclass class LoadConfig: """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. - "bitsandbytes" will load nf4 type weights. - ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's - checkpoints. - + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + """ - load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + load_format: Union[str, LoadFormat, BaseModelLoader] = LoadFormat.AUTO download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None @@ -241,6 +255,8 @@ class LoadConfig: rocm_supported_load_format = [ f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) ] - raise ValueError(f"load format '{load_format}' is not supported in ROCm. " - f"Supported load formats are " - f"{rocm_supported_load_format}") + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}" + ) diff --git a/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py index 732b543db..6175c56d4 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py @@ -13,13 +13,11 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models -from typing import Dict, Iterable, Tuple -import torch -import torch.nn as nn -from torch.distributed._tensor import DTensor, Shard, Replicate +from typing import Dict +import torch.nn as nn +from torch.distributed._tensor import DTensor from vllm.model_executor.layers.linear import * -from vllm.model_executor.models import ModelRegistry from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import is_pp_missing_parameter @@ -35,7 +33,7 @@ def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n ] params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue stacked_name = name.replace(shard_name, param_name) @@ -89,7 +87,7 @@ def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -120,7 +118,7 @@ def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n for name, loaded_weight in actor_weights.items(): if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -129,7 +127,7 @@ def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n # processed with quantization, LoRA, fine-tuning, etc. if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -166,7 +164,7 @@ def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n continue if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -200,16 +198,18 @@ def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping(ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=vllm_model.config.n_routed_experts) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) for name, loaded_weight in actor_weights.items(): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -219,7 +219,7 @@ def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -247,11 +247,13 @@ def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) param = params_dict[name] local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, - local_loaded_weight.to(dtype=param.dtype), - weight_name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -274,11 +276,13 @@ def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): param_name = _process_parameter_names(name=param_name) if parallelize_plan is not None: - assert param_name in parallelize_plan.keys(), \ + assert param_name in parallelize_plan, ( f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + ) placement = parallelize_plan[param_name] - local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, - placements=placement).to_local() + local_loaded_weights = loaded_weights.redistribute( + device_mesh=loaded_weights.device_mesh, placements=placement + ).to_local() else: local_loaded_weights = loaded_weights.full_tensor() return local_loaded_weights @@ -291,9 +295,9 @@ def _process_parameter_names(name): # Remove 'model.layers.x.' or 'model.' prefix if "model.layers" in name: - parts = name.split('.') + parts = name.split(".") # Reconstruct the string without 'model.layers.x.' - name = '.'.join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' elif name.startswith("model."): name = name[6:] # Remove 'model.' @@ -301,20 +305,20 @@ def _process_parameter_names(name): __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { - 'GPT2LMHeadModel': gpt2_dtensor_weight_loader, - 'LlamaForCausalLM': llama_dtensor_weight_loader, - 'LLaMAForCausalLM': llama_dtensor_weight_loader, - 'MistralForCausalLM': llama_dtensor_weight_loader, # mistral is the same as llama in vLLM - 'InternLMForCausalLM': llama_dtensor_weight_loader, - 'AquilaModel': llama_dtensor_weight_loader, - 'AquilaForCausalLM': llama_dtensor_weight_loader, - 'Phi3ForCausalLM': llama_dtensor_weight_loader, - 'GemmaForCausalLM': gemma_dtensor_weight_loader, - 'Gemma2ForCausalLM': gemma_dtensor_weight_loader, - 'GPTBigCodeForCausalLM': gptbigcode_dtensor_load_weights, - 'Starcoder2ForCausalLM': starcoder2_dtensor_load_weights, - 'Qwen2ForCausalLM': qwen2_dtensor_weight_loader, - 'DeepseekV2ForCausalLM': deepseekv2_dtensor_weight_loader + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, } @@ -331,8 +335,10 @@ def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): def _get_model_weight_loader(arch: str): if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}" + ) # NOTE(sgm): we use per-parameter weight loader in each vllm sub diff --git a/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py index 7af4953f3..0de56a008 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py @@ -13,24 +13,21 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models -from typing import Dict, Union, Optional, Iterable, Tuple +from typing import Dict -import torch import torch.nn as nn - from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.model_loader.weight_utils import default_weight_loader def update_hf_weight_loader(): - print('no hf weight loader need to be updated') + print("no hf weight loader need to be updated") return def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): assert isinstance(actor_weights, Dict) with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights: del actor_weights["lm_head.weight"] vllm_model.load_weights(actor_weights.items()) for _, module in vllm_model.named_modules(): diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm.py b/verl/third_party/vllm/vllm_v_0_5_4/llm.py index 80e6f906f..4ec3d2aac 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/llm.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/llm.py @@ -13,31 +13,21 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py -from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload, Dict, Tuple, Iterable +from typing import Dict, Iterable, List, Optional, Tuple, Union -from tqdm import tqdm -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers import PretrainedConfig +import torch import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm import LLM +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.utils import Counter + +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer + from .arg_utils import EngineArgs from .llm_engine_sp import LLMEngine -from vllm import LLM -from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import (GuidedDecodingRequest, get_local_guided_decoding_logits_processor) -from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import get_cached_tokenizer -from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_kwargs -import torch -from torch.nn.utils.rnn import pad_sequence -from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM(LLM): @@ -96,7 +86,7 @@ class LLM(LLM): def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], model_hf_config: PretrainedConfig, tokenizer_mode: str = "auto", @@ -115,7 +105,7 @@ class LLM(LLM): max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, - load_format = 'auto', + load_format="auto", **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -171,8 +161,7 @@ class LLM(LLM): total=num_requests, desc="Processed prompts", dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, " - f"output: {0:.2f} toks/s"), + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), ) # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -190,8 +179,7 @@ class LLM(LLM): in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) out_spd = total_out_toks / pbar.format_dict["elapsed"] - pbar.postfix = (f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s") + pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() @@ -226,7 +214,11 @@ class LLM(LLM): logprob.append(logprobs_dict[id].logprob) logprobs.append(torch.tensor(logprob)) - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + pad_token_id = ( + self.llm_engine.tokenizer.pad_token_id + if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id + ) output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) if len(logprobs) > 0: logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py index d532f59a9..b0f9dcf4a 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py @@ -13,31 +13,39 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py -import torch -from typing import Dict, Optional, Union, Type, Iterable +from typing import Dict, Iterable, Optional, Type, Union -import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) +from torch import nn +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoRAConfig, + MultiModalConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.interfaces import (SequenceGroupOutputProcessor) +from vllm.engine.llm_engine import LLMEngine, _load_generation_config_dict +from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger, StatLoggerBase +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.tracing import init_tracer from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger, StatLoggerBase, Stats) -from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message from vllm.utils import Counter -from vllm.engine.llm_engine import _load_generation_config_dict -from vllm.engine.llm_engine import LLMEngine from vllm.version import __version__ as VLLM_VERSION -import torch.nn as nn from .arg_utils import EngineArgs +from .config import LoadConfig, ModelConfig from .tokenizer import TokenizerGroup -from .config import ModelConfig, LoadConfig logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -77,7 +85,7 @@ class LLMEngine(LLMEngine): def __init__( self, # NOTE(sgm): first two arguments are added for verl - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict tokenizer: nn.Module, # NOTE(sgm): vllm original arguments model_config: ModelConfig, @@ -171,7 +179,7 @@ class LLMEngine(LLMEngine): self.input_processor = INPUT_REGISTRY.create_input_processor(self.model_config) self.model_executor = executor_class( - model=model, # add for spmd_gpu_executor + model=model, # add for spmd_gpu_executor model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -190,7 +198,8 @@ class LLMEngine(LLMEngine): # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import (get_architecture_class_name) + from vllm.model_executor.model_loader import get_architecture_class_name + usage_message.report_usage( get_architecture_class_name(model_config), usage_context, @@ -200,18 +209,17 @@ class LLMEngine(LLMEngine): "tensor_parallel_size": parallel_config.tensor_parallel_size, "block_size": cache_config.block_size, "gpu_memory_utilization": cache_config.gpu_memory_utilization, - # Quantization "quantization": model_config.quantization, "kv_cache_dtype": str(cache_config.cache_dtype), - # Feature flags "enable_lora": bool(lora_config), "enable_prompt_adapter": bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, "enforce_eager": model_config.enforce_eager, "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, - }) + }, + ) if self.tokenizer: # Ping the tokenizer to ensure liveness if it runs in a @@ -232,12 +240,12 @@ class LLMEngine(LLMEngine): self.stat_loggers = stat_loggers else: self.stat_loggers = { - "logging": - LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), - "prometheus": - PrometheusStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len), + "logging": LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ), } self.stat_loggers["prometheus"].info("cache_config", self.cache_config) @@ -247,7 +255,7 @@ class LLMEngine(LLMEngine): # Create sequence output processor, e.g. for beam search or # speculative decoding. - self.output_processor = (SequenceGroupOutputProcessor.create_output_processor( + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( self.scheduler_config, self.detokenizer, self.scheduler, @@ -257,13 +265,13 @@ class LLMEngine(LLMEngine): self.scheduler_config.max_model_len, self.get_tokenizer_for_seq, ), - )) + ) # TODO(sgm): add for verl but we may not tokenizer in Rollout def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): - init_kwargs = dict(enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None) + init_kwargs = dict( + enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None + ) init_kwargs.update(tokenizer_init_kwargs) return TokenizerGroup(tokenizer, **init_kwargs) @@ -279,13 +287,15 @@ class LLMEngine(LLMEngine): # The GPUExecutor remove the Ray dependency @classmethod def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: - assert engine_config.device_config.device_type == "cuda", \ + assert engine_config.device_config.device_type == "cuda", ( "Currently, the vllm in verl only support running on GPU" + ) if engine_config.parallel_config.world_size == 1: engine_config.load_config.load_format = "dummy_hf" from .spmd_gpu_executor import SPMDGPUExecutor + executor_class = SPMDGPUExecutor return executor_class @@ -303,10 +313,12 @@ class LLMEngine(LLMEngine): engine_config = engine_args.create_engine_config() executor_class = cls._get_executor_cls(engine_config) # Initialize the cluster and specify the executor class. - assert engine_config.device_config.device_type == "cuda", \ + assert engine_config.device_config.device_type == "cuda", ( "Currently, the vllm in verl only support running on GPU" + ) from .spmd_gpu_executor import SPMDGPUExecutor + executor_class = SPMDGPUExecutor # Create the LLM engine. diff --git a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py index 98a96c499..d5916f4e2 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py @@ -14,22 +14,25 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models from typing import Dict, Iterable + import torch import torch.nn as nn - from vllm.model_executor.layers.linear import * -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead -from vllm.model_executor.layers.activation import ScaledActivation +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding from vllm.model_executor.models import ModelRegistry # NOTE(shengguangming): replace the origin weight loader function in the class def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Parallel Linear weight loader.""" - assert param.size() == loaded_weight.size( - ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format( - param.size(), loaded_weight.size()) - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + assert param.size() == loaded_weight.size(), ( + "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size() + ) + ) + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -37,7 +40,9 @@ def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tenso def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -90,20 +95,20 @@ def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Mod ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'), - ('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -118,22 +123,22 @@ def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module # (megatron core gpt model name, vllm model name) ("embedding.word_embeddings", "model.embed_tokens"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), + ("self_attention.linear_proj", "self_attn.o_proj"), ( - 'input_layernorm', - 'input_layernorm', + "input_layernorm", + "input_layernorm", ), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -147,19 +152,19 @@ def _replace_name(megatron_name, name_mapping): for m_name, v_name in name_mapping: if m_name not in megatron_name: continue - if 'layers' in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace('decoder', 'model') - megatron_name_list = megatron_name.split('.') - if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list: + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: param_name_list = megatron_name_list[:3] param_name_list.append(v_name) - param_name = '.'.join(param_name_list) + param_name = ".".join(param_name_list) else: param_name_list = megatron_name_list[:3] weight_or_bias = megatron_name_list[-1] param_name_list.append(v_name) param_name_list.append(weight_or_bias) - param_name = '.'.join(param_name_list) + param_name = ".".join(param_name_list) return param_name else: param_name = megatron_name.replace(m_name, v_name) @@ -174,20 +179,20 @@ def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Mod ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'), - ('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -202,22 +207,22 @@ def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module # (megatron core gpt model name, vllm model name) ("embedding.word_embeddings", "model.embed_tokens"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", 'self_attn.o_proj'), + ("self_attention.linear_proj", "self_attn.o_proj"), ( - 'input_layernorm', - 'input_layernorm', + "input_layernorm", + "input_layernorm", ), - ('pre_mlp_layernorm', 'post_attention_layernorm'), - ('mlp.linear_fc1', 'mlp.gate_up_proj'), - ('mlp.linear_fc2', 'mlp.down_proj'), - ('decoder.final_layernorm', 'model.norm'), - ('output_layer', 'lm_head'), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), ] # NOTE(shengguangming): the megatron llama may have this prefix params_dict = dict(vllm_model.named_parameters()) for name, loaded_weight in actor_weights.items(): name = _replace_name(name, params_mapping) - if name.endswith('.bias') and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if "rotary_emb.inv_freq" in name: continue @@ -245,7 +250,7 @@ __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { QKVParallelLinear: parallel_weight_loader, RowParallelLinear: parallel_weight_loader, VocabParallelEmbedding: parallel_weight_loader, - ParallelLMHead: parallel_weight_loader + ParallelLMHead: parallel_weight_loader, # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights # "default_weight_loader": default_weight_loader } @@ -255,10 +260,10 @@ __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { # layer_class.weight_loader = weight_loader __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { - 'GPT2LMHeadModel': gpt2_weight_loader, - 'LlamaForCausalLM': llama_megatron_weight_loader, # use te backend for open-source megatron - 'LLaMAForCausalLM': llama_megatron_weight_loader, - 'MistralForCausalLM': mistral_megatron_weight_loader, + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, } @@ -275,8 +280,10 @@ def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): def _get_model_weight_loader(arch: str): if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) def update_megatron_weight_loader(): diff --git a/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py b/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py index 1b675bb79..f54cab9f3 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py @@ -13,52 +13,65 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader -from typing import Dict, Union, Optional, Iterable, Tuple +from typing import Dict, Optional, Union import torch import torch.nn as nn from transformers import PreTrainedModel - -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + MultiModalConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.distributed.communication_op import tensor_model_parallel_all_gather from vllm.model_executor.model_loader import BaseModelLoader from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.distributed.communication_op import tensor_model_parallel_all_gather -from .config import ModelConfig, LoadFormat, LoadConfig -from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader +from .config import LoadConfig, LoadFormat, ModelConfig from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader from .hf_weight_loader import update_hf_weight_loader +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader -def get_model(actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - load_config: LoadConfig, - device_config: DeviceConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig = None) -> nn.Module: +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig = None, +) -> nn.Module: loader = get_model_loader(load_config) - if load_config.load_format.startswith('dummy'): - return loader.load_model(model_config=model_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config) + if load_config.load_format.startswith("dummy"): + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) else: - return loader.load_model(actor_model=actor_model, - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config) + return loader.load_model( + actor_model=actor_model, + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: @@ -96,8 +109,11 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: update_dtensor_weight_loader() return DummyModelLoader(load_config) - raise ValueError('load format not supported in verl: {}, only support {} and {}'.format( - load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + raise ValueError( + "load format not supported in verl: {}, only support {} and {}".format( + load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF + ) + ) class DummyModelLoader(BaseModelLoader): @@ -106,16 +122,24 @@ class DummyModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, - scheduler_config) + model = _initialize_model( + model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config + ) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. # initialize_dummy_weights(model) @@ -128,8 +152,7 @@ class MegatronLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): # NOTE(shengguangming) Load the weights from the actor model @@ -140,19 +163,28 @@ class MegatronLoader(BaseModelLoader): # load_weights(actor_weights=actor_model, vllm_model=model) # return actor_model - def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, - device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, - scheduler_config) + model = _initialize_model( + model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config + ) # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm if isinstance(actor_model, nn.Module): - load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), - vllm_model=model) + load_megatron_weights( + actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model + ) else: load_megatron_weights(actor_weights=actor_model, vllm_model=model) @@ -175,8 +207,7 @@ class HFLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): if isinstance(actor_model, Dict): @@ -184,17 +215,25 @@ class HFLoader(BaseModelLoader): elif isinstance(actor_model, nn.Module): return dict(actor_model.named_parameters()).items() else: - raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}') + raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") - def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, - device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): # with torch.device(device_config.device): # NOTE(sgm): init the model in cpu - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, - scheduler_config) + model = _initialize_model( + model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config + ) model.load_weights(self._get_weights_iterator(actor_model)) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -215,8 +254,7 @@ class DTensorLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): # NOTE(shengguangming) Load the weights from the actor model @@ -227,19 +265,28 @@ class DTensorLoader(BaseModelLoader): # load_weights(actor_weights=actor_model, vllm_model=model) # return actor_model - def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, - device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, - scheduler_config) + model = _initialize_model( + model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config + ) # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm if isinstance(actor_model, nn.Module): - load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), - vllm_model=model) + load_dtensor_weights( + actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model + ) else: load_dtensor_weights(actor_weights=actor_model, vllm_model=model) @@ -260,8 +307,9 @@ class DTensorLoader(BaseModelLoader): # as they use ray, the _get_logits result will only need to return to the driver node, # therefore gather is enough. However, we use SPMD instead of a central scheduler, # all_gather is required (aligned with v0.2.6) -def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: +def _get_logits( + self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor] +) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: @@ -269,19 +317,21 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[:, :self.org_vocab_size] + logits = logits[:, : self.org_vocab_size] return logits from vllm.model_executor.layers.logits_processor import LogitsProcessor -def logitsprocessor_init(self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None) -> None: +def logitsprocessor_init( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, +) -> None: """ Args: scale: A scaling factor to apply to the logits. diff --git a/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py b/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py index d6ab23255..ec9eb671b 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py @@ -13,28 +13,31 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py +import warnings +from enum import IntEnum +from typing import Dict, Optional, Union + import torch import torch.nn as nn -from enum import IntEnum -from typing import Dict, List, Optional, Set, Tuple, Union -import warnings - import vllm.envs as envs -from vllm.attention import (AttentionMetadata, get_attn_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, +) from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.models.interfaces import (supports_lora, supports_vision) -from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available) -from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner -from vllm.prompt_adapter.worker_manager import (LRUCacheWorkerPromptAdapterManager) +from vllm.model_executor.models.interfaces import supports_lora, supports_vision +from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager +from vllm.utils import CudaMemoryProfiler, is_hip +from vllm.worker.model_runner import ModelRunner +from .config import LoadConfig, ModelConfig from .model_loader import get_model -from .config import ModelConfig, LoadConfig logger = init_logger(__name__) @@ -50,10 +53,9 @@ class BatchType(IntEnum): class ModelRunner(ModelRunner): - def __init__( self, - model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict + model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -67,7 +69,6 @@ class ModelRunner(ModelRunner): multimodal_config: Optional[MultiModalConfig] = None, return_hidden_states: bool = False, ): - super().__init__( model_config, parallel_config, @@ -80,7 +81,8 @@ class ModelRunner(ModelRunner): is_driver_worker=True, # a hack prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config, - return_hidden_states=return_hidden_states) + return_hidden_states=return_hidden_states, + ) # NOTE(sgm): add for verl self.model = model # this will be replaced by get_model() @@ -89,15 +91,17 @@ class ModelRunner(ModelRunner): def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with CudaMemoryProfiler() as m: - self.model = get_model(actor_model=self.model, - model_config=self.model_config, - device_config=self.device_config, - lora_config=self.lora_config, - load_config=self.load_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - multimodal_config=self.multimodal_config, - cache_config=self.cache_config) + self.model = get_model( + actor_model=self.model, + model_config=self.model_config, + device_config=self.device_config, + lora_config=self.lora_config, + load_config=self.load_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + multimodal_config=self.multimodal_config, + cache_config=self.cache_config, + ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) @@ -119,9 +123,12 @@ class ModelRunner(ModelRunner): if self.prompt_adapter_config: self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( - self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, self.device, - self.prompt_adapter_config) - self.model = (self.prompt_adapter_manager.create_prompt_adapter_manager(self.model)) + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.device, + self.prompt_adapter_config, + ) + self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors @@ -134,17 +141,22 @@ class ModelRunner(ModelRunner): "deprecated and will be removed. Please include " "kv cache scaling factors in the model checkpoint.", FutureWarning, - stacklevel=2) + stacklevel=2, + ) self.model.load_kv_cache_scales(self.model_config.quantization_param_path) logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", self.model.__class__) + "model %s does not support loading scaling factors.", + self.model.__class__, + ) else: - logger.warning("Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: self.model = torch.compile(self.model, fullgraph=True, backend="eager") diff --git a/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py b/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py index 4d853ba75..cb16b65aa 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py @@ -4,19 +4,21 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Model and data parallel groups.""" + import os -import torch -import torch.distributed from typing import Optional +import torch +import torch.distributed import vllm.distributed.parallel_state as ps -from vllm.distributed.parallel_state import get_pp_group, get_world_group, init_distributed_environment, init_model_parallel_group - -import vllm.envs as envs +from vllm.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) from vllm.logger import init_logger -from torch.distributed.device_mesh import init_device_mesh - logger = init_logger(__name__) """ This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. @@ -59,8 +61,10 @@ def initialize_parallel_state( init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) if torch.distributed.get_world_size() > 1: # NOTE: build a sepearate inference group with infer tp & micro dp - initialize_model_parallel_for_vllm(tensor_model_parallel_size=tensor_model_parallel_size, - num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp) + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) else: initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) @@ -80,28 +84,31 @@ def ensure_model_parallel_initialized( initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) return - assert (get_tensor_model_parallel_world_size() == tensor_model_parallel_size), ( + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( "tensor parallel group already initialized, but of unexpected size: " f"{get_tensor_model_parallel_world_size()=} vs. " - f"{tensor_model_parallel_size=}") + f"{tensor_model_parallel_size=}" + ) pp_world_size = get_pp_group().world_size - assert (pp_world_size == pipeline_model_parallel_size), ( + assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size: " f"{pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}") + f"{pipeline_model_parallel_size=}" + ) # TODO(sgm): deviate from the v0.5.4, not pp now def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (ps._TP is not None) + return ps._TP is not None # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) -def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, - num_tensor_model_parallel_groups_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1) -> None: - from torch.distributed import new_group +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() @@ -111,7 +118,7 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group # Build the tensor model-parallel groups. - assert ps._TP is None, ("tensor model parallel group is already initialized") + assert ps._TP is None, "tensor model parallel group is already initialized" global _TP @@ -126,7 +133,7 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, if num_tensor_model_parallel_groups_per_train_tp == 1: # if tensor_model_parallel_size == train_tensor_parallel_size: # using the same tp group as Megatron/vllm - assert _TP is None, ("tensor model parallel group is already initialized") + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) @@ -136,7 +143,8 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, local_rank=get_world_group().local_rank, backend=backend, use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + ) ps._TP = _TP # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine else: @@ -148,7 +156,7 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, # train_tp = train_tensor_parallel_size train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size - assert _TP is None, ("tensor model parallel group is already initialized") + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): start = train_tp * i @@ -163,7 +171,8 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, local_rank=get_world_group().local_rank, backend=backend, use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + ) ps._TP = _TP # Build the pipeline model-parallel groups. @@ -176,9 +185,9 @@ def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int, # TODO: init using device mesh (not support hybrid engine now) # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size global _PP - assert _PP is None, ("pipeline model parallel group is already initialized") + assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) @@ -196,7 +205,7 @@ def initialize_model_parallel( """ NOTE: This method is a hack from the open-sourced version without asertion of world_size = tp * pp - + Initialize model parallel groups. Arguments: @@ -232,10 +241,10 @@ def initialize_model_parallel( # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size rank = torch.distributed.get_rank() global _TP - assert _TP is None, ("tensor model parallel group is already initialized") + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) @@ -247,14 +256,15 @@ def initialize_model_parallel( get_world_group().local_rank, backend, use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + ) ps._TP = _TP # TODO: init using device mesh (not support hybrid engine now) # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size global _PP - assert _PP is None, ("pipeline model parallel group is already initialized") + assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) @@ -270,7 +280,7 @@ Device mesh utilities def get_device_mesh(): - assert _DEVICE_MESH is not None, ("device mesh is not initialized") + assert _DEVICE_MESH is not None, "device mesh is not initialized" return _DEVICE_MESH @@ -281,7 +291,7 @@ Tensor model parallel utilities def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TP is not None, ("tensor model parallel group is not initialized") + assert _TP is not None, "tensor model parallel group is not initialized" return _TP.device_group diff --git a/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py index 31e71aceb..f92d881b1 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py @@ -15,18 +15,25 @@ import os import socket -from typing import Any, Dict, List, Optional, Set, Tuple, Iterable +from typing import Iterable, List, Optional, Set, Tuple import torch -import vllm.envs as envs -from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig) -from .config import ModelConfig, LoadConfig +from .config import LoadConfig, ModelConfig logger = init_logger(__name__) @@ -36,7 +43,7 @@ class SPMDGPUExecutor(ExecutorBase): def __init__( self, - model, # pytorch model itself or its parameter dict + model, # pytorch model itself or its parameter dict model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig, @@ -64,7 +71,7 @@ class SPMDGPUExecutor(ExecutorBase): # TODO(sgm): verl not support speculative decode now def _init_executor(self, model, distributed_init_method) -> None: - assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend." + assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." # Create the parallel worker for each GPU. self._init_workers_sp(model, distributed_init_method) @@ -72,14 +79,14 @@ class SPMDGPUExecutor(ExecutorBase): def _init_workers_sp(self, model, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from .worker import Worker # pylint: disable=import-outside-toplevel + from .worker import Worker rank = int(os.getenv("RANK")) local_rank = int(os.getenv("LOCAL_RANK")) - print(f'local rank {local_rank}') + print(f"local rank {local_rank}") # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ['NCCL_CUMEM_ENABLE'] = '0' + os.environ["NCCL_CUMEM_ENABLE"] = "0" self.worker = Worker( model, @@ -125,8 +132,7 @@ class SPMDGPUExecutor(ExecutorBase): return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ + """Initialize the KV cache in all workers.""" # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors @@ -138,12 +144,12 @@ class SPMDGPUExecutor(ExecutorBase): if torch.distributed.get_rank() == 0: print( - f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB' + f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" ) self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) if torch.distributed.get_rank() == 0: print( - f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB' + f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" ) # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache @@ -181,8 +187,7 @@ class SPMDGPUExecutor(ExecutorBase): from vllm.prompt_adapter.request import PromptAdapterRequest def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: - assert prompt_adapter_request.prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." + assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." return self.worker.add_prompt_adapter(prompt_adapter_request) def list_prompt_adapters(self) -> Set[int]: @@ -193,13 +198,11 @@ class SPMDGPUExecutor(ExecutorBase): return self.worker.pin_lora(lora_id) def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." return self.worker.pin_prompt_adapter(prompt_adapter_id) def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." return self.worker.remove_prompt_adapter(prompt_adapter_id) # NOTE(sgm): add for verl @@ -230,7 +233,7 @@ def initialize_cluster( # We need to setup the distributed init method to make sure # the distributed megatron code (e.g., get world size) works correctly. # distributed_init_method = f"tcp://localhost:{port}" - distributed_init_method = 'env://' + distributed_init_method = "env://" return distributed_init_method @@ -242,7 +245,6 @@ def get_open_port(): # TODO(sgm): not implemented async executor yet class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): - async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py b/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py index aa625a033..d938e6a05 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py @@ -13,20 +13,20 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py -from typing import List, Optional, Tuple, Union - -from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from typing import List, Optional +from transformers import PreTrainedTokenizer from vllm.lora.request import LoRARequest -from vllm.utils import make_async, LRUCache from vllm.transformers_utils.tokenizers import * +from vllm.utils import LRUCache class TokenizerGroup: """A group of tokenizers that can be used for LoRA adapters.""" - def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int]): + def __init__( + self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int] + ): self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = tokenizer @@ -40,17 +40,15 @@ class TokenizerGroup: """Get the maximum input length for the LoRA request.""" return self.max_input_length - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + def encode( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) return tokenizer.encode(prompt) - async def encode_async(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + async def encode_async( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) return tokenizer.encode(prompt) diff --git a/verl/third_party/vllm/vllm_v_0_5_4/worker.py b/verl/third_party/vllm/vllm_v_0_5_4/worker.py index a5deb675a..dda09ab25 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/worker.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/worker.py @@ -13,32 +13,42 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py """A GPU worker class.""" -import os + import gc -from typing import Dict, List, Tuple, Optional, Union, Type +import os +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed import torch.nn as nn +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig) -from vllm.model_executor import set_random_seed -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput) -from vllm.worker.cache_engine import CacheEngine # TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state -from vllm.distributed import (init_distributed_environment, set_custom_all_reduce, get_tensor_model_parallel_group) -from vllm.worker.worker_base import WorkerInput -from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest, IntermediateTensors, SamplerOutput +from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import GPUModelRunnerBase -from .model_runner import ModelRunner -from .megatron_weight_loaders import load_megatron_weights -from .hf_weight_loader import load_hf_weights +from vllm.worker.model_runner_base import ModelRunnerInputBase +from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype +from vllm.worker.worker_base import WorkerInput + +from .config import LoadConfig, LoadFormat, ModelConfig from .dtensor_weight_loaders import load_dtensor_weights -from .parallel_state import (ensure_model_parallel_initialized) -from .config import ModelConfig, LoadConfig, LoadFormat +from .hf_weight_loader import load_hf_weights +from .megatron_weight_loaders import load_megatron_weights +from .model_runner import ModelRunner +from .parallel_state import ensure_model_parallel_initialized class Worker(Worker): @@ -51,7 +61,7 @@ class Worker(Worker): def __init__( self, - model: Union[nn.Module, Dict], # model itself or its parameter dict + model: Union[nn.Module, Dict], # model itself or its parameter dict model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -88,17 +98,19 @@ class Worker(Worker): if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.multimodal_config = multimodal_config # Return hidden states from target model if the draft model is an # mlp_speculator - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator"]) \ - else {"return_hidden_states": True} + speculative_args = ( + {} + if speculative_config is None + or (speculative_config.draft_model_config.model == model_config.model) + or (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) + else {"return_hidden_states": True} + ) # TODO(sgm): set correct model runner class ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner @@ -107,7 +119,7 @@ class Worker(Worker): elif self.model_config.embedding_mode: ModelRunnerClass = EmbeddingModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model, # [VERL]: add for verl + model, # [VERL]: add for verl model_config, parallel_config, scheduler_config, @@ -161,8 +173,9 @@ class Worker(Worker): raise RuntimeError(f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method, self.local_rank + ) # Set random seed. set_random_seed(self.model_config.seed) # self.model = get_model(actor_model=self.model, model_config=self.model_config) @@ -195,8 +208,10 @@ class Worker(Worker): free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = total_gpu_memory - free_gpu_memory - assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + assert peak_memory > 0, ( + "Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance." + ) cache_block_size = self.get_cache_block_size_bytes() @@ -211,15 +226,15 @@ class Worker(Worker): self.model_runner.remove_all_loras() # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank - num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda') - num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda') + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") - torch.distributed.all_reduce(num_gpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group().device_group) - torch.distributed.all_reduce(num_cpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce( + num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group + ) + torch.distributed.all_reduce( + num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group + ) num_gpu_blocks = num_gpu_blocks.item() num_cpu_blocks = num_cpu_blocks.item() gc.collect() @@ -236,19 +251,21 @@ class Worker(Worker): self.gpu_cache = None # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() - def execute_model(self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: + def execute_model( + self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None + ) -> Optional[List[SamplerOutput]]: """ Execute model in Single Program Multiple Data (SPMD) fashion. All workers take the same request, prepare the input and execute the model. """ - assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an ExecuteModelRequest" + ) worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = (self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) + model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list + ) # verl.worker.workerbase.WorkerBase # swap cache @@ -259,8 +276,10 @@ class Worker(Worker): return [] return self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, - intermediate_tensors) + model_input, + self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors, + ) # assume the input is .state_dict() def sync_model_weights(self, actor_weights: Dict, load_format: str): @@ -276,7 +295,7 @@ class Worker(Worker): if self.cpu_model == None: self.cpu_model = {} for name, params in self.model_runner.model.named_parameters(): - self.cpu_model[name] = torch.empty_like(params, device='cpu') + self.cpu_model[name] = torch.empty_like(params, device="cpu") params.data = self.cpu_model[name] else: for name, params in self.model_runner.model.named_parameters(): @@ -295,8 +314,10 @@ def init_worker_distributed_environment( # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - ensure_model_parallel_initialized(tensor_model_parallel_size=parallel_config.tensor_parallel_size, - pipeline_model_parallel_size=parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) # TODO(sgm): check whether need this # if pynccl_utils.is_initialized(): diff --git a/verl/third_party/vllm/vllm_v_0_6_3/config.py b/verl/third_party/vllm/vllm_v_0_6_3/config.py index d7cee4514..469d28a27 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/config.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/config.py @@ -42,7 +42,6 @@ class LoadFormat(str, enum.Enum): class ModelConfig(ModelConfig): - def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) self.hf_config = hf_config @@ -100,6 +99,8 @@ class LoadConfig: rocm_supported_load_format = [ f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) ] - raise ValueError(f"load format '{load_format}' is not supported in ROCm. " - f"Supported load formats are " - f"{rocm_supported_load_format}") + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}" + ) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py index a3042cabc..0e4d14d26 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py @@ -312,12 +312,13 @@ def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): param_name = _process_parameter_names(name=param_name) if parallelize_plan is not None: - assert ( - param_name - in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + assert param_name in parallelize_plan, ( + f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + ) placement = parallelize_plan[param_name] - local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, - placements=placement).to_local() + local_loaded_weights = loaded_weights.redistribute( + device_mesh=loaded_weights.device_mesh, placements=placement + ).to_local() else: local_loaded_weights = loaded_weights.full_tensor() return local_loaded_weights @@ -371,8 +372,10 @@ def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): def _get_model_weight_loader(arch: str): if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}" + ) # NOTE(sgm): we use per-parameter weight loader in each vllm sub diff --git a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py index a3e5b22b2..23304298b 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py @@ -27,7 +27,7 @@ def update_hf_weight_loader(): def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): assert isinstance(actor_weights, Dict) with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights: del actor_weights["lm_head.weight"] vllm_model.load_weights(actor_weights.items()) for _, module in vllm_model.named_modules(): diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py index 4af860e5e..bcea33f22 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/llm.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm.py @@ -13,17 +13,18 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py -from typing import Dict, List, Optional, Tuple, Union, Iterable +from typing import Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast -from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer from vllm import LLM from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.utils import Counter +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer + from .arg_utils import EngineArgs from .llm_engine_sp import LLMEngine @@ -186,8 +187,11 @@ class LLM(LLM): logprob.append(logprobs_dict[id].logprob) logprobs.append(torch.tensor(logprob)) - pad_token_id = (self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None - else self.llm_engine.tokenizer.eos_token_id) + pad_token_id = ( + self.llm_engine.tokenizer.pad_token_id + if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id + ) output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) if len(logprobs) > 0: logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py index 573423911..b815170b7 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py @@ -14,9 +14,8 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py from functools import partial -from typing import Callable, Dict, Optional, Type, Union, Iterable +from typing import Callable, Dict, Iterable, Optional, Type, Union -import torch import torch.nn as nn from vllm.config import ( CacheConfig, @@ -198,7 +197,7 @@ class LLMEngine(LLMEngine): # Ensure that the function doesn't contain a reference to self, # to avoid engine GC issues def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: - assert tokenizer_group, "tokenizer_group cannot be None, " "make sure skip_tokenizer_init is False" + assert tokenizer_group, "tokenizer_group cannot be None, make sure skip_tokenizer_init is False" return tokenizer_group.get_lora_tokenizer(sequence.lora_request) self.seq_counter = Counter() @@ -289,7 +288,8 @@ class LLMEngine(LLMEngine): lora_config, parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] if model_config.use_async_output_proc else None, - ) for v_id in range(parallel_config.pipeline_parallel_size) + ) + for v_id in range(parallel_config.pipeline_parallel_size) ] # Metric Logging. @@ -304,14 +304,12 @@ class LLMEngine(LLMEngine): from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger self.stat_loggers = { - "logging": - LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), - "prometheus": - PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len, - ), + "logging": LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ), } self.stat_loggers["prometheus"].info("cache_config", self.cache_config) @@ -335,9 +333,9 @@ class LLMEngine(LLMEngine): # TODO(sgm): add for verl but we may not tokenizer in Rollout def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): - init_kwargs = dict(enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None) + init_kwargs = dict( + enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None + ) init_kwargs.update(tokenizer_init_kwargs) return TokenizerGroup(tokenizer, **init_kwargs) @@ -355,8 +353,9 @@ class LLMEngine(LLMEngine): def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: distributed_executor_backend = engine_config.parallel_config.distributed_executor_backend # Initialize the cluster and specify the executor class.] - assert (engine_config.device_config.device_type == "cuda" - ), "Currently, the vllm in verl only support running on GPU" + assert engine_config.device_config.device_type == "cuda", ( + "Currently, the vllm in verl only support running on GPU" + ) # print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client() if engine_config.parallel_config.world_size == 1: @@ -382,8 +381,9 @@ class LLMEngine(LLMEngine): engine_config = engine_args.create_engine_config() executor_class = cls._get_executor_cls(engine_config) # Initialize the cluster and specify the executor class. - assert (engine_config.device_config.device_type == "cuda" - ), "Currently, the vllm in verl only support running on GPU" + assert engine_config.device_config.device_type == "cuda", ( + "Currently, the vllm in verl only support running on GPU" + ) from .spmd_gpu_executor import SPMDGPUExecutor diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py index 91d99c9e6..9a2e4aeaa 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -25,11 +25,14 @@ from vllm.model_executor.models import ModelRegistry # NOTE(shengguangming): replace the origin weight loader function in the class def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Parallel Linear weight loader.""" - assert (param.size() == loaded_weight.size( - )), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( - param.size(), loaded_weight.size()) - assert (param.data.dtype == loaded_weight.data.dtype - ), "if we want to shared weights, the data type should also be the same" + assert param.size() == loaded_weight.size(), ( + "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size() + ) + ) + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -37,8 +40,9 @@ def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tenso def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - assert (param.data.dtype == loaded_weight.data.dtype - ), "if we want to shared weights, the data type should also be the same" + assert param.data.dtype == loaded_weight.data.dtype, ( + "if we want to shared weights, the data type should also be the same" + ) param.data = loaded_weight.data @@ -235,7 +239,7 @@ __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { "LlamaForCausalLM": megatron_core_te_weight_loader, # use te backend for open-source megatron "LLaMAForCausalLM": megatron_core_te_weight_loader, "MistralForCausalLM": mistral_megatron_weight_loader, - 'Qwen2ForCausalLM': megatron_core_te_weight_loader, + "Qwen2ForCausalLM": megatron_core_te_weight_loader, } @@ -252,8 +256,10 @@ def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): def _get_model_weight_loader(arch: str): if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) def update_megatron_weight_loader(): diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py index f146a0eae..6791fd5f3 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py @@ -13,6 +13,7 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models """Utilities for selecting and loading models.""" + from typing import Dict, Optional, Union import torch @@ -97,8 +98,11 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: update_dtensor_weight_loader() return DummyModelLoader(load_config) - raise ValueError("load format not supported in verl: {}, only support {} and {}".format( - load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + raise ValueError( + "load format not supported in verl: {}, only support {} and {}".format( + load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF + ) + ) class DummyModelLoader(BaseModelLoader): @@ -107,8 +111,7 @@ class DummyModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def download_model(self, model_config: ModelConfig) -> None: pass @@ -138,8 +141,7 @@ class MegatronLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download @@ -169,8 +171,9 @@ class MegatronLoader(BaseModelLoader): # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm if isinstance(actor_model, nn.Module): - load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), - vllm_model=model) + load_megatron_weights( + actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model + ) else: load_megatron_weights(actor_weights=actor_model, vllm_model=model) @@ -193,8 +196,7 @@ class HFLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download @@ -241,8 +243,7 @@ class DTensorLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download @@ -272,8 +273,9 @@ class DTensorLoader(BaseModelLoader): # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm if isinstance(actor_model, nn.Module): - load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), - vllm_model=model) + load_dtensor_weights( + actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model + ) else: load_dtensor_weights(actor_weights=actor_model, vllm_model=model) @@ -294,8 +296,9 @@ class DTensorLoader(BaseModelLoader): # as they use ray, the _get_logits result will only need to return to the driver node, # therefore gather is enough. However, we use SPMD instead of a central scheduler, # all_gather is required (aligned with v0.2.6) -def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: +def _get_logits( + self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor] +) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: @@ -303,7 +306,7 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[:, :self.org_vocab_size] + logits = logits[:, : self.org_vocab_size] return logits diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py index b0cceffb5..a8ca1b6b7 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py @@ -58,7 +58,6 @@ class BatchType(IntEnum): class ModelRunner(ModelRunner): - def __init__( self, model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict @@ -77,7 +76,6 @@ class ModelRunner(ModelRunner): input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - super().__init__( model_config, parallel_config, @@ -118,9 +116,10 @@ class ModelRunner(ModelRunner): if self.lora_config: assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet." - if supports_multimodal(self.model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") + # if supports_multimodal(self.model): + # logger.warning( + # "Regarding multimodal models, vLLM currently only supports adding LoRA to language model." + # ) # It's necessary to distinguish between the max_position_embeddings # of VLMs and LLMs. if hasattr(self.model.config, "max_position_embeddings"): @@ -171,9 +170,11 @@ class ModelRunner(ModelRunner): self.model.__class__, ) else: - logger.warning("Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): from vllm.plugins import get_torch_compile_backend diff --git a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py index 0150c1c67..c7a2863b6 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py @@ -4,6 +4,7 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Model and data parallel groups.""" + import os from typing import Optional @@ -86,12 +87,14 @@ def ensure_model_parallel_initialized( assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( "tensor parallel group already initialized, but of unexpected size: " f"{get_tensor_model_parallel_world_size()=} vs. " - f"{tensor_model_parallel_size=}") + f"{tensor_model_parallel_size=}" + ) pp_world_size = get_pp_group().world_size assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size: " f"{pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}") + f"{pipeline_model_parallel_size=}" + ) # TODO(sgm): deviate from the v0.5.4, not pp now diff --git a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py index cb6a1f336..074b92d1c 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py @@ -15,7 +15,7 @@ import os import socket -from typing import Dict, List, Optional, Set, Tuple, Iterable +from typing import Iterable, List, Optional, Set, Tuple import torch from vllm.config import ( @@ -80,7 +80,7 @@ class SPMDGPUExecutor(ExecutorBase): def _init_workers_sp(self, model, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from .worker import Worker # pylint: disable=import-outside-toplevel + from .worker import Worker rank = int(os.getenv("RANK")) local_rank = int(os.getenv("LOCAL_RANK")) @@ -245,7 +245,6 @@ def get_open_port(): # TODO(sgm): not implemented async executor yet class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): - async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py index b0b4d0e27..e17f15993 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py @@ -23,8 +23,9 @@ from vllm.utils import LRUCache class TokenizerGroup(TokenizerGroup): """A group of tokenizers that can be used for LoRA adapters.""" - def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int]): + def __init__( + self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int] + ): self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = tokenizer diff --git a/verl/third_party/vllm/vllm_v_0_6_3/worker.py b/verl/third_party/vllm/vllm_v_0_6_3/worker.py index 76e216103..8dae34d1e 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/worker.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/worker.py @@ -13,9 +13,10 @@ # limitations under the License. # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py """A GPU worker class.""" + import gc import os -from typing import Dict, List, Optional, Tuple, Type, Union, Iterable +from typing import Dict, Iterable, List, Optional, Tuple, Type, Union import torch import torch.distributed @@ -102,10 +103,12 @@ class Worker(Worker): # Return hidden states from target model if the draft model is an # mlp_speculator speculative_args = ( - {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or - (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else { - "return_hidden_states": True - }) + {} + if speculative_config is None + or (speculative_config.draft_model_config.model == model_config.model) + or (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) + else {"return_hidden_states": True} + ) # TODO(sgm): set correct model runner class ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner @@ -167,8 +170,9 @@ class Worker(Worker): raise RuntimeError(f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method, self.local_rank + ) # Set random seed. set_random_seed(self.model_config.seed) # self.model = get_model(actor_model=self.model, model_config=self.model_config) @@ -201,8 +205,10 @@ class Worker(Worker): free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = total_gpu_memory - free_gpu_memory - assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + assert peak_memory > 0, ( + "Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance." + ) cache_block_size = self.get_cache_block_size_bytes() @@ -220,12 +226,12 @@ class Worker(Worker): num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") - torch.distributed.all_reduce(num_gpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group().device_group) - torch.distributed.all_reduce(num_cpu_blocks, - op=torch.distributed.ReduceOp.MIN, - group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce( + num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group + ) + torch.distributed.all_reduce( + num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group + ) num_gpu_blocks = num_gpu_blocks.item() num_cpu_blocks = num_cpu_blocks.item() gc.collect() @@ -242,19 +248,21 @@ class Worker(Worker): self.gpu_cache = None # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() - def execute_model(self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: + def execute_model( + self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None + ) -> Optional[List[SamplerOutput]]: """ Execute model in Single Program Multiple Data (SPMD) fashion. All workers take the same request, prepare the input and execute the model. """ - assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an ExecuteModelRequest" + ) worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list) + execute_model_req.seq_group_metadata_list + ) # verl.worker.workerbase.WorkerBase # swap cache diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index fd8b15305..9a68789c5 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -20,45 +20,47 @@ TODO(zhangchi.usc1992) import os -os.environ['NCCL_DEBUG'] = 'WARN' -os.environ['TOKENIZERS_PARALLELISM'] = 'true' +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" import logging import re from contextlib import nullcontext + import torch import torch.distributed -from torch import nn, optim -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload -from tqdm import tqdm -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig -from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +from peft import LoraConfig, TaskType, get_peft_model from tensordict import TensorDict -from torch.utils.data import DataLoader, DistributedSampler -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis - -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager -from verl.utils.dataset import SFTDataset -from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset -from verl.utils.fs import copy_to_local -from verl.utils.tracking import Tracking -from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group +from torch import nn, optim from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, DistributedSampler +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel import verl.utils.hdfs_io as hdfs_io +from verl.utils.dataset import SFTDataset +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset from verl.utils.debug import log_gpu_memory_usage -from peft import LoraConfig, TaskType, get_peft_model - +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn +from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup +from verl.utils.tracking import Tracking +from verl.utils.ulysses import ( + gather_outpus_and_unpad, + get_ulysses_sequence_parallel_world_size, + ulysses_pad_and_slice_inputs, +) from verl.workers.sharding_manager import FSDPUlyssesShardingManager -from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad -from verl import DataProto logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) def extract_step(path): - match = re.search(r'global_step_(\d+)', path) + match = re.search(r"global_step_(\d+)", path) if match: return int(match.group(1)) return None @@ -66,7 +68,8 @@ def extract_step(path): def convert_to_regular_types(obj): """Convert Hydra configs and other special types to regular Python types.""" - from omegaconf import ListConfig, DictConfig + from omegaconf import DictConfig, ListConfig + if isinstance(obj, (ListConfig, DictConfig)): return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) elif isinstance(obj, (list, tuple)): @@ -76,8 +79,7 @@ def convert_to_regular_types(obj): return obj -class FSDPSFTTrainer(object): - +class FSDPSFTTrainer: def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh): self.config = config self.device_mesh = device_mesh @@ -86,19 +88,20 @@ class FSDPSFTTrainer(object): # build tokenizer first local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) from verl.utils import hf_tokenizer + self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code) if self.config.data.chat_template is not None: - raise ValueError('Apply Chat template from config is not supported yet.') + raise ValueError("Apply Chat template from config is not supported yet.") # normalize dp size self._normalize_config_bsz() # Set sequence parallel size - self.config.ulysses_sequence_parallel_size = getattr(self.config, 'ulysses_sequence_parallel_size', 1) - self.use_remove_padding = getattr(self.config, 'use_remove_padding', False) + self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) + self.use_remove_padding = getattr(self.config, "use_remove_padding", False) if self.device_mesh.get_rank() == 0: - print(f'Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}') - print(f'Using remove padding: {self.use_remove_padding}') + print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") + print(f"Using remove padding: {self.use_remove_padding}") self._build_dataloader() # build model @@ -111,9 +114,11 @@ class FSDPSFTTrainer(object): def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) if self.device_mesh.get_rank() == 0: - print(f'Normalize batch size by dp {dp_size}') + print(f"Normalize batch size by dp {dp_size}") - assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + assert self.config.data.train_batch_size % dp_size == 0, ( + f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + ) self.config.data.train_batch_size //= dp_size @@ -128,59 +133,59 @@ class FSDPSFTTrainer(object): if config.data.custom_cls.get("path", None): dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name) # Then check if multi-turn dataset should be used - elif config.data.get('multiturn', {}).get('enable', False): + elif config.data.get("multiturn", {}).get("enable", False): dataset_cls = MultiTurnSFTDataset # Default to single-turn dataset else: dataset_cls = SFTDataset # Create datasets based on the selected class - self.train_dataset = dataset_cls(parquet_files=config.data.train_files, - tokenizer=self.tokenizer, - config=config.data) - self.val_dataset = dataset_cls(parquet_files=config.data.val_files, - tokenizer=self.tokenizer, - config=config.data) + self.train_dataset = dataset_cls( + parquet_files=config.data.train_files, tokenizer=self.tokenizer, config=config.data + ) + self.val_dataset = dataset_cls( + parquet_files=config.data.val_files, tokenizer=self.tokenizer, config=config.data + ) # build dataloader # Use data parallel rank and size instead of global rank and world size # If doing SP, we need to use the local rank and size if self.config.ulysses_sequence_parallel_size > 1: - rank = self.ulysses_device_mesh.get_local_rank('dp') + rank = self.ulysses_device_mesh.get_local_rank("dp") world_size = self.ulysses_device_mesh.size(0) if self.ulysses_device_mesh.get_rank() == 0: - print(f'Using SP rank {rank} and size {world_size} for data distribution') - print(f'Each SP rank gets different data, but the same data WITHIN the same rank') + print(f"Using SP rank {rank} and size {world_size} for data distribution") + print("Each SP rank gets different data, but the same data WITHIN the same rank") else: rank = self.device_mesh.get_rank() world_size = self.device_mesh.size() if self.device_mesh.get_rank() == 0: - print(f'Using FSDP rank {rank} and size {world_size} for data distribution') + print(f"Using FSDP rank {rank} and size {world_size} for data distribution") - self.train_sampler = DistributedSampler(self.train_dataset, - shuffle=True, - num_replicas=world_size, - rank=rank, - drop_last=True) - self.train_dataloader = DataLoader(dataset=self.train_dataset, - batch_size=config.data.train_batch_size, - sampler=self.train_sampler, - num_workers=8, - pin_memory=True, - drop_last=True) + self.train_sampler = DistributedSampler( + self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True + ) + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=config.data.train_batch_size, + sampler=self.train_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + ) - self.val_sampler = DistributedSampler(self.val_dataset, - shuffle=False, - num_replicas=world_size, - rank=rank, - drop_last=True) - self.val_dataloader = DataLoader(dataset=self.val_dataset, - batch_size=config.data.micro_batch_size_per_gpu, - sampler=self.val_sampler, - num_workers=8, - pin_memory=True, - drop_last=True) + self.val_sampler = DistributedSampler( + self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True + ) + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=config.data.micro_batch_size_per_gpu, + sampler=self.val_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + ) def _build_model_optimizer(self): # TODO (zhangchi.usc1992): @@ -188,12 +193,13 @@ class FSDPSFTTrainer(object): # 2. support init directly from sharded weights local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) - if self.config.model.get('external_lib', None) is not None: + if self.config.model.get("external_lib", None) is not None: # This is used to import external_lib into the huggingface systems import importlib + importlib.import_module(self.config.model.external_lib) - log_gpu_memory_usage('Before model allocation', logger=logger) + log_gpu_memory_usage("Before model allocation", logger=logger) trust_remote_code = self.config.model.trust_remote_code # load config first @@ -202,49 +208,56 @@ class FSDPSFTTrainer(object): assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" # This may be very large - init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, - mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(): - self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, - config=config, - torch_dtype=torch.float32, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + local_model_path, + config=config, + torch_dtype=torch.float32, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) # Apply Liger kernel if use_liger is enabled - if self.config.model.get('use_liger', False): + if self.config.model.get("use_liger", False): from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + _apply_liger_kernel_to_instance(model=self.model) - if self.config.model.get('lora_rank', 0) > 0: + if self.config.model.get("lora_rank", 0) > 0: self.model.enable_input_require_grads() # Convert config to regular Python types before creating PEFT model lora_config = { - 'task_type': TaskType.CAUSAL_LM, - 'r': self.config.model.lora_rank, - 'lora_alpha': self.config.model.lora_alpha, - 'target_modules': convert_to_regular_types(self.config.model.target_modules), - 'bias': "none" + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", } self.model = get_peft_model(self.model, LoraConfig(**lora_config)) if self.config.model.enable_gradient_checkpointing: - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - log_gpu_memory_usage('After model allocation', logger=logger) + log_gpu_memory_usage("After model allocation", logger=logger) - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) - auto_wrap_policy = get_fsdp_wrap_policy(self.model, - config=self.config.model.fsdp_config.wrap_policy, - is_lora=self.config.model.get('lora_rank', 0) > 0) + auto_wrap_policy = get_fsdp_wrap_policy( + self.model, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) if self.device_mesh.get_rank() == 0: print(auto_wrap_policy) @@ -253,69 +266,72 @@ class FSDPSFTTrainer(object): else: cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) - self.fsdp_model = FSDP(module=self.model, - auto_wrap_policy=auto_wrap_policy, - param_init_fn=init_fn, - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - device_mesh=self.device_mesh, - sync_module_states=True, - device_id=torch.cuda.current_device(), - cpu_offload=cpu_offload, - use_orig_params=False) + self.fsdp_model = FSDP( + module=self.model, + auto_wrap_policy=auto_wrap_policy, + param_init_fn=init_fn, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=self.device_mesh, + sync_module_states=True, + device_id=torch.cuda.current_device(), + cpu_offload=cpu_offload, + use_orig_params=False, + ) - log_gpu_memory_usage('After FSDP wrapping', logger=logger) + log_gpu_memory_usage("After FSDP wrapping", logger=logger) - self.optimizer = optim.AdamW(self.fsdp_model.parameters(), - lr=self.config.optim.lr, - betas=self.config.optim.betas, - weight_decay=self.config.optim.weight_decay) + self.optimizer = optim.AdamW( + self.fsdp_model.parameters(), + lr=self.config.optim.lr, + betas=self.config.optim.betas, + weight_decay=self.config.optim.weight_decay, + ) - log_gpu_memory_usage('After initialize optimizer', logger=logger) + log_gpu_memory_usage("After initialize optimizer", logger=logger) self.steps_per_epoch = len(self.train_dataloader) self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs if self.device_mesh.get_rank() == 0: print( - f'Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}' + f"Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}" ) num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) - if not hasattr(self.config.optim, 'lr_scheduler') or self.config.optim.lr_scheduler == 'cosine': - self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=self.total_steps) - elif self.config.optim.lr_scheduler == 'wsd': - self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=self.total_steps) + if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) + elif self.config.optim.lr_scheduler == "wsd": + self.lr_scheduler = get_wsd_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) else: - raise ValueError(f'Unknown lr scheduler: {self.config.optim.lr_scheduler}') + raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") def _compute_loss_and_backward(self, batch, do_backward=True): """Compute loss with optional sequence parallelism and remove padding features""" use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch['input_ids'].cuda() - attention_mask = batch['attention_mask'].cuda() - position_ids = batch['position_ids'].cuda() - loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() - loss_fct = nn.CrossEntropyLoss(reduction='none') + input_ids = batch["input_ids"].cuda() + attention_mask = batch["attention_mask"].cuda() + position_ids = batch["position_ids"].cuda() + loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).cuda() + loss_fct = nn.CrossEntropyLoss(reduction="none") # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() with context: - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() - output = self.fsdp_model(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False) + output = self.fsdp_model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) logits = output.logits shift_logits = logits[..., :-1, :].contiguous() @@ -336,21 +352,25 @@ class FSDPSFTTrainer(object): batch_size, seqlen = input_ids.shape # Remove padding - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # Unpad position_ids to align rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # Pad and slice inputs for sequence parallelism input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) # For computing loss input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() + ) input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) # Forward pass @@ -358,7 +378,8 @@ class FSDPSFTTrainer(object): input_ids=input_ids_rmpad_sliced, attention_mask=None, # Not needed with flash attention varlen position_ids=position_ids_rmpad_padded, - use_cache=False) + use_cache=False, + ) # Compute loss locally then aggregate logits_rmpad = output.logits.squeeze(0) @@ -368,10 +389,9 @@ class FSDPSFTTrainer(object): loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) # This is the loss collected from all ulysses ranks - full_loss = pad_input(hidden_states=loss.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) + full_loss = pad_input( + hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss full_loss = full_loss.reshape(-1) loss_mask = loss_mask.to(full_loss.device) @@ -381,7 +401,7 @@ class FSDPSFTTrainer(object): if self.config.data.balance_dp_token: torch.distributed.all_reduce(valid_token_this_rank) - dp_size = self.ulysses_device_mesh.size('dp') if use_sp else torch.distributed.get_world_size() + dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() else: dp_size = 1 @@ -394,11 +414,11 @@ class FSDPSFTTrainer(object): def training_step(self, batch: TensorDict): self.fsdp_model.train() - log_gpu_memory_usage('Before optimizer zero_grad', logger=logger) + log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) self.optimizer.zero_grad() - log_gpu_memory_usage('After optimizer zero_grad', logger=logger) + log_gpu_memory_usage("After optimizer zero_grad", logger=logger) micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) n_micro_batches = len(micro_batches) @@ -409,7 +429,7 @@ class FSDPSFTTrainer(object): grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) - log_gpu_memory_usage('Before optimizer step', logger=logger) + log_gpu_memory_usage("Before optimizer step", logger=logger) # if grad_norm is not finite, skip the update if not torch.isfinite(grad_norm): @@ -418,18 +438,18 @@ class FSDPSFTTrainer(object): else: self.optimizer.step() - log_gpu_memory_usage('After optimizer step', logger=logger) + log_gpu_memory_usage("After optimizer step", logger=logger) self.lr_scheduler.step() # reduce loss across dp ranks lr = self.lr_scheduler.get_last_lr()[0] - log_gpu_memory_usage('After offload weights', logger=logger) + log_gpu_memory_usage("After offload weights", logger=logger) step_loss = torch.tensor(step_loss).cuda() torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) - return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} + return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} def validation_step(self, batch: TensorDict): self.fsdp_model.eval() @@ -441,11 +461,12 @@ class FSDPSFTTrainer(object): def save_checkpoint(self, step): # save checkpoint from torch.distributed.fsdp import FullStateDictConfig, StateDictType + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): state_dict = self.fsdp_model.state_dict() - path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}') + path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") # save huggingface model if self.device_mesh.get_rank() == 0: os.makedirs(path, exist_ok=True) @@ -461,9 +482,11 @@ class FSDPSFTTrainer(object): # TODO: add a unified tracking if rank == 0: - tracking = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger) + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + ) global_step = 0 # compute the total training steps. @@ -474,15 +497,17 @@ class FSDPSFTTrainer(object): total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps - print(f'Total training steps: {self.total_training_steps}') + print(f"Total training steps: {self.total_training_steps}") # TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow. for epoch in range(self.config.trainer.total_epochs): self.train_sampler.set_epoch(epoch=epoch) - for data in tqdm(self.train_dataloader, - total=self.steps_per_epoch, - desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): + for data in tqdm( + self.train_dataloader, + total=self.steps_per_epoch, + desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", + ): global_step += 1 data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() metric = self.training_step(data) @@ -499,7 +524,7 @@ class FSDPSFTTrainer(object): val_losses.append(val_loss) if rank == 0: avg_val_loss = torch.mean(torch.stack(val_losses)) - metric = {'val/loss': avg_val_loss.detach().item()} + metric = {"val/loss": avg_val_loss.detach().item()} tracking.log(data=metric, step=global_step) torch.distributed.barrier() @@ -515,7 +540,7 @@ class FSDPSFTTrainer(object): val_losses.append(val_loss) if rank == 0: val_loss = torch.mean(torch.stack(val_losses)) - metric = {'val/loss': val_loss.detach().item()} + metric = {"val/loss": val_loss.detach().item()} tracking.log(data=metric, step=global_step) torch.distributed.barrier() @@ -523,26 +548,25 @@ class FSDPSFTTrainer(object): self.save_checkpoint(step=global_step) -from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer import hydra - from torch.distributed.device_mesh import init_device_mesh +from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer from verl.utils.distributed import initialize_global_process_group -@hydra.main(config_path='config', config_name='sft_trainer', version_base=None) +@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) def main(config): local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type='cuda', - mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), - mesh_dim_names=('dp', 'sp')) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") + ) trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) trainer.fit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/verl/trainer/main_eval.py b/verl/trainer/main_eval.py index 3868e5a61..475d40749 100644 --- a/verl/trainer/main_eval.py +++ b/verl/trainer/main_eval.py @@ -17,17 +17,22 @@ The input is a parquet file that contains N generated sequences and (optional) t """ -import hydra -from verl.utils.fs import copy_to_local -import pandas as pd -import numpy as np -from tqdm import tqdm from collections import defaultdict + +import hydra +import numpy as np +import pandas as pd import ray +from tqdm import tqdm + +from verl.utils.fs import copy_to_local def get_custom_reward_fn(config): - import importlib.util, os, sys + import importlib.util + import os + import sys + reward_fn_config = config.get("custom_reward_function") or {} file_path = reward_fn_config.get("path") if not file_path: @@ -61,12 +66,12 @@ def get_custom_reward_fn(config): @ray.remote def process_item(reward_fn, data_source, response_lst, reward_data): - ground_truth = reward_data['ground_truth'] + ground_truth = reward_data["ground_truth"] score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] return data_source, np.mean(score_lst) -@hydra.main(config_path='config', config_name='evaluation', version_base=None) +@hydra.main(config_path="config", config_name="evaluation", version_base=None) def main(config): local_path = copy_to_local(config.data.path) dataset = pd.read_parquet(local_path) @@ -102,10 +107,10 @@ def main(config): metric_dict = {} for data_source, rewards in data_source_reward.items(): - metric_dict[f'test_score/{data_source}'] = np.mean(rewards) + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) print(metric_dict) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 50be66aab..1f41ea2cf 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -14,39 +14,41 @@ """ Generate responses given a dataset of prompts """ -import ray -import numpy as np -import hydra + import os -os.environ['NCCL_DEBUG'] = 'WARN' -os.environ['TOKENIZERS_PARALLELISM'] = 'true' +import hydra +import numpy as np +import ray + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" # os.environ['TORCH_COMPILE_DISABLE'] = '1' -import pandas as pd from pprint import pprint + +import pandas as pd from omegaconf import OmegaConf from verl import DataProto -from verl.utils.fs import copy_to_local -from verl.workers.fsdp_workers import ActorRolloutRefWorker -from verl.utils.hdfs_io import makedirs -from verl.utils import hf_tokenizer from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.utils.model import compute_position_id_with_mask from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.hdfs_io import makedirs +from verl.utils.model import compute_position_id_with_mask +from verl.workers.fsdp_workers import ActorRolloutRefWorker -@hydra.main(config_path='config', config_name='generation', version_base=None) +@hydra.main(config_path="config", config_name="generation", version_base=None) def main(config): run_generation(config) def run_generation(config) -> None: - if not ray.is_initialized(): # this is for local ray cluster - ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) ray.get(main_task.remote(config)) @@ -57,11 +59,11 @@ def main_task(config): OmegaConf.resolve(config) local_path = copy_to_local(config.model.path) - trust_remote_code = config.data.get('trust_remote_code', False) + trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - if config.rollout.temperature == 0.: - assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.' + if config.rollout.temperature == 0.0: + assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1." assert config.data.n_samples >= 1, "n_samples should always >= 1" # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) @@ -70,11 +72,11 @@ def main_task(config): chat_lst = [chat.tolist() for chat in chat_lst] - tokenizer.padding_side = 'left' + tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout') + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) wg.init_model() @@ -85,26 +87,28 @@ def main_task(config): output_lst = [[] for _ in range(config.data.n_samples)] for batch_idx in range(num_batch): - print(f'[{batch_idx+1}/{num_batch}] Start to process.') - batch_chat_lst = chat_lst[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size] - inputs = tokenizer.apply_chat_template(batch_chat_lst, - add_generation_prompt=True, - padding=True, - truncation=True, - max_length=config.rollout.prompt_length, - return_tensors='pt', - return_dict=True, - tokenize=True) - input_ids = inputs['input_ids'] - attention_mask = inputs['attention_mask'] + print(f"[{batch_idx + 1}/{num_batch}] Start to process.") + batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] + inputs = tokenizer.apply_chat_template( + batch_chat_lst, + add_generation_prompt=True, + padding=True, + truncation=True, + max_length=config.rollout.prompt_length, + return_tensors="pt", + return_dict=True, + tokenize=True, + ) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] position_ids = compute_position_id_with_mask(attention_mask) - batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids} + batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} data = DataProto.from_dict(batch_dict) data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) # START TO GENERATE FOR n_samples TIMES - print(f'[{batch_idx+1}/{num_batch}] Start to generate.') + print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") for n_sample in range(config.data.n_samples): output_padded = wg.generate_sequences(data_padded) output = unpad_dataproto(output_padded, pad_size=pad_size) @@ -112,9 +116,9 @@ def main_task(config): output_texts = [] for i in range(len(output)): data_item = output[i] - prompt_length = data_item.batch['prompts'].shape[-1] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() - valid_response_ids = data_item.batch['responses'][:valid_response_length] + prompt_length = data_item.batch["prompts"].shape[-1] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = data_item.batch["responses"][:valid_response_length] response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) output_texts.append(response_str) @@ -125,7 +129,7 @@ def main_task(config): output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() # add to the data frame - dataset['responses'] = output_lst + dataset["responses"] = output_lst # write to a new parquet output_dir = os.path.dirname(config.data.output_path) @@ -133,5 +137,5 @@ def main_task(config): dataset.to_parquet(config.data.output_path) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 773f230aa..6ac3b288f 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -14,15 +14,19 @@ """ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ -from verl.trainer.ppo.ray_trainer import RayPPOTrainer import os -import ray + import hydra +import ray + +from verl.trainer.ppo.ray_trainer import RayPPOTrainer def get_custom_reward_fn(config): - import importlib.util, sys + import importlib.util + import sys + reward_fn_config = config.get("custom_reward_function") or {} file_path = reward_fn_config.get("path") if not file_path: @@ -54,7 +58,7 @@ def get_custom_reward_fn(config): return wrapped_fn -@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None) +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) def main(config): run_ppo(config) @@ -62,16 +66,14 @@ def main(config): def run_ppo(config) -> None: # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '') + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") if not ray.is_initialized(): # this is for local ray cluster - ray.init(runtime_env={ - 'env_vars': { - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN', - 'VLLM_LOGGING_LEVEL': 'WARN' + ray.init( + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} } - }) + ) runner = TaskRunner.remote() ray.get(runner.run.remote(config)) @@ -79,12 +81,14 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - def run(self, config): - from verl.utils.fs import copy_to_local # print initial config from pprint import pprint + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) @@ -92,22 +96,25 @@ class TaskRunner: local_path = copy_to_local(config.actor_rollout_ref.model.path) # instantiate tokenizer - from verl.utils import hf_tokenizer, hf_processor - trust_remote_code = config.data.get('trust_remote_code', False) + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes - if config.actor_rollout_ref.actor.strategy == 'fsdp': + if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = RayWorkerGroup - elif config.actor_rollout_ref.actor.strategy == 'megatron': + elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -120,7 +127,7 @@ class TaskRunner: Role.Critic: ray.remote(CriticWorker), } - global_pool_id = 'global_pool' + global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } @@ -136,63 +143,69 @@ class TaskRunner: # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == 'fsdp': + if config.reward_model.strategy == "fsdp": from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == 'megatron': + elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) mapping[Role.RewardModel] = global_pool_id - #use reference model + # use reference model if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == 'naive': + if reward_manager_name == "naive": from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager - elif reward_manager_name == 'prime': + elif reward_manager_name == "prime": from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager - elif reward_manager_name == 'batch': + elif reward_manager_name == "batch": from verl.workers.reward_manager import BatchRewardManager + reward_manager_cls = BatchRewardManager - elif reward_manager_name == 'dapo': + elif reward_manager_name == "dapo": from verl.workers.reward_manager import DAPORewardManager + reward_manager_cls = DAPORewardManager else: - raise NotImplementedError compute_score = get_custom_reward_fn(config) reward_kwargs = dict(config.reward_model.get("reward_kwargs", {})) - reward_fn = reward_manager_cls(tokenizer=tokenizer, - num_examine=0, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, - **reward_kwargs) + reward_fn = reward_manager_cls( + tokenizer=tokenizer, + num_examine=0, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, - num_examine=1, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key) + val_reward_fn = reward_manager_cls( + tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key + ) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - trainer = RayPPOTrainer(config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) trainer.init_workers() trainer.fit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 697a63c52..64998dc72 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -18,9 +18,10 @@ The function implemented in this file should be used by trainer with different d implement PPO """ +from collections import defaultdict + import numpy as np import torch -from collections import defaultdict import verl.utils.torch_functional as verl_F @@ -54,17 +55,22 @@ class FixedKLController: def get_kl_controller(kl_ctrl): - if kl_ctrl.type == 'fixed': + if kl_ctrl.type == "fixed": return FixedKLController(kl_coef=kl_ctrl.kl_coef) - elif kl_ctrl.type == 'adaptive': - assert kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {kl_ctrl.horizon}' + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) else: raise NotImplementedError -def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, - gamma: torch.Tensor, lam: torch.Tensor): +def compute_gae_advantage_return( + token_level_rewards: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + gamma: torch.Tensor, + lam: torch.Tensor, +): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py Args: @@ -104,19 +110,18 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. -def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6): +def compute_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6 +): """ - Compute advantage for GRPO, operating only on Outcome reward + Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). Args: token_level_rewards: `(torch.Tensor)` shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - + Returns: advantages: `(torch.Tensor)` shape: (bs, response_length) @@ -149,19 +154,18 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, return scores, scores -def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: torch.Tensor, - epsilon: float = 1e-6): +def compute_reinforce_plus_plus_baseline_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6 +): """ - Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward + Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward (with only one scalar reward for each response). Args: token_level_rewards: `(torch.Tensor)` shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - + Returns: advantages: `(torch.Tensor)` shape: (bs, response_length) @@ -194,10 +198,9 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: return scores, scores -def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6): +def compute_rloo_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6 +): """ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 Args: @@ -231,24 +234,26 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, for i in range(bsz): response_num = len(id2score[index[i]]) if response_num > 1: - scores[i] = scores[i] * response_num / (response_num - - 1) - id2mean[index[i]] * response_num / (response_num - 1) + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( + response_num - 1 + ) scores = scores.unsqueeze(-1) * response_mask return scores, scores -def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, - gamma: torch.Tensor): +def compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor +): """ - Compute advantage for REINFORCE++. + Compute advantage for REINFORCE++. This implementation is based on the paper: https://arxiv.org/abs/2501.03262 Args: token_level_rewards: `(torch.Tensor)` shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - + Returns: advantages: `(torch.Tensor)` shape: (bs, response_length) @@ -272,10 +277,11 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten return advantages, returns -def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, - response_mask: torch.Tensor): +def compute_remax_outcome_advantage( + token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor +): """ - Compute advantage for ReMax, operating only on Outcome reward + Compute advantage for ReMax, operating only on Outcome reward This implementation is based on the paper: https://arxiv.org/abs/2310.10505 (with only one scalar reward for each response). @@ -286,7 +292,7 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba shape: (bs,) response_mask: `(torch.Tensor)` shape: (bs, response_length) - + Returns: advantages: `(torch.Tensor)` shape: (bs, response_length) @@ -334,15 +340,17 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str return loss -def compute_policy_loss(old_log_prob, - log_prob, - advantages, - response_mask, - cliprange=None, - cliprange_low=None, - cliprange_high=None, - clip_ratio_c=3.0, - loss_agg_mode="token-mean"): +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode="token-mean", +): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: old_log_prob: `(torch.Tensor)` @@ -362,7 +370,7 @@ def compute_policy_loss(old_log_prob, clip_ratio_c: (float) default: 3.0 The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729 loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean" - "token-mean" is the default behavior + "token-mean" is the default behavior Returns: pg_loss: `a scalar torch.Tensor` @@ -374,7 +382,9 @@ def compute_policy_loss(old_log_prob, pg_clipfrac_lower: (float) the fraction of policy gradient loss being clipped when the advantage is negative """ - assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}." + assert clip_ratio_c > 1.0, ( + f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}." + ) negative_approx_kl = log_prob - old_log_prob ratio = torch.exp(negative_approx_kl) @@ -385,16 +395,19 @@ def compute_policy_loss(old_log_prob, cliprange_low = cliprange if cliprange_high is None: cliprange_high = cliprange - pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, - 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A - clip_pg_losses1 = torch.maximum(pg_losses1, - pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) pg_losses3 = -advantages * clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) pg_clipfrac_lower = verl_F.masked_mean( - torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask) + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) @@ -440,8 +453,8 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): """ vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) - vf_losses1 = (vpreds - returns)**2 - vf_losses2 = (vpredclipped - returns)**2 + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac @@ -469,7 +482,7 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe # J. Schulman. Approximating kl divergence, 2020. # # URL http://joschu.net/blog/kl-approx.html. - if kl_penalty == 'low_var_kl': + if kl_penalty == "low_var_kl": kl = ref_logprob - logprob ratio = torch.exp(kl) kld = (ratio - kl - 1).contiguous() diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 8f59e7989..42d5a284a 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -15,12 +15,14 @@ Metrics related to the PPO trainer. """ -import torch -from typing import Any, Dict, List, Callable -import numpy as np -from verl import DataProto -from collections import Counter, defaultdict +from collections import defaultdict from functools import partial +from typing import Any, Callable, Dict, List + +import numpy as np +import torch + +from verl import DataProto def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -30,10 +32,10 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: def _compute_response_info(batch: DataProto) -> Dict[str, Any]: - response_length = batch.batch['responses'].shape[-1] + response_length = batch.batch["responses"].shape[-1] - prompt_mask = batch.batch['attention_mask'][:, :-response_length] - response_mask = batch.batch['attention_mask'][:, -response_length:] + prompt_mask = batch.batch["attention_mask"][:, :-response_length] + response_mask = batch.batch["attention_mask"][:, -response_length:] prompt_length = prompt_mask.sum(-1).float() response_length = response_mask.sum(-1).float() # (batch_size,) @@ -47,134 +49,117 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]: def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]: # TODO: add response length - sequence_score = batch.batch['token_level_scores'].sum(-1) - sequence_reward = batch.batch['token_level_rewards'].sum(-1) + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) - advantages = batch.batch['advantages'] - returns = batch.batch['returns'] + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] - max_response_length = batch.batch['responses'].shape[-1] + max_response_length = batch.batch["responses"].shape[-1] - prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() - response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) - prompt_length = response_info['prompt_length'] - response_length = response_info['response_length'] + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) if use_critic: - values = batch.batch['values'] + values = batch.batch["values"] valid_values = torch.masked_select(values, response_mask) return_diff_var = torch.var(valid_returns - valid_values) return_var = torch.var(valid_returns) metrics = { # score - 'critic/score/mean': - torch.mean(sequence_score).detach().item(), - 'critic/score/max': - torch.max(sequence_score).detach().item(), - 'critic/score/min': - torch.min(sequence_score).detach().item(), + "critic/score/mean": torch.mean(sequence_score).detach().item(), + "critic/score/max": torch.max(sequence_score).detach().item(), + "critic/score/min": torch.min(sequence_score).detach().item(), # reward - 'critic/rewards/mean': - torch.mean(sequence_reward).detach().item(), - 'critic/rewards/max': - torch.max(sequence_reward).detach().item(), - 'critic/rewards/min': - torch.min(sequence_reward).detach().item(), + "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), + "critic/rewards/max": torch.max(sequence_reward).detach().item(), + "critic/rewards/min": torch.min(sequence_reward).detach().item(), # adv - 'critic/advantages/mean': - torch.mean(valid_adv).detach().item(), - 'critic/advantages/max': - torch.max(valid_adv).detach().item(), - 'critic/advantages/min': - torch.min(valid_adv).detach().item(), + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), # returns - 'critic/returns/mean': - torch.mean(valid_returns).detach().item(), - 'critic/returns/max': - torch.max(valid_returns).detach().item(), - 'critic/returns/min': - torch.min(valid_returns).detach().item(), - **({ - # values - 'critic/values/mean': torch.mean(valid_values).detach().item(), - 'critic/values/max': torch.max(valid_values).detach().item(), - 'critic/values/min': torch.min(valid_values).detach().item(), - # vf explained var - 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), - } if use_critic else {}), - + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), # response length - 'response_length/mean': - torch.mean(response_length).detach().item(), - 'response_length/max': - torch.max(response_length).detach().item(), - 'response_length/min': - torch.min(response_length).detach().item(), - 'response_length/clip_ratio': - torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), # prompt length - 'prompt_length/mean': - torch.mean(prompt_length).detach().item(), - 'prompt_length/max': - torch.max(prompt_length).detach().item(), - 'prompt_length/min': - torch.min(prompt_length).detach().item(), - 'prompt_length/clip_ratio': - torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } return metrics def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: response_info = _compute_response_info(batch) - num_prompt_tokens = torch.sum(response_info['prompt_length']).item() - num_response_tokens = torch.sum(response_info['response_length']).item() + num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() + num_response_tokens = torch.sum(response_info["response_length"]).item() num_overall_tokens = num_prompt_tokens + num_response_tokens num_tokens_of_section = { - 'gen': num_response_tokens, - **{ - name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor'] - }, + "gen": num_response_tokens, + **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, } return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ - f'timing_s/{name}': value for name, value in timing_raw.items() - }, - **{ - f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( - )) & set(timing_raw.keys()) + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) }, } def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: - total_num_tokens = sum(batch.meta_info['global_token_num']) - time = timing_raw['step'] + total_num_tokens = sum(batch.meta_info["global_token_num"]) + time = timing_raw["step"] # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time) # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus), # f'Theoretical TFLOPs/s/GPU​': promised_flops, return { - 'perf/total_num_tokens': total_num_tokens, - 'perf/time_per_step': time, - 'perf/throughput': total_num_tokens / (time * n_gpus), + "perf/total_num_tokens": total_num_tokens, + "perf/time_per_step": time, + "perf/throughput": total_num_tokens / (time * n_gpus), } -def bootstrap_metric(data: list[Any], - subset_size: int, - reduce_fns: list[Callable[[np.ndarray], float]], - n_bootstrap: int = 1000, - seed: int = 42) -> list[tuple[float, float]]: +def bootstrap_metric( + data: list[Any], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000, + seed: int = 42, +) -> list[tuple[float, float]]: np.random.seed(seed) bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] @@ -202,17 +187,16 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo return maj_val -def process_validation_metrics(data_sources: list[str], - sample_inputs: list[str], - infos_dict: dict[str, list[Any]], - seed: int = 42) -> dict[str, dict[str, dict[str, float]]]: +def process_validation_metrics( + data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 +) -> dict[str, dict[str, dict[str, float]]]: """Process validation metrics into a structured format. - + Args: data_sources: Array of data source identifiers for each sample sample_inputs: List of input prompts infos_dict: variable name -> list of values for each sample - + Returns: dict[str, dict[str, dict[str, float]]]: data source -> variable name -> metric value """ @@ -245,20 +229,20 @@ def process_validation_metrics(data_sources: list[str], for n in ns: # Best/Worst-of-N - [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(data=var_vals, - subset_size=n, - reduce_fns=[np.max, np.min], - seed=seed) + [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( + data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed + ) metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std # Majority voting if var2vals.get("pred", None) is not None: vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])] - [(maj_n_mean, maj_n_std) - ] = bootstrap_metric(data=vote_data, - subset_size=n, - reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], - seed=seed) + [(maj_n_mean, maj_n_std)] = bootstrap_metric( + data=vote_data, + subset_size=n, + reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + seed=seed, + ) metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std data_src2prompt2var2metric[data_source][prompt][var_name] = metric diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 6898af2b9..c72e6280d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -18,36 +18,40 @@ This trainer supports model-agonistic model initialization with huggingface import os import uuid -import warnings +from collections import defaultdict from contextlib import contextmanager +from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Type, Dict -from copy import deepcopy -from collections import defaultdict -from functools import partial -from tqdm import tqdm +from typing import Dict, Type -import ray import numpy as np +import ray from codetiming import Timer from omegaconf import OmegaConf, open_dict +from torch.utils.data import Dataset, RandomSampler, SequentialSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + from verl import DataProto from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.base import Worker -from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import agg_loss -from verl.utils.py_functional import append_to_dict -from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, bootstrap_metric, calc_maj_val, process_validation_metrics -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, + reduce_metrics, +) from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.tracking import ValidationGenerationsLogger -from torch.utils.data import Dataset, RandomSampler, SequentialSampler -from torchdata.stateful_dataloader import StatefulDataLoader WorkerType = Type[Worker] @@ -56,6 +60,7 @@ class Role(Enum): """ To create more roles dynamically, you can subclass Role and add new members """ + Actor = 0 Rollout = 1 ActorRollout = 2 @@ -69,12 +74,13 @@ class AdvantageEstimator(str, Enum): """ Using an enumeration class to avoid spelling errors in adv_estimator """ - GAE = 'gae' - GRPO = 'grpo' - REINFORCE_PLUS_PLUS = 'reinforce_plus_plus' - REINFORCE_PLUS_PLUS_BASELINE = 'reinforce_plus_plus_baseline' - REMAX = 'remax' - RLOO = 'rloo' + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" @dataclass @@ -83,6 +89,7 @@ class ResourcePoolManager: Define a resource pool specification. Resource pool will be initialized first. Mapping """ + resource_pool_spec: dict[str, list[int]] mapping: dict[Role, str] resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) @@ -92,10 +99,9 @@ class ResourcePoolManager: # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, - use_gpu=True, - max_colocate_count=1, - name_prefix=resource_pool_name) + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) self.resource_pool_dict[resource_pool_name] = resource_pool self._check_resource_available() @@ -111,15 +117,17 @@ class ResourcePoolManager: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get('GPU', 0) for node, node_info in node_available_resources.items()} + node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) if total_available_gpus < total_required_gpus: raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}") + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) # check each resource pool can be satisfied, O(#resource_pools * #nodes) for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): @@ -137,21 +145,23 @@ class ResourcePoolManager: import torch + from verl.utils.torch_functional import masked_mean -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'): - responses = data.batch['responses'] +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + responses = data.batch["responses"] response_length = responses.size(1) - token_level_scores = data.batch['token_level_scores'] + token_level_scores = data.batch["token_level_scores"] batch_size = data.batch.batch_size[0] - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] response_mask = attention_mask[:, -response_length:] # compute kl between ref_policy and current policy # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'], - kl_penalty=kl_penalty) # (batch_size, response_length) + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) kld = kld * response_mask beta = kl_ctrl.value @@ -162,72 +172,78 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch['token_level_rewards'] = token_level_rewards + data.batch["token_level_rewards"] = token_level_rewards - metrics = {'actor/reward_kl_penalty': current_kl, 'actor/reward_kl_penalty_coeff': beta} + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} return data, metrics def compute_response_mask(data: DataProto): - responses = data.batch['responses'] + responses = data.batch["responses"] response_length = responses.size(1) - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] return attention_mask[:, -response_length:] def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): # Back-compatible with trainers that do not compute response mask in fit if "response_mask" not in data.batch.keys(): - data.batch['response_mask'] = compute_response_mask(data) + data.batch["response_mask"] = compute_response_mask(data) # prepare response group # TODO: add other ways to estimate advantages if adv_estimator == AdvantageEstimator.GAE: - values = data.batch['values'] + values = data.batch["values"] advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch['token_level_rewards'], - values=data.batch['values'], - response_mask=data.batch['response_mask'], + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], gamma=gamma, - lam=lam) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + lam=lam, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.GRPO: advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch['token_level_rewards'], - response_mask=data.batch['response_mask'], - index=data.non_tensor_batch['uid']) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + index=data.non_tensor_batch["uid"], + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE: advantages, returns = core_algos.compute_reinforce_plus_plus_baseline_outcome_advantage( - token_level_rewards=data.batch['token_level_rewards'], - response_mask=data.batch['response_mask'], - index=data.non_tensor_batch['uid']) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + index=data.non_tensor_batch["uid"], + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch['token_level_rewards'], - response_mask=data.batch['response_mask'], - gamma=gamma) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + gamma=gamma, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.REMAX: advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch['token_level_rewards'], - reward_baselines=data.batch['reward_baselines'], - response_mask=data.batch['response_mask']) + token_level_rewards=data.batch["token_level_rewards"], + reward_baselines=data.batch["reward_baselines"], + response_mask=data.batch["response_mask"], + ) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + data.batch["advantages"] = advantages + data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.RLOO: advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch['token_level_rewards'], - response_mask=data.batch['response_mask'], - index=data.non_tensor_batch['uid']) - data.batch['advantages'] = advantages - data.batch['returns'] = returns + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + index=data.non_tensor_batch["uid"], + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns else: raise NotImplementedError return data @@ -242,23 +258,24 @@ def _timer(name: str, timing_raw: Dict[str, float]): timing_raw[name] += timer.last -class RayPPOTrainer(object): +class RayPPOTrainer: """ Note that this trainer runs on the driver process on a single CPU/GPU node. """ # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role - def __init__(self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None): - + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + ): # assert torch.cuda.is_available(), 'cuda must be available on driver' self.tokenizer = tokenizer @@ -268,10 +285,10 @@ class RayPPOTrainer(object): self.val_reward_fn = val_reward_fn self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, 'Currently, only support hybrid engine' + assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}' + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager @@ -288,8 +305,11 @@ class RayPPOTrainer(object): if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True elif self.config.algorithm.adv_estimator in [ - AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX, - AdvantageEstimator.RLOO, AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE + AdvantageEstimator.GRPO, + AdvantageEstimator.REINFORCE_PLUS_PLUS, + AdvantageEstimator.REMAX, + AdvantageEstimator.RLOO, + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, ]: self.use_critic = False else: @@ -305,8 +325,9 @@ class RayPPOTrainer(object): # 1. Check total batch size for data correctness real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % n_gpus == 0, \ + assert real_train_batch_size % n_gpus == 0, ( f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + ) # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". @@ -325,7 +346,8 @@ class RayPPOTrainer(object): if mbs is None and mbs_per_gpu is None: raise ValueError( - f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." + ) if mbs is not None and mbs_per_gpu is not None: raise ValueError( @@ -335,30 +357,38 @@ class RayPPOTrainer(object): if not config.actor_rollout_ref.actor.use_dynamic_bsz: # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor") + check_mutually_exclusive( + config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor", + ) if self.use_reference_policy: # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref") + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout") + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) if self.use_critic and not config.critic.use_dynamic_bsz: # Check for critic micro-batch size conflicts - check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, - "critic") + check_mutually_exclusive( + config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + ) # Check for reward model micro-batch size conflicts if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, - "reward_model") + check_mutually_exclusive( + config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + ) # Actor # check if train_batch_size is larger than ppo_mini_batch_size @@ -367,58 +397,72 @@ class RayPPOTrainer(object): # ppo_micro_batch_size * sequence_parallel_size >= n_gpus if not config.actor_rollout_ref.actor.use_dynamic_bsz: assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) + sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert ( + config.actor_rollout_ref.actor.ppo_mini_batch_size + % config.actor_rollout_ref.actor.ppo_micro_batch_size + == 0 + ) assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus assert config.actor_rollout_ref.actor.loss_agg_mode in [ - "token-mean", "seq-mean-token-sum", "seq-mean-token-mean" + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: - print(f"NOTICE: You have both enabled in-reward kl and kl loss.") + print("NOTICE: You have both enabled in-reward kl and kl loss.") # critic if self.use_critic and not config.critic.use_dynamic_bsz: assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size - sp_size = config.critic.get('ulysses_sequence_parallel_size', 1) + sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) if config.critic.ppo_micro_batch_size is not None: assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == 'fsdp': - if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \ - config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.actor_rollout_ref.model.use_remove_padding, \ + if config.actor_rollout_ref.actor.strategy == "fsdp": + if ( + config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 + or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 + ): + assert config.actor_rollout_ref.model.use_remove_padding, ( "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + ) - if self.use_critic and config.critic.strategy == 'fsdp': - if config.critic.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.critic.model.use_remove_padding, \ + if self.use_critic and config.critic.strategy == "fsdp": + if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: + assert config.critic.model.use_remove_padding, ( "When using sequence parallelism for critic, you must enable `use_remove_padding`." + ) - if config.data.get('val_batch_size', None) is not None: + if config.data.get("val_batch_size", None) is not None: print( - f"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves." + "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves." ) # check eval config if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, \ + assert config.actor_rollout_ref.rollout.temperature > 0, ( "validation gen temperature should be greater than 0 when enabling do_sample" + ) print("[validate_config] All configuration checks passed successfully!") def _create_dataloader(self): # TODO: we have to make sure the batch size is divisible by the dp size from verl.utils.import_utils import load_extern_type + if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None: dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name) if not issubclass(dataset_cls, Dataset): - raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from " - f"'{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset") + raise TypeError( + f"The custom dataset class '{self.config.data.custom_cls.name}' from " + f"'{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) else: dataset_cls = RLHFDataset @@ -432,18 +476,19 @@ class RayPPOTrainer(object): # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) else: sampler = SequentialSampler(data_source=self.train_dataset) - self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset, - batch_size=self.config.data.get('gen_batch_size', - self.config.data.train_batch_size), - num_workers=8, - drop_last=True, - collate_fn=collate_fn, - sampler=sampler) + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=8, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) self.val_dataset = dataset_cls( data_files=self.config.data.val_files, @@ -459,14 +504,15 @@ class RayPPOTrainer(object): num_workers=8, shuffle=False, drop_last=False, - collate_fn=collate_fn) + collate_fn=collate_fn, + ) assert len(self.train_dataloader) >= 1 - assert len( - self.val_dataloader - ) == 1, "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves." + assert len(self.val_dataloader) == 1, ( + "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves." + ) - print(f'Size of train dataloader: {len(self.train_dataloader)}') + print(f"Size of train dataloader: {len(self.train_dataloader)}") # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs @@ -475,7 +521,7 @@ class RayPPOTrainer(object): total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps - print(f'Total training steps: {self.total_training_steps}') + print(f"Total training steps: {self.total_training_steps}") OmegaConf.set_struct(self.config, True) with open_dict(self.config): @@ -519,38 +565,39 @@ class RayPPOTrainer(object): test_batch = DataProto.from_single_dict(test_data) # repeat test batch - test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, - interleave=True) + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": return {} # Store original inputs - input_ids = test_batch.batch['input_ids'] + input_ids = test_batch.batch["input_ids"] # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) - if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys(): + if "multi_modal_inputs" in test_batch.non_tensor_batch.keys(): test_gen_batch = test_batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'], + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], ) else: test_gen_batch = test_batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=['raw_prompt_ids'], + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], ) test_gen_batch.meta_info = { - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, - 'recompute_log_prob': False, - 'do_sample': self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, - 'validate': True, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, } - print(f'test_gen_batch meta info: {test_gen_batch.meta_info}') + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") # pad to be divisible by dp_size test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) @@ -558,10 +605,10 @@ class RayPPOTrainer(object): # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) - print('validation generation end') + print("validation generation end") # Store generated outputs - output_ids = test_output_gen_batch.batch['responses'] + output_ids = test_output_gen_batch.batch["responses"] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) @@ -578,7 +625,7 @@ class RayPPOTrainer(object): for key, lst in result["reward_extra_info"].items(): reward_extra_infos_dict[key].extend(lst) - data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) @@ -594,9 +641,11 @@ class RayPPOTrainer(object): for var_name, metric2val in var2metric2val.items(): n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) for metric_name, metric_val in metric2val.items(): - if (var_name == core_var) and any( - metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" - in metric_name): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): metric_sec = "val-core" else: metric_sec = "val-aux" @@ -614,10 +663,12 @@ class RayPPOTrainer(object): # create actor and rollout if self.hybrid_engine: resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role='actor_rollout') - self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls else: raise NotImplementedError @@ -625,22 +676,22 @@ class RayPPOTrainer(object): if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) - self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls # create reference policy if needed if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref" + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls # create a reward model if reward_fn is None if self.use_rm: # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, @@ -654,83 +705,91 @@ class RayPPOTrainer(object): for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, - ray_cls_with_init=worker_dict_cls, - **wg_kwargs) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs + ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 self.wg_dicts.append(wg_dict) if self.use_critic: - self.critic_wg = all_wg['critic'] + self.critic_wg = all_wg["critic"] self.critic_wg.init_model() if self.use_reference_policy: - self.ref_policy_wg = all_wg['ref'] + self.ref_policy_wg = all_wg["ref"] self.ref_policy_wg.init_model() if self.use_rm: - self.rm_wg = all_wg['rm'] + self.rm_wg = all_wg["rm"] self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg['actor_rollout'] + self.actor_rollout_wg = all_wg["actor_rollout"] self.actor_rollout_wg.init_model() def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, - f'global_step_{self.global_steps}') + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) - print(f'local_global_step_folder: {local_global_step_folder}') - actor_local_path = os.path.join(local_global_step_folder, 'actor') + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) - remove_previous_ckpt_in_save = self.config.trainer.get('remove_previous_ckpt_in_save', False) + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) if remove_previous_ckpt_in_save: print( - 'Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead' + "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" ) - max_actor_ckpt_to_keep = self.config.trainer.get('max_actor_ckpt_to_keep', - None) if not remove_previous_ckpt_in_save else 1 - max_critic_ckpt_to_keep = self.config.trainer.get('max_critic_ckpt_to_keep', - None) if not remove_previous_ckpt_in_save else 1 + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) - self.actor_rollout_wg.save_checkpoint(actor_local_path, - actor_remote_path, - self.global_steps, - max_ckpt_to_keep=max_actor_ckpt_to_keep) + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, 'critic') - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic') - self.critic_wg.save_checkpoint(critic_local_path, - critic_remote_path, - self.global_steps, - max_ckpt_to_keep=max_critic_ckpt_to_keep) + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") dataloader_state_dict = self.train_dataloader.state_dict() torch.save(dataloader_state_dict, dataloader_local_path) # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, - 'latest_checkpointed_iteration.txt') - with open(local_latest_checkpointed_iteration, 'w') as f: + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) def _load_checkpoint(self): - if self.config.trainer.resume_mode == 'disable': + if self.config.trainer.resume_mode == "disable": return 0 # load from hdfs if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError('load from hdfs is not implemented yet') + raise NotImplementedError("load from hdfs is not implemented yet") else: checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path if not os.path.isabs(checkpoint_folder): @@ -739,59 +798,63 @@ class RayPPOTrainer(object): global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder - if self.config.trainer.resume_mode == 'auto': + if self.config.trainer.resume_mode == "auto": if global_step_folder is None: - print('Training from scratch') + print("Training from scratch") return 0 else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) - print(f'Load from checkpoint folder: {global_step_folder}') + print(f"Load from checkpoint folder: {global_step_folder}") # set global step - self.global_steps = int(global_step_folder.split('global_step_')[-1]) + self.global_steps = int(global_step_folder.split("global_step_")[-1]) - print(f'Setting global step to {self.global_steps}') - print(f'Resuming from {global_step_folder}') + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") - actor_path = os.path.join(global_step_folder, 'actor') - critic_path = os.path.join(global_step_folder, 'critic') + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, - del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, - del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load dataloader, # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, 'data.pt') + dataloader_local_path = os.path.join(global_step_folder, "data.pt") if os.path.exists(dataloader_local_path): dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) self.train_dataloader.load_state_dict(dataloader_state_dict) else: print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") - def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'): + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch['attention_mask'] + attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, - k_partitions=world_size, - equal_size=True) + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, - partitions=global_partition_lst, - prefix=logging_prefix) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) metrics.update(global_balance_stats) def fit(self): @@ -800,13 +863,16 @@ class RayPPOTrainer(object): The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ - from verl.utils.tracking import Tracking from omegaconf import OmegaConf - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True)) + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) self.global_steps = 0 @@ -815,11 +881,11 @@ class RayPPOTrainer(object): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') + pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get('val_only', False): + if self.config.trainer.get("val_only", False): return # add tqdm @@ -837,28 +903,28 @@ class RayPPOTrainer(object): batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - if 'multi_modal_inputs' in batch.non_tensor_batch.keys(): + if "multi_modal_inputs" in batch.non_tensor_batch.keys(): gen_batch = batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'], + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], ) else: gen_batch = batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=['raw_prompt_ids'], + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], ) is_last_step = self.global_steps >= self.total_training_steps - with _timer('step', timing_raw): + with _timer("step", timing_raw): # generate a batch - with _timer('gen', timing_raw): + with _timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer('gen_max', timing_raw): + with _timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) @@ -867,17 +933,18 @@ class RayPPOTrainer(object): batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - batch.batch['reward_baselines'] = reward_baseline_tensor + batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], - dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) - batch.batch['response_mask'] = compute_response_mask(batch) + batch.batch["response_mask"] = compute_response_mask(batch) # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo @@ -885,35 +952,35 @@ class RayPPOTrainer(object): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # recompute old_log_probs - with _timer('old_log_prob', timing_raw): + with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch['entropys'] - response_masks = batch.batch['response_mask'] + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, - loss_mask=response_masks, - loss_agg_mode=loss_agg_mode) + entropy_loss = agg_loss( + loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode + ) old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop('entropys') + old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer('ref', timing_raw): + with _timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: - with _timer('values', timing_raw): + with _timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer('adv', timing_raw): + with _timer("adv", timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. @@ -926,62 +993,68 @@ class RayPPOTrainer(object): reward_extra_infos_dict: dict[str, list] try: reward_result = self.reward_fn(batch, return_dict=True) - reward_tensor = reward_result['reward_tensor'] - reward_extra_infos_dict = reward_result['reward_extra_info'] + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result["reward_extra_info"] except Exception as e: - print(f'Error in reward_fn: {e}') + print(f"Error in reward_fn: {e}") reward_tensor = self.reward_fn(batch) reward_extra_infos_dict = {} - batch.batch['token_level_scores'] = reward_tensor + batch.batch["token_level_scores"] = reward_tensor - print(f'{list(reward_extra_infos_dict.keys())=}') + print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty) + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) metrics.update(kl_metrics) else: - batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process - batch = compute_advantage(batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + ) # update critic if self.use_critic: - with _timer('update_critic', timing_raw): + with _timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer('update_actor', timing_raw): + with _timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ - (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer('testing', timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with _timer("testing", timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and ( is_last_step or \ - self.global_steps % self.config.trainer.save_freq == 0): - with _timer('save_checkpoint', timing_raw): + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics @@ -995,7 +1068,7 @@ class RayPPOTrainer(object): logger.log(data=metrics, step=self.global_steps) if is_last_step: - pprint(f'Final validation metrics: {last_val_metrics}') + pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py index bc781029d..430ff65bd 100644 --- a/verl/utils/__init__.py +++ b/verl/utils/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from . import tokenizer -from .tokenizer import hf_tokenizer, hf_processor +from .tokenizer import hf_processor, hf_tokenizer -__all__ = tokenizer.__all__ \ No newline at end of file +__all__ = tokenizer.__all__ diff --git a/verl/utils/checkpoint/__init__.py b/verl/utils/checkpoint/__init__.py index 7a7aadbc9..1ce90c5eb 100644 --- a/verl/utils/checkpoint/__init__.py +++ b/verl/utils/checkpoint/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index af32e85d9..0687dd7f4 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import random import shutil -from filelock import FileLock import tempfile from typing import Union + +import numpy as np import torch import torch.distributed +from filelock import FileLock from transformers import PreTrainedTokenizer, ProcessorMixin -import numpy as np -import random -import re class BaseCheckpointManager: @@ -39,12 +39,14 @@ class BaseCheckpointManager: - huggingface tokenizer and config for ckpt merge """ - def __init__(self, - model, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, - processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, - checkpoint_contents: list = ['model', 'optimizer', 'extra']): + def __init__( + self, + model, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, + processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, + checkpoint_contents: list = ["model", "optimizer", "extra"], + ): self.previous_global_step = None self.previous_saved_paths = [] @@ -60,11 +62,9 @@ class BaseCheckpointManager: def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): raise NotImplementedError - def save_checkpoint(self, - local_path: str, - hdfs_path: str = None, - global_step: int = 0, - max_ckpt_to_keep: int = None): + def save_checkpoint( + self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None + ): raise NotImplementedError @staticmethod @@ -77,7 +77,7 @@ class BaseCheckpointManager: path = [path] for p in path: abs_path = os.path.abspath(p) - print(f'Checkpoint manager remove previous save local path: {abs_path}') + print(f"Checkpoint manager remove previous save local path: {abs_path}") if not os.path.exists(abs_path): continue shutil.rmtree(abs_path, ignore_errors=True) @@ -106,19 +106,19 @@ class BaseCheckpointManager: @staticmethod def get_rng_state(): rng_state = { - 'cpu': torch.get_rng_state(), - 'cuda': torch.cuda.get_rng_state(), - 'numpy': np.random.get_state(), - 'random': random.getstate(), + "cpu": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "random": random.getstate(), } return rng_state @staticmethod def load_rng_state(rng_state): - torch.set_rng_state(rng_state['cpu']) - torch.cuda.set_rng_state(rng_state['cuda']) - np.random.set_state(rng_state['numpy']) - random.setstate(rng_state['random']) + torch.set_rng_state(rng_state["cpu"]) + torch.cuda.set_rng_state(rng_state["cuda"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) def find_latest_ckpt_path(path, directory_format="global_step_{}"): diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index c59f844df..466db5695 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray import os - import warnings from typing import Union + import torch import torch.distributed -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType -from torch.distributed.fsdp import ShardedStateDictConfig, ShardedOptimStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType +from transformers import PreTrainedTokenizer, ProcessorMixin from verl.utils.fs import copy_to_local, is_non_local -from transformers import PreTrainedTokenizer, ProcessorMixin - from .checkpoint_manager import BaseCheckpointManager @@ -38,43 +36,49 @@ class FSDPCheckpointManager(BaseCheckpointManager): - extra_states in a SPMD way. - We save + We save - sharded model states and optimizer states - full lr_scheduler states - huggingface tokenizer/processor and config for ckpt merge """ - def __init__(self, - model: FSDP, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, - processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, - checkpoint_contents: list = ['model', 'optimizer', 'extra'], - **kwargs): - + def __init__( + self, + model: FSDP, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, + checkpoint_contents: list = ["model", "optimizer", "extra"], + **kwargs, + ): if processing_class is None: assert "tokenizer" in kwargs, "tokenizer or processor must be provided" warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning) processing_class = kwargs.pop("tokenizer") - assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}" + assert ( + "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents + ), f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}" - super().__init__(model, - optimizer, - lr_scheduler=lr_scheduler, - processing_class=processing_class, - checkpoint_contents=checkpoint_contents) + super().__init__( + model, + optimizer, + lr_scheduler=lr_scheduler, + processing_class=processing_class, + checkpoint_contents=checkpoint_contents, + ) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): if local_path is None: return # every rank download its own checkpoint - remote_model_path = os.path.join(local_path, f'model_world_size_{self.world_size}_rank_{self.rank}.pt') - remote_optim_path = os.path.join(local_path, f'optim_world_size_{self.world_size}_rank_{self.rank}.pt') - remote_extra_state_path = os.path.join(local_path, - f'extra_state_world_size_{self.world_size}_rank_{self.rank}.pt') + remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + remote_extra_state_path = os.path.join( + local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) print( - f'[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}' + f"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}" ) local_model_path = copy_to_local(remote_model_path) local_optim_path = copy_to_local(remote_optim_path) @@ -91,10 +95,10 @@ class FSDPCheckpointManager(BaseCheckpointManager): os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None except Exception as e: print( - f'[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored' + f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored" ) - lr_scheduler_state_dict = extra_state_dict['lr_scheduler'] + lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) @@ -103,9 +107,9 @@ class FSDPCheckpointManager(BaseCheckpointManager): if self.optimizer is not None: self.optimizer.load_state_dict(optimizer_state_dict) # recover random state - if 'rng' in extra_state_dict: + if "rng" in extra_state_dict: # 'rng' may not exist for backward compatibility - self.load_rng_state(extra_state_dict['rng']) + self.load_rng_state(extra_state_dict["rng"]) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) @@ -118,8 +122,12 @@ class FSDPCheckpointManager(BaseCheckpointManager): self.previous_global_step = global_step # remove previous local_path - if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len( - self.previous_saved_paths) >= max_ckpt_to_keep: + if ( + max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) self.previous_saved_paths = self.previous_saved_paths[keep_start:] @@ -144,16 +152,16 @@ class FSDPCheckpointManager(BaseCheckpointManager): lr_scheduler_state_dict = None extra_state_dict = { - 'lr_scheduler': lr_scheduler_state_dict, - 'rng': self.get_rng_state(), + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), } - model_path = os.path.join(local_path, f'model_world_size_{self.world_size}_rank_{self.rank}.pt') - optim_path = os.path.join(local_path, f'optim_world_size_{self.world_size}_rank_{self.rank}.pt') - extra_path = os.path.join(local_path, f'extra_state_world_size_{self.world_size}_rank_{self.rank}.pt') + model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") - print(f'[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}') - print(f'[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}') - print(f'[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}') + print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}") + print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}") + print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}") torch.save(model_state_dict, model_path) torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None torch.save(extra_state_dict, extra_path) @@ -163,7 +171,7 @@ class FSDPCheckpointManager(BaseCheckpointManager): torch.distributed.barrier() if self.rank == 0: - hf_local_path = os.path.join(local_path, 'huggingface') + hf_local_path = os.path.join(local_path, "huggingface") os.makedirs(hf_local_path, exist_ok=True) self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path) self.processing_class.save_pretrained(hf_local_path) diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index 2f26763e5..c7b59cea0 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -12,30 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray import os import random -import numpy as np -import warnings -from typing import Union +import numpy as np import torch import torch.distributed -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.fs import copy_to_local, is_non_local -from verl.models.weight_loader_registry import get_weight_saver -from verl.models.weight_loader_registry import get_weight_loader -from verl.utils.model import load_megatron_model_weights -from verl.utils.megatron_utils import TransformerConfig, get_model_checkpoint_path, get_hf_model_checkpoint_path, get_optimizer_checkpoint_path, get_rng_states_checkpoint_path, unwrap_model - -from .checkpoint_manager import BaseCheckpointManager -from transformers import AutoModelForCausalLM - from megatron.core import mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject -from megatron.core.transformer.module import Float16Module -from megatron.core.distributed import DistributedDataParallel as LocalDDP + +from verl.models.weight_loader_registry import get_weight_saver +from verl.utils.fs import is_non_local +from verl.utils.megatron_utils import ( + get_hf_model_checkpoint_path, + get_model_checkpoint_path, + get_optimizer_checkpoint_path, + get_rng_states_checkpoint_path, +) + +from .checkpoint_manager import BaseCheckpointManager class MegatronCheckpointManager(BaseCheckpointManager): @@ -47,32 +42,35 @@ class MegatronCheckpointManager(BaseCheckpointManager): - extra_states in a SPMD way. - We save + We save - sharded model states and optimizer states - full lr_scheduler states - huggingface tokenizer/processor and config for ckpt merge """ - def __init__(self, - config, - model_config, - role, - model: torch.nn.ModuleList, - arch: str, - hf_config, - param_dtype: torch.dtype, - share_embeddings_and_output_weights: bool, - tokenizer, - optimizer, - use_distributed_optimizer: bool, - checkpoint_contents: list = ['model', 'optimizer', 'extra'], - **kwargs): - - super().__init__(model, - optimizer=optimizer, - lr_scheduler=None, - processing_class=tokenizer, - checkpoint_contents=checkpoint_contents) + def __init__( + self, + config, + model_config, + role, + model: torch.nn.ModuleList, + arch: str, + hf_config, + param_dtype: torch.dtype, + share_embeddings_and_output_weights: bool, + tokenizer, + optimizer, + use_distributed_optimizer: bool, + checkpoint_contents: list = ["model", "optimizer", "extra"], + **kwargs, + ): + super().__init__( + model, + optimizer=optimizer, + lr_scheduler=None, + processing_class=tokenizer, + checkpoint_contents=checkpoint_contents, + ) self.arch = arch self.config = config self.role = role @@ -91,20 +89,18 @@ class MegatronCheckpointManager(BaseCheckpointManager): self.weight_saver = get_weight_saver(self.arch) def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False): - """ collect rng state across data parallel ranks """ + """collect rng state across data parallel ranks""" rng_state = { - 'random_rng_state': random.getstate(), - 'np_rng_state': np.random.get_state(), - 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), - 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states() + "random_rng_state": random.getstate(), + "np_rng_state": np.random.get_state(), + "torch_rng_state": torch.get_rng_state(), + "cuda_rng_state": torch.cuda.get_rng_state(), + "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), } rng_state_list = None - if torch.distributed.is_initialized() and \ - mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: - rng_state_list = \ - [None for i in range(mpu.get_data_parallel_world_size())] + if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: + rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] @@ -116,26 +112,32 @@ class MegatronCheckpointManager(BaseCheckpointManager): tp_size = mpu.get_tensor_model_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() - rng_state_list = ShardedObject('rng_state', - rng_state_list, (pp_size, tp_size, cp_size), (pp_rank, tp_rank, cp_rank), - replica_id=mpu.get_data_parallel_rank(with_context_parallel=True)) + rng_state_list = ShardedObject( + "rng_state", + rng_state_list, + (pp_size, tp_size, cp_size), + (pp_rank, tp_rank, cp_rank), + replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), + ) return rng_state_list - def get_checkpoint_name(self, - checkpoints_path, - pipeline_parallel=None, - tensor_rank=None, - pipeline_rank=None, - cp_rank=None, - expert_parallel=None, - expert_rank=None, - return_base_dir=True, - basename="model.pt"): + def get_checkpoint_name( + self, + checkpoints_path, + pipeline_parallel=None, + tensor_rank=None, + pipeline_rank=None, + cp_rank=None, + expert_parallel=None, + expert_rank=None, + return_base_dir=True, + basename="model.pt", + ): """Determine the directory name for this rank's checkpoint.""" # Use both the tensor and pipeline MP rank. if pipeline_parallel is None: - pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) + pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 if tensor_rank is None: tensor_rank = mpu.get_tensor_model_parallel_rank() if pipeline_rank is None: @@ -143,7 +145,7 @@ class MegatronCheckpointManager(BaseCheckpointManager): if cp_rank is None: cp_rank = mpu.get_context_parallel_rank() if expert_parallel is None: - expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1) + expert_parallel = mpu.get_expert_model_parallel_world_size() > 1 if expert_rank is None: expert_rank = mpu.get_expert_model_parallel_rank() @@ -153,12 +155,12 @@ class MegatronCheckpointManager(BaseCheckpointManager): # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path if not pipeline_parallel: - common_path = os.path.join(checkpoints_path, f'mp_rank_{tensor_rank:02d}') + common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}") else: - common_path = os.path.join(checkpoints_path, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') + common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") if expert_parallel: - common_path = common_path + f'_{expert_rank:03d}' + common_path = common_path + f"_{expert_rank:03d}" os.makedirs(common_path, exist_ok=True) @@ -182,33 +184,34 @@ class MegatronCheckpointManager(BaseCheckpointManager): rng_state = rng_state[mpu.get_data_parallel_rank()] else: rng_state = rng_state[0] - random.setstate(rng_state['random_rng_state']) - np.random.set_state(rng_state['np_rng_state']) - torch.set_rng_state(rng_state['torch_rng_state']) - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + random.setstate(rng_state["random_rng_state"]) + np.random.set_state(rng_state["np_rng_state"]) + torch.set_rng_state(rng_state["torch_rng_state"]) + torch.cuda.set_rng_state(rng_state["cuda_rng_state"]) # Check for empty states array - if not rng_state['rng_tracker_states']: + if not rng_state["rng_tracker_states"]: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) + tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"]) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): if local_path is None: return - if 'model' in self.checkpoint_contents: + if "model" in self.checkpoint_contents: model_path = get_model_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False) state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False) - assert len(state_dicts) == len( - self.model), f'state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}' + assert len(state_dicts) == len(self.model), ( + f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}" + ) for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)): model.load_state_dict(state_dict) - print(f'Loaded sharded model checkpoint from {model_path}') + print(f"Loaded sharded model checkpoint from {model_path}") - if 'optimizer' in self.checkpoint_contents: + if "optimizer" in self.checkpoint_contents: self.load_optimizer(local_path) - if 'extra' in self.checkpoint_contents: + if "extra" in self.checkpoint_contents: self.load_rng_states(local_path) if del_local_after_load: @@ -216,7 +219,7 @@ class MegatronCheckpointManager(BaseCheckpointManager): os.remove(local_path) if is_non_local(local_path) else None except Exception as e: print( - f'[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored' + f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored" ) def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): @@ -224,8 +227,12 @@ class MegatronCheckpointManager(BaseCheckpointManager): self.previous_global_step = global_step # remove previous local_path - if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len( - self.previous_saved_paths) >= max_ckpt_to_keep: + if ( + max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) self.previous_saved_paths = self.previous_saved_paths[keep_start:] @@ -233,62 +240,71 @@ class MegatronCheckpointManager(BaseCheckpointManager): local_path = self.local_mkdir(local_path) # Save Model - if 'model' in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0: + if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0: state_dicts = [] for vpp_rank, model in enumerate(self.model): state_dict = model.state_dict() state_dicts.append(state_dict) - print(f'Saving sharded model checkpoint to {local_path}') + print(f"Saving sharded model checkpoint to {local_path}") model_ckpt_path = get_model_checkpoint_path(local_path) hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False) torch.save(state_dicts, os.path.join(ckpt_name)) self.processing_class.save_pretrained(hf_model_ckpt_path) # tokenizer will be saved to hf_model_ckpt_path - print(f'Saved checkpoint to {model_ckpt_path}') + print(f"Saved checkpoint to {model_ckpt_path}") if hdfs_path is not None: - print(f'Uploading checkpoint to {hdfs_path}') + print(f"Uploading checkpoint to {hdfs_path}") from verl.utils import hdfs_io + hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) - if 'hf_model' in self.checkpoint_contents: + if "hf_model" in self.checkpoint_contents: # wait for everyone to dump to local - state_dict = self.weight_saver(self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights) + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) torch.distributed.barrier() - print(f'self.param_dtype: {self.param_dtype}') + print(f"self.param_dtype: {self.param_dtype}") for key in state_dict.keys(): - print(f'state_dict[key].dtype: {key} {state_dict[key].dtype}') + print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}") torch.distributed.barrier() if self.rank == 0: hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - from accelerate import init_empty_weights import warnings + + from accelerate import init_empty_weights + with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter("ignore") - if 'mistral7b-rm' in self.config.model.path: + if "mistral7b-rm" in self.config.model.path: from transformers import MistralForSequenceClassification + model = MistralForSequenceClassification.from_pretrained( - self.config.model.path) # use score head instead of lm_head - state_dict['score.weight'] = state_dict['score.weight'] + self.config.model.path + ) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] else: from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) if hdfs_path is not None: - print(f'Uploading checkpoint to {hdfs_path}') + print(f"Uploading checkpoint to {hdfs_path}") from verl.utils import hdfs_io + hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) # Save Optimizer - if 'optimizer' in self.checkpoint_contents: + if "optimizer" in self.checkpoint_contents: torch.distributed.barrier() optimizer_path = get_optimizer_checkpoint_path(local_path) @@ -297,7 +313,7 @@ class MegatronCheckpointManager(BaseCheckpointManager): print(f"saving optimizer state to {optimizer_path}") # Save RNG States - if 'extra' in self.checkpoint_contents: + if "extra" in self.checkpoint_contents: torch.distributed.barrier() rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False) diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 8ff29cb0a..f4fbc0a31 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -22,9 +22,8 @@ import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizer -from verl.utils.fs import copy_local_path_from_hdfs -from verl.utils.model import compute_position_id_with_mask from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs class MultiTurnSFTDataset(Dataset): @@ -35,13 +34,13 @@ class MultiTurnSFTDataset(Dataset): def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None): # Set defaults and extract parameters from config if provided config = config or {} - self.truncation = config.get('truncation', 'error') - self.max_length = config.get('max_length', 1024) + self.truncation = config.get("truncation", "error") + self.max_length = config.get("max_length", 1024) # Get messages_key from the new multiturn config structure - multiturn_config = config.get('multiturn', {}) - self.messages_key = multiturn_config.get('messages_key', 'messages') + multiturn_config = config.get("multiturn", {}) + self.messages_key = multiturn_config.get("messages_key", "messages") - assert self.truncation in ['error', 'left', 'right'] + assert self.truncation in ["error", "left", "right"] if not isinstance(parquet_files, List): parquet_files = [parquet_files] @@ -59,9 +58,10 @@ class MultiTurnSFTDataset(Dataset): self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) def _read_files_and_process(self): - def series_to_item(ls): - import pandas, numpy + import numpy + import pandas + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: ls = ls[0] return ls @@ -83,10 +83,9 @@ class MultiTurnSFTDataset(Dataset): messages = self.messages[item] # First, get the full conversation tokens - full_tokens = tokenizer.apply_chat_template(messages, - tokenize=True, - return_tensors='pt', - add_generation_prompt=False) + full_tokens = tokenizer.apply_chat_template( + messages, tokenize=True, return_tensors="pt", add_generation_prompt=False + ) input_ids = full_tokens[0] # The output is already a tensor attention_mask = torch.ones_like(input_ids) @@ -97,22 +96,26 @@ class MultiTurnSFTDataset(Dataset): current_length = 0 for i, msg in enumerate(messages): # Get tokens for messages up to this point to find the start position - prefix_messages = messages[:i + 1] - prefix_tokens = tokenizer.apply_chat_template(prefix_messages, - tokenize=True, - return_tensors='pt', - add_generation_prompt=False) + prefix_messages = messages[: i + 1] + prefix_tokens = tokenizer.apply_chat_template( + prefix_messages, tokenize=True, return_tensors="pt", add_generation_prompt=False + ) # Get tokens for messages up to previous point - prev_tokens = tokenizer.apply_chat_template( - messages[:i], tokenize=True, return_tensors='pt', add_generation_prompt=False) if i > 0 else None + prev_tokens = ( + tokenizer.apply_chat_template( + messages[:i], tokenize=True, return_tensors="pt", add_generation_prompt=False + ) + if i > 0 + else None + ) # Calculate start and end positions start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 end_pos = prefix_tokens[0].shape[0] # If this is an assistant message, set loss mask - if msg['role'] == 'assistant': + if msg["role"] == "assistant": loss_mask[start_pos:end_pos] = 1 # Handle sequence length @@ -120,8 +123,9 @@ class MultiTurnSFTDataset(Dataset): if sequence_length < self.max_length: # Pad sequences pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), - dtype=input_ids.dtype) * pad_token_id + padded_input_ids = ( + torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * pad_token_id + ) padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) @@ -129,18 +133,18 @@ class MultiTurnSFTDataset(Dataset): attention_mask = torch.cat((attention_mask, padded_attention_mask)) loss_mask = torch.cat((loss_mask, padded_loss_mask)) elif sequence_length > self.max_length: - if self.truncation == 'left': - input_ids = input_ids[-self.max_length:] - attention_mask = attention_mask[-self.max_length:] - loss_mask = loss_mask[-self.max_length:] - elif self.truncation == 'right': - input_ids = input_ids[:self.max_length] - attention_mask = attention_mask[:self.max_length] - loss_mask = loss_mask[:self.max_length] - elif self.truncation == 'error': - raise ValueError(f'{sequence_length=} is larger than {self.max_length=}') + if self.truncation == "left": + input_ids = input_ids[-self.max_length :] + attention_mask = attention_mask[-self.max_length :] + loss_mask = loss_mask[-self.max_length :] + elif self.truncation == "right": + input_ids = input_ids[: self.max_length] + attention_mask = attention_mask[: self.max_length] + loss_mask = loss_mask[: self.max_length] + elif self.truncation == "error": + raise ValueError(f"{sequence_length=} is larger than {self.max_length=}") else: - raise ValueError(f'Unknown truncation method {self.truncation}') + raise ValueError(f"Unknown truncation method {self.truncation}") # Create position IDs position_ids = torch.arange(len(input_ids), dtype=torch.long) @@ -148,8 +152,8 @@ class MultiTurnSFTDataset(Dataset): position_ids = position_ids * attention_mask return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'position_ids': position_ids, - 'loss_mask': loss_mask + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, } diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 079f108d8..730631849 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import re -from typing import List, Union, Optional -import copy -import datasets from collections import defaultdict +from typing import List, Optional, Union -import torch +import datasets import numpy as np +import torch +from omegaconf import DictConfig, ListConfig from torch.utils.data import Dataset from transformers import PreTrainedTokenizer, ProcessorMixin -from omegaconf import ListConfig, DictConfig -from verl.utils.model import compute_position_id_with_mask import verl.utils.torch_functional as verl_F +from verl.utils.model import compute_position_id_with_mask def collate_fn(data_list: list[dict]) -> dict: @@ -76,8 +76,8 @@ class RLHFDataset(Dataset): self.video_key = config.get("video_key", "videos") self.max_prompt_length = config.get("max_prompt_length", 1024) - self.return_raw_chat = config.get('return_raw_chat', False) - self.truncation = config.get('truncation', 'error') + self.return_raw_chat = config.get("return_raw_chat", False) + self.truncation = config.get("truncation", "error") self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) @@ -91,6 +91,7 @@ class RLHFDataset(Dataset): def _download(self, use_origin_parquet=False): from verl.utils.fs import copy_to_local + data_files = self.data_files if not use_origin_parquet else self.original_data_files for i, parquet_file in enumerate(data_files): self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir) @@ -103,28 +104,29 @@ class RLHFDataset(Dataset): dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) - print(f'dataset len: {len(self.dataframe)}') + print(f"dataset len: {len(self.dataframe)}") # filter out too long prompts if self.filter_overlong_prompts: tokenizer = self.tokenizer prompt_key = self.prompt_key self.dataframe = self.dataframe.filter( - lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) - ) <= self.max_prompt_length, + lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) + <= self.max_prompt_length, num_proc=self.num_workers, - desc=f"Filtering prompts longer than {self.max_prompt_length} tokens") + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", + ) - print(f'filter dataset len: {len(self.dataframe)}') + print(f"filter dataset len: {len(self.dataframe)}") def resume_dataset_state(self): - self.serialize_dataset = False if hasattr(self, 'original_data_files') else True + self.serialize_dataset = False if hasattr(self, "original_data_files") else True # resume dataframe if not it's serialized in data.pt if not self.serialize_dataset: self._download(use_origin_parquet=True) # download and resume from original parquet files self._read_files_and_tokenize() else: - print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance') + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") def __len__(self): return len(self.dataframe) @@ -189,16 +191,18 @@ class RLHFDataset(Dataset): else: raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - model_inputs = self.tokenizer(raw_prompt, return_tensors='pt', add_special_tokens=False) - input_ids = model_inputs.pop('input_ids') - attention_mask = model_inputs.pop('attention_mask') + model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") - input_ids, attention_mask = verl_F.postprocess_data(input_ids=input_ids, - attention_mask=attention_mask, - max_length=self.max_prompt_length, - pad_token_id=self.tokenizer.pad_token_id, - left_pad=True, - truncation=self.truncation) + input_ids, attention_mask = verl_F.postprocess_data( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor": from verl.models.transformers.qwen2_vl import get_rope_index @@ -217,23 +221,23 @@ class RLHFDataset(Dataset): else: position_ids = compute_position_id_with_mask(attention_mask) - row_dict['input_ids'] = input_ids[0] - row_dict['attention_mask'] = attention_mask[0] - row_dict['position_ids'] = position_ids[0] + row_dict["input_ids"] = input_ids[0] + row_dict["attention_mask"] = attention_mask[0] + row_dict["position_ids"] = position_ids[0] raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) if len(raw_prompt_ids) > self.max_prompt_length: if self.truncation == "left": - raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length:] + raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] elif self.truncation == "right": - raw_prompt_ids = raw_prompt_ids[:self.max_prompt_length] + raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] elif self.truncation == "error": raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") - row_dict['raw_prompt_ids'] = raw_prompt_ids + row_dict["raw_prompt_ids"] = raw_prompt_ids # encode prompts without chat template if self.return_raw_chat: - row_dict['raw_prompt'] = messages + row_dict["raw_prompt"] = messages # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) @@ -245,8 +249,8 @@ class RLHFDataset(Dataset): if not self.serialize_dataset: state = self.__dict__.copy() - if 'dataframe' in state: - del state['dataframe'] + if "dataframe" in state: + del state["dataframe"] return state return self.__dict__.copy() diff --git a/verl/utils/dataset/rm_dataset.py b/verl/utils/dataset/rm_dataset.py index be137895b..b8ebb5ea0 100644 --- a/verl/utils/dataset/rm_dataset.py +++ b/verl/utils/dataset/rm_dataset.py @@ -16,16 +16,15 @@ import os from typing import List, Union import pandas as pd - import torch from torch.utils.data import Dataset -from transformers import AutoTokenizer from verl.utils import hf_tokenizer def download_files_distributed(download_fn): import torch.distributed + if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: # download files @@ -38,16 +37,17 @@ def download_files_distributed(download_fn): class RMDataset(Dataset): - - def __init__(self, - parquet_files: Union[str, List[str]], - tokenizer, - prompt_key='prompt', - chosen_key='chosen', - rejected_key='rejected', - max_length=1024, - add_eos=True, - cache_dir='~/.cache/verl/rm'): + def __init__( + self, + parquet_files: Union[str, List[str]], + tokenizer, + prompt_key="prompt", + chosen_key="chosen", + rejected_key="rejected", + max_length=1024, + add_eos=True, + cache_dir="~/.cache/verl/rm", + ): if not isinstance(parquet_files, List): parquet_files = [parquet_files] @@ -68,9 +68,9 @@ class RMDataset(Dataset): self._read_files_and_tokenize() def _download(self): - def _download_files(): from verl.utils.fs import copy, is_non_local + os.makedirs(self.cache_dir, exist_ok=True) assert os.path.exists(self.cache_dir) for i, parquet_file in enumerate(self.parquet_files): @@ -101,13 +101,14 @@ class RMDataset(Dataset): if curr_length < self.max_length: input_ids = torch.cat( - (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1) + (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1 + ) attention_mask = torch.cat( - (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), - dim=-1) + (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1 + ) elif curr_length > self.max_length: - input_ids = input_ids[:self.max_length] - attention_mask = attention_mask[:self.max_length] + input_ids = input_ids[: self.max_length] + attention_mask = attention_mask[: self.max_length] return input_ids, attention_mask @@ -116,14 +117,15 @@ class RMDataset(Dataset): chosen_response = self.chosen_responses[item] rejected_response = self.rejected_responses[item] - prompt_ids = self.tokenizer(prompt, return_tensors='pt')['input_ids'][0] - chosen_response_ids = self.tokenizer(chosen_response, return_tensors='pt')['input_ids'][0] - rejected_response_ids = self.tokenizer(rejected_response, return_tensors='pt')['input_ids'][0] + prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] + chosen_response_ids = self.tokenizer(chosen_response, return_tensors="pt")["input_ids"][0] + rejected_response_ids = self.tokenizer(rejected_response, return_tensors="pt")["input_ids"][0] if self.add_eos: chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1) - rejected_response_ids = torch.cat((rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), - dim=-1) + rejected_response_ids = torch.cat( + (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1 + ) chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1) chosen_attention_mask = torch.ones_like(chosen_input_ids) @@ -138,6 +140,6 @@ class RMDataset(Dataset): attention_mask = torch.stack((rejected_input_ids, rejected_attention_mask), dim=0) return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - } \ No newline at end of file + "input_ids": input_ids, + "attention_mask": attention_mask, + } diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index 7407239df..082a19e10 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -21,14 +21,13 @@ Each parquet file contains from typing import List, Union import pandas as pd - import torch from torch.utils.data import Dataset -from transformers import AutoTokenizer, PreTrainedTokenizer +from transformers import PreTrainedTokenizer +from verl.utils import hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.model import compute_position_id_with_mask -from verl.utils import hf_tokenizer class SFTDataset(Dataset): @@ -40,15 +39,14 @@ class SFTDataset(Dataset): """ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): + prompt_key = config.get("prompt_key", "prompt") + prompt_dict_keys = config.get("prompt_dict_keys", None) + response_key = config.get("response_key", "response") + response_dict_keys = config.get("response_dict_keys", None) + max_length = config.get("max_length", 1024) + truncation = config.get("truncation", "error") - prompt_key = config.get('prompt_key', 'prompt') - prompt_dict_keys = config.get('prompt_dict_keys', None) - response_key = config.get('response_key', 'response') - response_dict_keys = config.get('response_dict_keys', None) - max_length = config.get('max_length', 1024) - truncation = config.get('truncation', 'error') - - assert truncation in ['error', 'left', 'right'] + assert truncation in ["error", "left", "right"] self.truncation = truncation if not isinstance(parquet_files, List): @@ -74,9 +72,10 @@ class SFTDataset(Dataset): self.parquet_files[i] = copy_to_local(parquet_file, verbose=True) def _read_files_and_tokenize(self): - def series_to_item(ls): - import pandas, numpy + import numpy + import pandas + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: ls = ls[0] return ls @@ -95,7 +94,7 @@ class SFTDataset(Dataset): try: self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1) except Exception: - print(f'self.prompts={self.prompts}') + print(f"self.prompts={self.prompts}") raise self.prompts = self.prompts.tolist() self.responses = self.dataframe[self.response_key] @@ -103,7 +102,7 @@ class SFTDataset(Dataset): try: self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1) except Exception: - print(f'self.responses={self.responses}') + print(f"self.responses={self.responses}") raise self.responses = self.responses.tolist() @@ -117,20 +116,20 @@ class SFTDataset(Dataset): response = self.responses[item] # apply chat template - prompt_chat = [{'role': 'user', 'content': prompt}] + prompt_chat = [{"role": "user", "content": prompt}] # string prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False) response_chat_str = response + tokenizer.eos_token # tokenize - prompt_ids_output = tokenizer(prompt_chat_str, return_tensors='pt', add_special_tokens=False) - prompt_ids = prompt_ids_output['input_ids'][0] - prompt_attention_mask = prompt_ids_output['attention_mask'][0] + prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False) + prompt_ids = prompt_ids_output["input_ids"][0] + prompt_attention_mask = prompt_ids_output["attention_mask"][0] - response_ids_output = tokenizer(response_chat_str, return_tensors='pt', add_special_tokens=False) - response_ids = response_ids_output['input_ids'][0] - response_attention_mask = response_ids_output['attention_mask'][0] + response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False) + response_ids = response_ids_output["input_ids"][0] + response_attention_mask = response_ids_output["attention_mask"][0] prompt_length = prompt_ids.shape[0] response_length = response_ids.shape[0] @@ -141,37 +140,39 @@ class SFTDataset(Dataset): # padding to max length sequence_length = input_ids.shape[0] if sequence_length < self.max_length: - padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), - dtype=input_ids.dtype) * self.tokenizer.pad_token_id + padded_input_ids = ( + torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) + * self.tokenizer.pad_token_id + ) padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) input_ids = torch.cat((input_ids, padded_input_ids)) attention_mask = torch.cat((attention_mask, padded_attention_mask)) elif sequence_length > self.max_length: - if self.truncation == 'left': + if self.truncation == "left": # actually, left truncation may not be reasonable - input_ids = input_ids[-self.max_length:] - attention_mask = attention_mask[-self.max_length:] - elif self.truncation == 'right': - input_ids = input_ids[:self.max_length] - attention_mask = attention_mask[:self.max_length] - elif self.truncation == 'error': - raise NotImplementedError(f'{sequence_length=} is larger than {self.max_length=}') + input_ids = input_ids[-self.max_length :] + attention_mask = attention_mask[-self.max_length :] + elif self.truncation == "right": + input_ids = input_ids[: self.max_length] + attention_mask = attention_mask[: self.max_length] + elif self.truncation == "error": + raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}") else: - raise NotImplementedError(f'Unknown truncation method {self.truncation}') + raise NotImplementedError(f"Unknown truncation method {self.truncation}") position_ids = compute_position_id_with_mask(attention_mask) loss_mask = attention_mask.clone() if prompt_length > 1: # mask out prompt for SFT. - loss_mask[:min(prompt_length, loss_mask.size(0)) - 1] = 0 + loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0 # mask out the last token in response loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0 return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'position_ids': position_ids, - 'loss_mask': loss_mask + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, } diff --git a/verl/utils/debug/__init__.py b/verl/utils/debug/__init__.py index 0d0b3432e..13e712374 100644 --- a/verl/utils/debug/__init__.py +++ b/verl/utils/debug/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .performance import log_gpu_memory_usage \ No newline at end of file +from .performance import log_gpu_memory_usage diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index 615475a66..ca0a44f7f 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import torch import torch.distributed as dist -import logging def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): @@ -22,7 +23,7 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging memory_allocated = torch.cuda.memory_allocated() / 1024**3 memory_reserved = torch.cuda.memory_reserved() / 1024**3 - message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}' + message = f"{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" if logger is None: print(message) diff --git a/verl/utils/debug/trajectory_tracker.py b/verl/utils/debug/trajectory_tracker.py index 33b254685..73afb8540 100644 --- a/verl/utils/debug/trajectory_tracker.py +++ b/verl/utils/debug/trajectory_tracker.py @@ -17,29 +17,30 @@ The results will be dump to hdfs for offline comparison. Each process will have a client that first move all the tensors to CPU """ -from verl.utils.hdfs_io import makedirs, copy -import torch -import os -import ray import io +import os import tempfile - from collections import deque +import ray +import torch + +from verl.utils.hdfs_io import copy, makedirs + remote_copy = ray.remote(copy) @ray.remote def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose): - filename = name + '.pth' + filename = name + ".pth" with tempfile.TemporaryDirectory() as tmpdirname: local_filepath = os.path.join(tmpdirname, filename) - with open(local_filepath, 'wb') as f: + with open(local_filepath, "wb") as f: f.write(data.getbuffer()) # upload to hdfs if verbose: - print(f'Saving {local_filepath} to {hdfs_dir}') + print(f"Saving {local_filepath} to {hdfs_dir}") try: copy(local_filepath, hdfs_dir) except Exception as e: @@ -47,8 +48,7 @@ def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose): @ray.remote -class TrajectoryTracker(): - +class TrajectoryTracker: def __init__(self, hdfs_dir, verbose) -> None: self.hdfs_dir = hdfs_dir makedirs(hdfs_dir) @@ -67,7 +67,7 @@ class TrajectoryTracker(): def dump_data(data, name): - enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1' + enable = os.getenv("VERL_ENABLE_TRACKER", "0") == "1" if not enable: return buffer = io.BytesIO() @@ -77,23 +77,24 @@ def dump_data(data, name): def get_trajectory_tracker(): - hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None) - verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1' + hdfs_dir = os.getenv("VERL_TRACKER_HDFS_DIR", default=None) + verbose = os.getenv("VERL_TRACKER_VERBOSE", default="0") == "1" assert hdfs_dir is not None - tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, - lifetime="detached").remote(hdfs_dir, verbose) + tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, lifetime="detached").remote( + hdfs_dir, verbose + ) return tracker -if __name__ == '__main__': +if __name__ == "__main__": # testing - os.environ['VERL_ENABLE_TRACKER'] = '1' - os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test' + os.environ["VERL_ENABLE_TRACKER"] = "1" + os.environ["VERL_TRACKER_HDFS_DIR"] = "~/debug/test" @ray.remote def process(iter): - data = {'obs': torch.randn(10, 20)} - dump_data(data, f'process_{iter}_obs') + data = {"obs": torch.randn(10, 20)} + dump_data(data, f"process_{iter}_obs") ray.init() diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 6fea5a29c..7aa30c168 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities for distributed training.""" + import os def initialize_global_process_group(timeout_second=36000): - import torch.distributed from datetime import timedelta - torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) + + import torch.distributed + + torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=timeout_second)) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 9bcebc851..4fb11943c 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -19,7 +19,6 @@ VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "deepseek_v3"} def get_device_flops(unit="T"): - def unit_convert(number, level): units = ["B", "K", "M", "G", "T", "P"] if number <= 0: @@ -62,16 +61,18 @@ class FlopsCounter: """ def __init__(self, config: PretrainedConfig): - if not config.model_type in VALID_CONFIG_TYPE: - print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. " - f"MFU will always be zero.") + if config.model_type not in VALID_CONFIG_TYPE: + print( + f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. " + f"MFU will always be zero." + ) self.estimate_func = { - 'qwen2': self._estimate_qwen2_flops, - 'llama': self._estimate_qwen2_flops, - 'qwen2_vl': self._estimate_qwen2_flops, - 'qwen2_5_vl': self._estimate_qwen2_flops, - 'deepseek_v3': self._estimate_deepseek_v3_flops, + "qwen2": self._estimate_qwen2_flops, + "llama": self._estimate_qwen2_flops, + "qwen2_vl": self._estimate_qwen2_flops, + "qwen2_5_vl": self._estimate_qwen2_flops, + "deepseek_v3": self._estimate_deepseek_v3_flops, } self.config = config @@ -138,14 +139,19 @@ class FlopsCounter: attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim) - attn_linear_N += (num_query_heads * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) * - self.config.kv_lora_rank) + attn_linear_N += ( + num_query_heads + * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) + * self.config.kv_lora_rank + ) attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size emd_and_lm_head_N = vocab_size * hidden_size * 2 # non-attn all_layer parm - moe_N = ((moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) + - (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + - emd_and_lm_head_N) + moe_N = ( + (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) + + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + + emd_and_lm_head_N + ) # non-attn all_layer & all_token fwd & bwd flops dense_N_flops = 6 * moe_N * tokens_sum diff --git a/verl/utils/fs.py b/verl/utils/fs.py index 769b7e48e..52e5c6961 100644 --- a/verl/utils/fs.py +++ b/verl/utils/fs.py @@ -15,14 +15,15 @@ # -*- coding: utf-8 -*- """File-system agnostic IO APIs""" + +import hashlib import os import tempfile -import hashlib try: - from hdfs_io import copy, makedirs, exists # for internal use only + from hdfs_io import copy, exists, makedirs # for internal use only except ImportError: - from .hdfs_io import copy, makedirs, exists + from .hdfs_io import copy, exists, makedirs __all__ = ["copy", "exists", "makedirs"] @@ -55,7 +56,7 @@ def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: return dst -def copy_to_local(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str: +def copy_to_local(src: str, cache_dir=None, filelock=".file.lock", verbose=False) -> str: """Copy src from hdfs to local if src is on hdfs or directly return src. If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if the src name is the same between calls @@ -69,11 +70,11 @@ def copy_to_local(src: str, cache_dir=None, filelock='.file.lock', verbose=False return copy_local_path_from_hdfs(src, cache_dir, filelock, verbose) -def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str: +def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock=".file.lock", verbose=False) -> str: """Deprecated. Please use copy_to_local instead.""" from filelock import FileLock - assert src[-1] != '/', f'Make sure the last char in src is not / because it will cause error. Got {src}' + assert src[-1] != "/", f"Make sure the last char in src is not / because it will cause error. Got {src}" if is_non_local(src): # download from hdfs to local @@ -84,12 +85,12 @@ def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', v assert os.path.exists(cache_dir) local_path = get_local_temp_path(src, cache_dir) # get a specific lock - filelock = md5_encode(src) + '.lock' + filelock = md5_encode(src) + ".lock" lock_file = os.path.join(cache_dir, filelock) with FileLock(lock_file=lock_file): if not os.path.exists(local_path): if verbose: - print(f'Copy from {src} to {local_path}') + print(f"Copy from {src} to {local_path}") copy(src, local_path) return local_path else: diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 4929c6764..af3694b44 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -12,21 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict import functools +import itertools import json import math -import itertools import os from contextlib import contextmanager +from typing import Dict + +import torch +import torch.distributed as dist +import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._runtime_utils import _lazy_init +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name -import torch -import torch.nn as nn -import torch.distributed as dist def init_fn(x: torch.nn.Module): @@ -38,7 +39,8 @@ def init_fn(x: torch.nn.Module): def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None): from accelerate import init_empty_weights - cpu_init_weights = lambda: torch.device('cpu') + + cpu_init_weights = lambda: torch.device("cpu") if use_meta_tensor: if mesh is None: init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights @@ -53,7 +55,7 @@ def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = Non # Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py def get_fsdp_wrap_policy(module, config=None, is_lora=False): """Get FSDP wrap policy for the module. - + Args: module: The module to get wrap policy for config: Configuration for wrap policy @@ -62,25 +64,29 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False): if config is None: config = {} - if config.get('disable', False): + if config.get("disable", False): return None default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) - fsdp_transformer_layer_cls_to_wrap = config.get("transformer_layer_cls_to_wrap", - default_transformer_cls_names_to_wrap) - min_num_params = config.get('min_num_params', 0) + fsdp_transformer_layer_cls_to_wrap = config.get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) + min_num_params = config.get("min_num_params", 0) auto_wrap_policy = None policies = [] - from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy # Add lambda policy for LoRA modules if is_lora is True if is_lora: def lambda_policy_fn(module): - if (len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and - module.weight.requires_grad): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): return True return False @@ -116,14 +122,16 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): assert isinstance(model, FSDP) # lazy init FSDP model _lazy_init(model, model) - assert model._is_root, f"Only support root model offloading to CPU" + assert model._is_root, "Only support root model offloading to CPU" for handle in model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param - assert flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and \ - id(flat_param.data) != id(flat_param._local_shard) and \ - flat_param.data.size() == flat_param._local_shard.size() + assert ( + flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() + and id(flat_param.data) != id(flat_param._local_shard) + and flat_param.data.size() == flat_param._local_shard.size() + ) handle.flat_param_to(torch.device("cpu"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @@ -137,7 +145,7 @@ def load_fsdp_model_to_gpu(model: FSDP): assert isinstance(model, FSDP) # lazy init FSDP model _lazy_init(model, model) - assert model._is_root, f"Only support root model loading to GPU" + assert model._is_root, "Only support root model loading to GPU" device_id = torch.cuda.current_device() for handle in model._all_handles: if handle._offload_params: @@ -153,7 +161,7 @@ def offload_fsdp_optimizer(optimizer): if not optimizer.state: return for param_group in optimizer.param_groups: - for param in param_group['params']: + for param in param_group["params"]: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): @@ -165,7 +173,7 @@ def load_fsdp_optimizer(optimizer, device_id): if not optimizer.state: return for param_group in optimizer.param_groups: - for param in param_group['params']: + for param in param_group["params"]: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): @@ -242,7 +250,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = sorted(safetensors2param.keys()) world_size = dist.get_world_size() size = int(math.ceil(total_files / world_size)) - ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)] + ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] shard_states = {} device = torch.cuda.current_device() @@ -274,8 +282,9 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor """ state2fqn = {} - for name, state in itertools.chain(module.named_parameters(remove_duplicate=False), - module.named_buffers(remove_duplicate=False)): + for name, state in itertools.chain( + module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) + ): state2fqn.setdefault(state, []).append(name) # remove standalone parameters and buffers shared = {s for s, names in state2fqn.items() if len(names) > 1} @@ -314,7 +323,8 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor if state.is_meta: raise RuntimeError( f"find a non-persistent buffer ({fqn}) initiated with device meta. " - "Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.") + "Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device." + ) continue # for shared parameter, we get it from the first time it is created if state in shared: diff --git a/verl/utils/hdfs_io.py b/verl/utils/hdfs_io.py index 08c4ecb9a..31edda1f6 100644 --- a/verl/utils/hdfs_io.py +++ b/verl/utils/hdfs_io.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import shutil -import logging logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) _HDFS_PREFIX = "hdfs://" -_HDFS_BIN_PATH = shutil.which('hdfs') +_HDFS_BIN_PATH = shutil.which("hdfs") def exists(path: str, **kwargs) -> bool: @@ -41,7 +41,7 @@ def exists(path: str, **kwargs) -> bool: def _exists(file_path: str): - """ hdfs capable to check whether a file_path is exists """ + """hdfs capable to check whether a file_path is exists""" if file_path.startswith("hdfs"): return _run_cmd(_hdfs_cmd(f"-test -e {file_path}")) == 0 return os.path.exists(file_path) @@ -118,8 +118,13 @@ def _copy(from_path: str, to_path: str, timeout: int = None) -> bool: returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout) else: if from_path.startswith("hdfs"): - returncode = _run_cmd(_hdfs_cmd(f"-get \ - {from_path} {to_path}"), timeout=timeout) + returncode = _run_cmd( + _hdfs_cmd( + f"-get \ + {from_path} {to_path}" + ), + timeout=timeout, + ) else: try: shutil.copy(from_path, to_path) diff --git a/verl/utils/import_utils.py b/verl/utils/import_utils.py index 740affe3a..fa21887d4 100644 --- a/verl/utils/import_utils.py +++ b/verl/utils/import_utils.py @@ -16,15 +16,15 @@ Utilities to check if packages are available. We assume package availability won't change during runtime. """ +import importlib from functools import cache from typing import List, Optional -import importlib @cache def is_megatron_core_available(): try: - mcore_spec = importlib.util.find_spec('megatron.core') + mcore_spec = importlib.util.find_spec("megatron.core") except ModuleNotFoundError: mcore_spec = None return mcore_spec is not None @@ -33,7 +33,7 @@ def is_megatron_core_available(): @cache def is_vllm_available(): try: - vllm_spec = importlib.util.find_spec('vllm') + vllm_spec = importlib.util.find_spec("vllm") except ModuleNotFoundError: vllm_spec = None return vllm_spec is not None @@ -42,7 +42,7 @@ def is_vllm_available(): @cache def is_sglang_available(): try: - sglang_spec = importlib.util.find_spec('sglang') + sglang_spec = importlib.util.find_spec("sglang") except ModuleNotFoundError: sglang_spec = None return sglang_spec is not None @@ -54,13 +54,15 @@ def import_external_libs(external_libs=None): if not isinstance(external_libs, List): external_libs = [external_libs] import importlib + for external_lib in external_libs: importlib.import_module(external_lib) def load_extern_type(file_path: Optional[str], type_name: Optional[str]): """Load a external data type based on the file path and type name""" - import importlib.util, os + import importlib.util + import os if not file_path: return None @@ -78,4 +80,4 @@ def load_extern_type(file_path: Optional[str], type_name: Optional[str]): if not hasattr(module, type_name): raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.") - return getattr(module, type_name) \ No newline at end of file + return getattr(module, type_name) diff --git a/verl/utils/logger/aggregate_logger.py b/verl/utils/logger/aggregate_logger.py index ac57cf58e..ec83edb17 100644 --- a/verl/utils/logger/aggregate_logger.py +++ b/verl/utils/logger/aggregate_logger.py @@ -14,29 +14,29 @@ """ A Ray logger will receive logging info from different processes. """ + import numbers from typing import Dict def concat_dict_to_str(dict: Dict, step): - output = [f'step:{step}'] + output = [f"step:{step}"] for k, v in dict.items(): if isinstance(v, numbers.Number): - output.append(f'{k}:{v:.3f}') - output_str = ' - '.join(output) + output.append(f"{k}:{v:.3f}") + output_str = " - ".join(output) return output_str class LocalLogger: - def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False): self.print_to_console = print_to_console if print_to_console: - print('Using LocalLogger is deprecated. The constructor API will change ') + print("Using LocalLogger is deprecated. The constructor API will change ") def flush(self): pass def log(self, data, step): if self.print_to_console: - print(concat_dict_to_str(data, step=step), flush=True) \ No newline at end of file + print(concat_dict_to_str(data, step=step), flush=True) diff --git a/verl/utils/logging_utils.py b/verl/utils/logging_utils.py index 329835347..13fa9170b 100644 --- a/verl/utils/logging_utils.py +++ b/verl/utils/logging_utils.py @@ -14,6 +14,7 @@ import logging import os + import torch @@ -21,11 +22,11 @@ def set_basic_config(level): """ This function sets the global logging format and level. It will be called when import verl """ - logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s', level=level) + logging.basicConfig(format="%(levelname)s:%(asctime)s:%(message)s", level=level) def log_to_file(string): print(string) - if os.path.isdir('logs'): - with open(f'logs/log_{torch.distributed.get_rank()}', 'a+') as f: - f.write(string + '\n') + if os.path.isdir("logs"): + with open(f"logs/log_{torch.distributed.get_rank()}", "a+") as f: + f.write(string + "\n") diff --git a/verl/utils/megatron/memory.py b/verl/utils/megatron/memory.py index 5e8570ed4..ee912e878 100644 --- a/verl/utils/megatron/memory.py +++ b/verl/utils/megatron/memory.py @@ -16,15 +16,13 @@ import torch class MemoryBuffer: - def __init__(self, numel, numel_padded, dtype): self.numel = numel self.numel_padded = numel_padded self.dtype = dtype - self.data = torch.zeros(self.numel_padded, - dtype=self.dtype, - device=torch.cuda.current_device(), - requires_grad=False) + self.data = torch.zeros( + self.numel_padded, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False + ) def zero(self): """Reset the buffer to zero.""" @@ -34,8 +32,7 @@ class MemoryBuffer: """Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`.""" end_index = start_index + shape.numel() - assert end_index <= self.numel, \ - 'requested tensor is out of the buffer range.' + assert end_index <= self.numel, "requested tensor is out of the buffer range." buffer_tensor = self.data[start_index:end_index] buffer_tensor = buffer_tensor.view(shape) return buffer_tensor diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py index 1d936d9ed..eacdbc36d 100644 --- a/verl/utils/megatron/optimizer.py +++ b/verl/utils/megatron/optimizer.py @@ -13,29 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib -from packaging.version import Version - -from apex.optimizers import FusedAdam as Adam -from apex.optimizers import FusedSGD as SGD from megatron.core.optimizer import OptimizerConfig - from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native def get_megatron_optimizer( - model, - config: OptimizerConfig, - no_weight_decay_cond=None, - scale_lr_cond=None, - lr_mult=1.0, - check_for_nan_in_loss_and_grad=False, - overlap_param_gather=False # add for verl + model, + config: OptimizerConfig, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, + check_for_nan_in_loss_and_grad=False, + overlap_param_gather=False, # add for verl ): # Base optimizer. - return get_megatron_optimizer_native(config=config, - model_chunks=model, - no_weight_decay_cond=no_weight_decay_cond, - scale_lr_cond=scale_lr_cond, - lr_mult=lr_mult) + return get_megatron_optimizer_native( + config=config, + model_chunks=model, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + ) diff --git a/verl/utils/megatron/pipeline_parallel.py b/verl/utils/megatron/pipeline_parallel.py index 3a3790bb1..b7e272763 100644 --- a/verl/utils/megatron/pipeline_parallel.py +++ b/verl/utils/megatron/pipeline_parallel.py @@ -21,22 +21,28 @@ from .sequence_parallel import pad_to_sequence_parallel def compute_transformers_input_shapes(batches, meta_info): from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron + # pre-compute input shapes for each micro-batch at each pp stage input_shapes = [] for model_inputs in batches: - input_ids = model_inputs['input_ids'] - attention_mask = model_inputs['attention_mask'] + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1) - if meta_info['sequence_parallel']: + if meta_info["sequence_parallel"]: input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad) # compute shapes for model_inputs input_shapes.append( - torch.Size([ - input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size'] - ])) + torch.Size( + [ + input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), + 1, + meta_info["hidden_size"], + ] + ) + ) else: # compute shapes for model_inputs - input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']])) + input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info["hidden_size"]])) return input_shapes diff --git a/verl/utils/megatron/sequence_parallel.py b/verl/utils/megatron/sequence_parallel.py index 4b76cb295..e979bd4d2 100644 --- a/verl/utils/megatron/sequence_parallel.py +++ b/verl/utils/megatron/sequence_parallel.py @@ -19,11 +19,11 @@ from megatron.core import parallel_state as mpu def mark_parameter_as_sequence_parallel(parameter): - setattr(parameter, 'sequence_parallel', True) + parameter.sequence_parallel = True def is_sequence_parallel_param(param): - return hasattr(param, 'sequence_parallel') and param.sequence_parallel + return hasattr(param, "sequence_parallel") and param.sequence_parallel def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): @@ -49,6 +49,6 @@ def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): elif unpad_tokens.ndim == 2: unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) else: - raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported') + raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") return unpad_tokens diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py index 623c4129d..c7b73233e 100644 --- a/verl/utils/megatron/tensor_parallel.py +++ b/verl/utils/megatron/tensor_parallel.py @@ -15,26 +15,28 @@ """ Utilities for using tensor_parallel in megatron """ + from typing import Dict + import torch -from torch.nn import init import torch.distributed as dist -from megatron.core import ModelParallelConfig -from megatron.core import parallel_state as mpu, tensor_parallel +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch.nn import init def update_kwargs_with_config(dictionary: Dict, config: ModelParallelConfig): - dictionary['config'] = config + dictionary["config"] = config return dictionary def get_default_kwargs_for_model_parallel_config(): model_parallel_config_kwargs = { - 'params_dtype': torch.float32, - 'use_cpu_initialization': False, - 'perform_initialization': True, - 'gradient_accumulation_fusion': False, - 'sequence_parallel': False, + "params_dtype": torch.float32, + "use_cpu_initialization": False, + "perform_initialization": True, + "gradient_accumulation_fusion": False, + "sequence_parallel": False, } return model_parallel_config_kwargs @@ -46,10 +48,10 @@ def get_default_model_parallel_config(): def get_common_default_kwargs_for_parallel_linear(): default_model_parallel_config = get_default_model_parallel_config() common_default_kwargs = { - 'init_method': init.xavier_normal_, - 'stride': 1, - 'keep_master_weight_for_test': False, - 'config': default_model_parallel_config, + "init_method": init.xavier_normal_, + "stride": 1, + "keep_master_weight_for_test": False, + "config": default_model_parallel_config, } return common_default_kwargs @@ -57,11 +59,11 @@ def get_common_default_kwargs_for_parallel_linear(): def get_default_kwargs_for_column_parallel_linear(): model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config() column_parallel_config_kwargs = { - 'async_tensor_model_parallel_allreduce': False, + "async_tensor_model_parallel_allreduce": False, } model_parallel_config_kwargs.update(column_parallel_config_kwargs) column_default_kwargs = { - 'config': ModelParallelConfig(**model_parallel_config_kwargs), + "config": ModelParallelConfig(**model_parallel_config_kwargs), } common_default_kwargs = get_common_default_kwargs_for_parallel_linear() common_default_kwargs.update(column_default_kwargs) @@ -76,14 +78,14 @@ def get_default_kwargs_for_row_parallel_linear(): def get_default_kwargs_for_parallel_embedding(): model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config() embedding_default_kwargs = { - 'init_method': init.xavier_normal_, - 'config': ModelParallelConfig(**model_parallel_config_kwargs), + "init_method": init.xavier_normal_, + "config": ModelParallelConfig(**model_parallel_config_kwargs), } return embedding_default_kwargs def is_tensor_parallel_param(param): - return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) + return hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel def get_tensor_parallel_partition_dim(param): @@ -97,10 +99,8 @@ def get_tensor_parallel_partition_stride(param): class _VocabParallelEntropy(torch.autograd.Function): - @staticmethod def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: - @torch.compile(dynamic=True) def mul_reduce(a, b): return (a * b).sum(dim=-1, keepdim=True) @@ -133,12 +133,12 @@ class _VocabParallelEntropy(torch.autograd.Function): def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor: """Compute entropy when the logits are sharded in tp ranks - + Args: vocab_parallel_logits: (total_nnz, vocab_size // tp_size) Returns: (total_nnz,) - + """ return _VocabParallelEntropy.apply(vocab_parallel_logits) @@ -165,11 +165,11 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad, - labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) - output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length] + full_log_probs_rmpad = vocab_parallel_log_probs_from_logits( + logits=logits_rmpad, labels=input_ids_rmpad_rolled + ) # (total_nnz,) + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] return output diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 16e38d31f..55c0f8c41 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Pretrain utilities.""" + import os +import warnings from typing import Any, Dict + import torch import torch.nn as nn import torch.nn.functional as F -from megatron.core import ModelParallelConfig -from megatron.core import mpu, tensor_parallel +from megatron.core import ModelParallelConfig, mpu, tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.enums import ModelType @@ -28,26 +30,27 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_attr_wrapped_model -from omegaconf import DictConfig from verl.utils.memory_buffer import build_memory_reference_from_module from verl.utils.torch_dtypes import PrecisionType def get_model_config(model): - return get_attr_wrapped_model(model, 'config', allow_none=False) + return get_attr_wrapped_model(model, "config", allow_none=False) -def get_model(model_provider_func, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_distributed_optimizer=True): +def get_model( + model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True, use_distributed_optimizer=True +): """Build the model.""" # Build model. - if mpu.get_pipeline_model_parallel_world_size() > 1 and \ - mpu.get_virtual_pipeline_model_parallel_world_size() is not None: - assert model_type != ModelType.encoder_and_decoder, \ + if ( + mpu.get_pipeline_model_parallel_world_size() > 1 + and mpu.get_virtual_pipeline_model_parallel_world_size() is not None + ): + assert model_type != ModelType.encoder_and_decoder, ( "Interleaved schedule not supported for model with both encoder and decoder" + ) model = [] for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()): mpu.set_virtual_pipeline_model_parallel_rank(i) @@ -64,8 +67,9 @@ def get_model(model_provider_func, add_decoder = True if model_type == ModelType.encoder_and_decoder: if mpu.get_pipeline_model_parallel_world_size() > 1: - assert mpu.get_pipeline_model_parallel_split_rank() is not None, \ + assert mpu.get_pipeline_model_parallel_split_rank() is not None, ( "Split rank needs to be specified for model with both encoder and decoder" + ) rank = mpu.get_pipeline_model_parallel_rank() split_rank = mpu.get_pipeline_model_parallel_split_rank() world_size = mpu.get_pipeline_model_parallel_world_size() @@ -73,10 +77,9 @@ def get_model(model_provider_func, post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1)) add_encoder = mpu.is_pipeline_stage_before_split() add_decoder = mpu.is_pipeline_stage_after_split() - model = model_provider_func(pre_process=pre_process, - post_process=post_process, - add_encoder=add_encoder, - add_decoder=add_decoder) + model = model_provider_func( + pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder + ) else: model = model_provider_func(pre_process=pre_process, post_process=post_process) model.model_type = model_type @@ -94,11 +97,14 @@ def get_model(model_provider_func, # Print number of parameters. if mpu.get_data_parallel_rank() == 0: - print(' > number of parameters on (tensor, pipeline) ' - 'model parallel rank ({}, {}): {}'.format( - mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), - sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])), - flush=True) + print( + " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( + mpu.get_tensor_model_parallel_rank(), + mpu.get_pipeline_model_parallel_rank(), + sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]), + ), + flush=True, + ) # GPU allocation. for model_module in model: @@ -122,7 +128,8 @@ def get_model(model_provider_func, overlap_grad_reduce=False, use_distributed_optimizer=use_distributed_optimizer, grad_reduce_in_fp32=True, # [old] accumulate_allreduce_grads_in_fp32=True, - )) + ), + ) ddp_models.append(ddp_model) model = ddp_models # # Broadcast params from data parallel src rank to other data parallel ranks. @@ -154,15 +161,17 @@ from transformers import PretrainedConfig def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: - print(f'megatron config {megatron_config}') + print(f"megatron config {megatron_config}") dt = PrecisionType.to_dtype(megatron_config.params_dtype) - print(f'pipeline_dtype=megatron_config {dt}') + print(f"pipeline_dtype=megatron_config {dt}") if "Qwen2ForCausalLM" in hf_config.architectures: qkv_bias = True else: - qkv_bias = getattr(hf_config, 'attention_bias', False) - overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size( - ) is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + qkv_bias = getattr(hf_config, "attention_bias", False) + overlap_p2p_comm = ( + mpu.get_virtual_pipeline_model_parallel_world_size() is not None + and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + ) batch_p2p_comm = False transformer_config = TransformerConfig( num_layers=hf_config.num_hidden_layers, @@ -172,7 +181,7 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC ffn_hidden_size=hf_config.intermediate_size, # max_position_embeddings=hf_config.max_position_embeddings, activation_func=F.silu, - normalization='RMSNorm', + normalization="RMSNorm", # rotary_percent=False, # default, gated_linear_unit=True, # for llama use_cpu_initialization=True, @@ -191,19 +200,20 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC masked_softmax_fusion=True, moe_token_dispatcher_type="alltoall", attention_dropout=hf_config.attention_dropout, - hidden_dropout=getattr(hf_config, 'hidden_dropout', 0.0), + hidden_dropout=getattr(hf_config, "hidden_dropout", 0.0), add_qkv_bias=qkv_bias, attention_backend=AttnBackend.flash, - bf16=dt is torch.bfloat16) + bf16=dt is torch.bfloat16, + ) return transformer_config def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig: config = OptimizerConfig( - optimizer='adam', - lr=optim_config.get('lr'), - clip_grad=optim_config.get('clip_grad'), + optimizer="adam", + lr=optim_config.get("lr"), + clip_grad=optim_config.get("clip_grad"), weight_decay=1e-2, bf16=True, params_dtype=torch.bfloat16, @@ -222,7 +232,8 @@ def mcore_model_parallel_config( "Code should not reach this point. This function is deprecated and will be removed. " "Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return ModelParallelConfig( tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), @@ -233,19 +244,20 @@ def mcore_model_parallel_config( pipeline_dtype=params_dtype, bf16=True, fp16=False, - timers=None) + timers=None, + ) def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_grad=False, hybrid_engine=None): if hybrid_engine is not None: pp_rank = mpu.get_pipeline_model_parallel_rank() for buffer in hybrid_engine.memory_buffers[pp_rank].values(): - buffer.data = buffer.data.to('cpu', non_blocking=True) + buffer.data = buffer.data.to("cpu", non_blocking=True) build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True) else: for module in module_list: for _, param in module.named_parameters(): - param.data = param.data.to('cpu', non_blocking=True) + param.data = param.data.to("cpu", non_blocking=True) if offload_grad and param.grad is not None: param.grad = param.grad.to("cpu", non_blocking=True) torch.cuda.empty_cache() @@ -293,29 +305,30 @@ def get_optimizer_checkpoint_path(checkpoint_path, use_distributed_optimizer=Tru tp_rank = mpu.get_tensor_model_parallel_rank() cp_rank = mpu.get_context_parallel_rank() dp_rank = mpu.get_data_parallel_rank() - #TODO: support ep - return os.path.join(checkpoint_path, f"optim", f"distrib_optim_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") + # TODO: support ep + return os.path.join(checkpoint_path, "optim", f"distrib_optim_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") def get_rng_states_checkpoint_path(checkpoint_path, only_rank0_save=True): # save rng states cause interrupts os.makedirs(os.path.join(checkpoint_path, "rng_states"), exist_ok=True) if only_rank0_save: - return os.path.join(checkpoint_path, f'rng_states', "rng_states.pt") + return os.path.join(checkpoint_path, "rng_states", "rng_states.pt") dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() tp_rank = mpu.get_tensor_model_parallel_rank() cp_rank = mpu.get_context_parallel_rank() - return os.path.join(checkpoint_path, f'rng_states', - f"rng_states_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") + return os.path.join(checkpoint_path, "rng_states", f"rng_states_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") -def convert_megatron_model_to_transformers_model(name, - param, - config: PretrainedConfig, - tp_size: int, - num_query_groups: int, - convert_qkv_gate_up_by_trunk_concat=False): +def convert_megatron_model_to_transformers_model( + name, + param, + config: PretrainedConfig, + tp_size: int, + num_query_groups: int, + convert_qkv_gate_up_by_trunk_concat=False, +): """Convert megatron model to transformers model.""" new_params = {} @@ -335,13 +348,13 @@ def convert_megatron_model_to_transformers_model(name, total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): num_query_groups_per_partition = num_query_groups // tp_size - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_size_chunk = q_size_tp // num_query_groups_per_partition kv_size_chunk = kv_size_tp // num_query_groups_per_partition for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] q_shard_list.append(q_part) k_shard_list.append(k_part) v_shard_list.append(v_part) @@ -351,13 +364,13 @@ def convert_megatron_model_to_transformers_model(name, total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): num_query_groups_per_partition = num_query_groups // tp_size - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_size_chunk = q_size_tp // num_query_groups_per_partition kv_size_chunk = kv_size_tp // num_query_groups_per_partition for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] q_shard_list.append(q_part) if i * config.num_key_value_heads % tp_size == 0: k_shard_list.append(k_part) @@ -375,7 +388,7 @@ def convert_megatron_model_to_transformers_model(name, gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -384,54 +397,60 @@ def convert_megatron_model_to_transformers_model(name, new_params[gate_name] = torch.cat(gate_weight_list, dim=0) new_params[up_name] = torch.cat(up_weight_list, dim=0) - if name == 'embedding.word_embeddings.weight': - new_params['model.embed_tokens.weight'] = param - elif 'self_attention' in name: - splitted_name = name.split('.') + if name == "embedding.word_embeddings.weight": + new_params["model.embed_tokens.weight"] = param + elif "self_attention" in name: + splitted_name = name.split(".") layer_number = splitted_name[2] component = splitted_name[4] param_type = splitted_name[5] - if component == 'linear_proj': - new_params[f'model.layers.{layer_number}.self_attn.o_proj.weight'] = param - elif component == 'linear_qkv' and not isinstance(param, list): - if param_type == 'layer_norm_weight': - new_params[f'model.layers.{layer_number}.input_layernorm.weight'] = param + if component == "linear_proj": + new_params[f"model.layers.{layer_number}.self_attn.o_proj.weight"] = param + elif component == "linear_qkv" and not isinstance(param, list): + if param_type == "layer_norm_weight": + new_params[f"model.layers.{layer_number}.input_layernorm.weight"] = param else: if convert_qkv_gate_up_by_trunk_concat: - convert_qkv_shard(param, f'model.layers.{layer_number}.self_attn.q_proj.{param_type}', - f'model.layers.{layer_number}.self_attn.k_proj.{param_type}', - f'model.layers.{layer_number}.self_attn.v_proj.{param_type}') + convert_qkv_shard( + param, + f"model.layers.{layer_number}.self_attn.q_proj.{param_type}", + f"model.layers.{layer_number}.self_attn.k_proj.{param_type}", + f"model.layers.{layer_number}.self_attn.v_proj.{param_type}", + ) else: - new_params[f'model.layers.{layer_number}.self_attn.qkv_proj.{param_type}'] = param + new_params[f"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}"] = param else: assert isinstance(param, list) and len(param) == 3 - assert param_type == 'weight' or param_type == 'bias' - new_params[f'model.layers.{layer_number}.self_attn.q_proj.{param_type}'] = param[0] - new_params[f'model.layers.{layer_number}.self_attn.k_proj.{param_type}'] = param[1] - new_params[f'model.layers.{layer_number}.self_attn.v_proj.{param_type}'] = param[2] - elif 'mlp' in name: - splitted_name = name.split('.') + assert param_type == "weight" or param_type == "bias" + new_params[f"model.layers.{layer_number}.self_attn.q_proj.{param_type}"] = param[0] + new_params[f"model.layers.{layer_number}.self_attn.k_proj.{param_type}"] = param[1] + new_params[f"model.layers.{layer_number}.self_attn.v_proj.{param_type}"] = param[2] + elif "mlp" in name: + splitted_name = name.split(".") layer_number = splitted_name[2] component = splitted_name[4] param_type = splitted_name[5] - if component == 'linear_fc1' and not isinstance(param, list): - if param_type == 'layer_norm_weight': - new_params[f'model.layers.{layer_number}.post_attention_layernorm.weight'] = param - elif param_type == 'weight': + if component == "linear_fc1" and not isinstance(param, list): + if param_type == "layer_norm_weight": + new_params[f"model.layers.{layer_number}.post_attention_layernorm.weight"] = param + elif param_type == "weight": if convert_qkv_gate_up_by_trunk_concat: - convert_gate_up_shard(param, f'model.layers.{layer_number}.mlp.gate_proj.weight', - f'model.layers.{layer_number}.mlp.up_proj.weight') + convert_gate_up_shard( + param, + f"model.layers.{layer_number}.mlp.gate_proj.weight", + f"model.layers.{layer_number}.mlp.up_proj.weight", + ) else: - new_params[f'model.layers.{layer_number}.mlp.gate_up_proj.weight'] = param - elif component == 'linear_fc1' and isinstance(param, list): + new_params[f"model.layers.{layer_number}.mlp.gate_up_proj.weight"] = param + elif component == "linear_fc1" and isinstance(param, list): assert len(param) == 2 - assert param_type == 'weight' or param_type == 'bias' - new_params[f'model.layers.{layer_number}.mlp.gate_proj.weight'] = param[0] - new_params[f'model.layers.{layer_number}.mlp.up_proj.weight'] = param[1] - elif component == 'linear_fc2': - new_params[f'model.layers.{layer_number}.mlp.down_proj.weight'] = param + assert param_type == "weight" or param_type == "bias" + new_params[f"model.layers.{layer_number}.mlp.gate_proj.weight"] = param[0] + new_params[f"model.layers.{layer_number}.mlp.up_proj.weight"] = param[1] + elif component == "linear_fc2": + new_params[f"model.layers.{layer_number}.mlp.down_proj.weight"] = param elif name == "decoder.final_layernorm.weight": - new_params['model.norm.weight'] = param + new_params["model.norm.weight"] = param elif name == "output_layer.weight": new_params["lm_head.weight"] = param else: @@ -444,15 +463,15 @@ def broadcast_from_megatron_pp(tensor: torch.Tensor): if tensor is not None: shape = tensor.shape dtype = tensor.dtype - tensor_parallel = getattr(tensor, 'tensor_model_parallel', None) - partition_dim = getattr(tensor, 'partition_dim', None) + tensor_parallel = getattr(tensor, "tensor_model_parallel", None) + partition_dim = getattr(tensor, "partition_dim", None) tensor_spec = (shape, dtype, tensor_parallel, partition_dim) else: tensor_spec = None tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=tensor_spec_output, - obj=tensor_spec, - group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.all_gather_object( + object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group() + ) # find the src rank target_tensor_spec = None src_rank = None @@ -461,13 +480,13 @@ def broadcast_from_megatron_pp(tensor: torch.Tensor): if target_tensor_spec is None: target_tensor_spec = tensor_spec else: - raise ValueError('A tensor exists on two pp ranks') + raise ValueError("A tensor exists on two pp ranks") src_rank = rank assert target_tensor_spec is not None if tensor is None: - tensor = torch.empty(size=target_tensor_spec[0], - dtype=target_tensor_spec[1], - device=torch.cuda.current_device()) + tensor = torch.empty( + size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=torch.cuda.current_device() + ) if target_tensor_spec[2] is not None: tensor.tensor_model_parallel = target_tensor_spec[2] if target_tensor_spec[3] is not None: @@ -487,7 +506,7 @@ def broadcast_str_from_megatron_pp(obj: Any): for rank, item in enumerate(obj_output): if item is not None: if target_obj is not None: - raise ValueError('An object exists on two pp ranks') + raise ValueError("An object exists on two pp ranks") target_obj = item src_rank = rank @@ -497,8 +516,8 @@ def broadcast_str_from_megatron_pp(obj: Any): obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) obj_output[0] = target_obj - torch.distributed.broadcast_object_list(object_list=obj_output, - src=global_rank, - group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.broadcast_object_list( + object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group() + ) return obj_output[0] diff --git a/verl/utils/memory_buffer.py b/verl/utils/memory_buffer.py index 07d9f23f3..a5bb0081f 100644 --- a/verl/utils/memory_buffer.py +++ b/verl/utils/memory_buffer.py @@ -34,7 +34,7 @@ class MemoryBuffer: if source is not None: self.data = source else: - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) + self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device="cuda", requires_grad=False) def zero(self): """Reset the buffer to zero.""" @@ -44,8 +44,7 @@ class MemoryBuffer: """Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`.""" end_index = start_index + shape.numel() - assert end_index <= self.numel, \ - 'requested tensor is out of the buffer range.' + assert end_index <= self.numel, "requested tensor is out of the buffer range." buffer_tensor = self.data[start_index:end_index] buffer_tensor = buffer_tensor.view(shape) return buffer_tensor @@ -64,7 +63,7 @@ def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]: """ weight_buffer_meta = {} for name, param in sorted(module.named_parameters()): - weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype} + weight_buffer_meta[name] = {"shape": param.shape, "dtype": param.dtype} return weight_buffer_meta @@ -80,8 +79,8 @@ def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype memory_buffers = {} total_numel_map = {} # map from dtype to the total numel for name, meta_info in sorted(weight_buffer_meta.items()): - shape = meta_info['shape'] - dtype = meta_info['dtype'] + shape = meta_info["shape"] + dtype = meta_info["dtype"] assert isinstance(shape, torch.Size) assert isinstance(dtype, torch.dtype) @@ -97,11 +96,11 @@ def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype return memory_buffers -def build_memory_reference_from_module(module: torch.nn.Module, - memory_buffers: Dict[torch.dtype, MemoryBuffer], - maintain_weight=True): +def build_memory_reference_from_module( + module: torch.nn.Module, memory_buffers: Dict[torch.dtype, MemoryBuffer], maintain_weight=True +): start_index = {} - for dtype in memory_buffers.keys(): + for dtype in memory_buffers: start_index[dtype] = 0 for name, param in sorted(module.named_parameters()): memory_buffer = memory_buffers[param.dtype] @@ -126,12 +125,12 @@ def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: """ start_idx = {} weight_buffers = {} - for dtype in memory_buffers.keys(): + for dtype in memory_buffers: start_idx[dtype] = 0 for name, meta_info in sorted(weight_buffer_meta.items()): - shape = meta_info['shape'] - dtype = meta_info['dtype'] + shape = meta_info["shape"] + dtype = meta_info["dtype"] buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype]) start_idx[dtype] += calc_padded_numel(shape, dtype) @@ -160,7 +159,7 @@ class MemoryBufferModuleWrapper: return self.weight_buffer_meta -class MegatronMemoryBufferForRollout(object): +class MegatronMemoryBufferForRollout: """ We assume that - inference engine has tp + dp diff --git a/verl/utils/model.py b/verl/utils/model.py index eb2729bf1..26893fe34 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -14,19 +14,26 @@ """ Utilities to create common models from huggingface """ + import os import warnings -from typing import Dict, Type, Optional +from typing import Dict, Optional, Type import numpy as np import torch from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification, GenerationConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + MistralForSequenceClassification, + PretrainedConfig, +) + from verl.models.registry import ModelRegistry class LambdaLayer(nn.Module): - def __init__(self, fn): super().__init__() self.fn = fn @@ -47,8 +54,9 @@ def update_model_config(module_config, override_config_kwargs): def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> Dict: if override_config_kwargs is None: override_config_kwargs = {} - assert isinstance(override_config_kwargs, Dict), \ - f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}' + assert isinstance(override_config_kwargs, Dict), ( + f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + ) module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) update_model_config(module_config, override_config_kwargs) @@ -86,11 +94,12 @@ def create_huggingface_actor(model_name: str, override_config_kwargs=None, autom override_config_kwargs = {} if automodel_kwargs is None: automodel_kwargs = {} - assert isinstance(override_config_kwargs, Dict), \ - f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}' - module_config = get_huggingface_actor_config(model_name, - override_config_kwargs, - trust_remote_code=automodel_kwargs.get('trust_remote_code', False)) + assert isinstance(override_config_kwargs, Dict), ( + f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + ) + module_config = get_huggingface_actor_config( + model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False) + ) module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs) return module @@ -105,55 +114,58 @@ def create_huggingface_critic(model_name: str, override_config_kwargs=None, auto Returns: """ - critic_module: nn.Module = create_huggingface_actor(model_name, - override_config_kwargs=override_config_kwargs, - automodel_kwargs=automodel_kwargs) + critic_module: nn.Module = create_huggingface_actor( + model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs + ) if automodel_kwargs is None: automodel_kwargs = {} - torch_dtype = automodel_kwargs.get('torch_dtype', torch.float32) - critic_module.lm_head = nn.Sequential(nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), - LambdaLayer(fn=squeeze)) + torch_dtype = automodel_kwargs.get("torch_dtype", torch.float32) + critic_module.lm_head = nn.Sequential( + nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze) + ) return critic_module -def get_model_size(model: nn.Module, scale='auto'): +def get_model_size(model: nn.Module, scale="auto"): n_params = sum(p.numel() for p in model.parameters()) - if scale == 'auto': + if scale == "auto": if n_params > 1e9: - scale = 'B' + scale = "B" elif n_params > 1e6: - scale = 'M' + scale = "M" elif n_params > 1e3: - scale = 'K' + scale = "K" else: - scale = '' + scale = "" - if scale == 'B': + if scale == "B": n_params = n_params / 1e9 - elif scale == 'M': + elif scale == "M": n_params = n_params / 1e6 - elif scale == 'K': + elif scale == "K": n_params = n_params / 1e3 - elif scale == '': + elif scale == "": pass else: - raise NotImplemented(f'Unknown scale {scale}') + raise NotImplementedError(f"Unknown scale {scale}") return n_params, scale def print_model_size(model: nn.Module, name: str = None): - n_params, scale = get_model_size(model, scale='auto') + n_params, scale = get_model_size(model, scale="auto") if name is None: name = model.__class__.__name__ - print(f'{name} contains {n_params:.2f}{scale} parameters') + print(f"{name} contains {n_params:.2f}{scale} parameters") -def create_random_mask(input_ids: torch.Tensor, - max_ratio_of_valid_token: float, - max_ratio_of_left_padding: float, - min_ratio_of_valid_token: float = 0): +def create_random_mask( + input_ids: torch.Tensor, + max_ratio_of_valid_token: float, + max_ratio_of_left_padding: float, + min_ratio_of_valid_token: float = 0, +): """Create a random mask given input_ids. Support left padding and right padding. Process: - Sample valid token length @@ -167,8 +179,8 @@ def create_random_mask(input_ids: torch.Tensor, Returns: """ - assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1. - assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1. + assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0 + assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0 assert min_ratio_of_valid_token <= max_ratio_of_valid_token batch_size, sequence_length = input_ids.shape @@ -211,22 +223,22 @@ def normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_layers, layer_offset = layers_per_pp * pp_rank if layer_name in name: # belong to an intermediate layer - split_name = name.split('.') + split_name = name.split(".") # find the num next to split_name for i, name in enumerate(split_name): if name == layer_name: break layer_num_idx = i + 1 # check the name - assert len(split_name) >= layer_num_idx + 1, f'split_name = {split_name}' - assert split_name[layer_num_idx].isdigit(), f'split_name = {split_name}' + assert len(split_name) >= layer_num_idx + 1, f"split_name = {split_name}" + assert split_name[layer_num_idx].isdigit(), f"split_name = {split_name}" # increment layer_num_idx by layer_offset split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset) - name = '.'.join(split_name) # weight name in inference_tp_model + name = ".".join(split_name) # weight name in inference_tp_model return name -def normalize_pp_vpp_params(params, num_hidden_layers, layer_name='layers'): +def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"): """ Normalize the pp vpp params into a complete named parameters. This is useful when gather parameters from pp ranks and passed to a model without pp @@ -241,31 +253,27 @@ def normalize_pp_vpp_params(params, num_hidden_layers, layer_name='layers'): vpp_size = len(params[pp_rank]) for vpp_rank in range(vpp_size): for name, param in params[pp_rank][vpp_rank].items(): - normalized_name = normalize_model_name(name, - pp_rank, - vpp_rank, - pp_size, - vpp_size, - num_hidden_layers, - layer_name=layer_name) + normalized_name = normalize_model_name( + name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name + ) yield normalized_name, param -def get_parallel_model_from_config(config, - megatron_config, - pre_process=None, - post_process=None, - share_embeddings_and_output_weights=False, - value=False): +def get_parallel_model_from_config( + config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False +): from megatron.core import ModelParallelConfig + assert isinstance(megatron_config, ModelParallelConfig) model_class = _get_parallel_model_architecture_from_config(config, value) - model = model_class(config, - megatron_config, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights) + model = model_class( + config, + megatron_config, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + ) return model @@ -273,18 +281,21 @@ def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value architectures = getattr(config, "architectures", []) for arch in architectures: model_cls = ModelRegistry.load_model_cls(arch, value) - print(f'after load model cls') + print("after load model cls") if model_cls is not None: return model_cls - raise ValueError(f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) def _load_hf_model(config, model_config, is_value_model, local_cache_path): """Helper function containing the loading hf model logic""" - from megatron.core import parallel_state as mpu - from verl.models.mcore.saver import _megatron_calc_global_rank from accelerate import init_empty_weights + from megatron.core import parallel_state as mpu + + from verl.models.mcore.saver import _megatron_calc_global_rank assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!" architectures = getattr(model_config, "architectures", []) @@ -292,20 +303,21 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): if config.model.path.startswith("hdfs:"): from verl.utils.fs import copy_to_local - print(f'start download from {config.model.path}') + + print(f"start download from {config.model.path}") local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path) - print('finish download') + print("finish download") else: local_model_path = config.model.path print(f"load from local dir {local_model_path}") src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank()) - cpu_init_weights = lambda: torch.device('cpu') + cpu_init_weights = lambda: torch.device("cpu") init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") # TODO: to find a better way to load mistral7b-rm lm_head - if 'mistral7b-rm' in config.model.path: + if "mistral7b-rm" in config.model.path: model = MistralForSequenceClassification.from_pretrained( local_model_path, torch_dtype="auto", @@ -313,9 +325,10 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): # low_cpu_mem_usage=True ) # use score head instead of lm_head state_dict = model.state_dict() - state_dict['lm_head.weight'] = state_dict['score.weight'] - state_dict['model.embed_tokens.weight'] = state_dict[ - 'model.embed_tokens.weight'][:32000] # workaround, 32001 -> 32000 + state_dict["lm_head.weight"] = state_dict["score.weight"] + state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][ + :32000 + ] # workaround, 32001 -> 32000 is_value_model = True else: model = AutoModelForCausalLM.from_pretrained( @@ -329,45 +342,46 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): return architectures, model, state_dict, is_value_model -def load_megatron_model_weights(config, - model_config, - parallel_model, - params_dtype, - is_value_model=False, - local_cache_path='~/.cache/verl/rlhf'): +def load_megatron_model_weights( + config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" +): """Load weights for verl customized model.""" - architectures, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, - local_cache_path) + architectures, model, state_dict, is_value_model = _load_hf_model( + config, model_config, is_value_model, local_cache_path + ) from verl.models.weight_loader_registry import get_weight_loader - print(f'before weight loader: architectures = {architectures}...') + + print(f"before weight loader: architectures = {architectures}...") for arch in architectures: - print(f'call weight loader arch = {arch}, model config = {model.config}') + print(f"call weight loader arch = {arch}, model config = {model.config}") weight_loader = get_weight_loader(arch) - weight_loader(state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model, - tie_word_embeddings=model_config.tie_word_embeddings) + weight_loader( + state_dict=state_dict, + wrapped_models=parallel_model, + config=model.config, + params_dtype=params_dtype, + is_value_model=is_value_model, + tie_word_embeddings=model_config.tie_word_embeddings, + ) return model.config -def load_megatron_gptmodel_weights(config, - model_config, - parallel_model, - params_dtype, - is_value_model=False, - local_cache_path='~/.cache/verl/rlhf'): +def load_megatron_gptmodel_weights( + config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" +): """Load weights for mcore GPT model.""" _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path) from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - load_state_dict_to_megatron_gptmodel(state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model) + + load_state_dict_to_megatron_gptmodel( + state_dict=state_dict, + wrapped_models=parallel_model, + config=model.config, + params_dtype=params_dtype, + is_value_model=is_value_model, + ) del state_dict, model @@ -400,7 +414,7 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc elif unpad_tokens.ndim == 2: unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) else: - raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported') + raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1]) max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size) @@ -425,37 +439,36 @@ def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=Fal return -def get_parallel_gptmodel_from_config(tfconfig, - hf_config, - pre_process=None, - post_process=None, - share_embeddings_and_output_weights=False, - value=False): - from megatron.core.models.gpt.gpt_model import GPTModel +def get_parallel_gptmodel_from_config( + tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False +): from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec - from megatron.core import parallel_state as mpu - from megatron.core import tensor_parallel + from megatron.core.models.gpt.gpt_model import GPTModel + use_te = True - assert tfconfig.normalization == "RMSNorm", 'only RMSNorm is supported for now' + assert tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te) rope_scaling_args = {} if hf_config.rope_scaling is not None: - assert hf_config.rope_scaling['type'] == 'linear', "only linear scaling is supported for now" - rope_scaling_args['seq_len_interpolation_factor'] = hf_config.rope_scaling['factor'] - parallel_model = GPTModel(config=tfconfig, - transformer_layer_spec=transformer_layer_spec, - vocab_size=hf_config.vocab_size, - max_sequence_length=hf_config.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type='rope', - rotary_base=hf_config.rope_theta, - **rope_scaling_args) + assert hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" + rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"] + parallel_model = GPTModel( + config=tfconfig, + transformer_layer_spec=transformer_layer_spec, + vocab_size=hf_config.vocab_size, + max_sequence_length=hf_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type="rope", + rotary_base=hf_config.rope_theta, + **rope_scaling_args, + ) # # for layer in parallel_model.decoder.layers: layer.self_attention.core_attention.flash_attention.softmax_scale = None if post_process and value: from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, - output_size=1, - config=tfconfig) + + parallel_model.output_layer = LinearForLastLayer( + input_size=tfconfig.hidden_size, output_size=1, config=tfconfig + ) return parallel_model diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 8f5a0e176..b96eb0cd1 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -15,8 +15,8 @@ Contain small python utility functions """ -from typing import Dict from types import SimpleNamespace +from typing import Dict def union_two_dict(dict1: Dict, dict2: Dict): @@ -31,8 +31,7 @@ def union_two_dict(dict1: Dict, dict2: Dict): """ for key, val in dict2.items(): if key in dict1: - assert dict2[key] == dict1[key], \ - f'{key} in meta_dict1 and meta_dict2 are not the same object' + assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" dict1[key] = val return dict1 @@ -46,7 +45,6 @@ def append_to_dict(data: Dict, new_data: Dict): class NestedNamespace(SimpleNamespace): - def __init__(self, dictionary, **kwargs): super().__init__(**kwargs) for key, value in dictionary.items(): diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py index 9a75df6c3..49b60ef45 100644 --- a/verl/utils/ray_utils.py +++ b/verl/utils/ray_utils.py @@ -15,13 +15,12 @@ Contains commonly used utilities for ray """ -import ray - import concurrent.futures +import ray + def parallel_put(data_list, max_workers=None): - def put_data(index, data): return index, ray.put(data) diff --git a/verl/utils/rendezvous/ray_backend.py b/verl/utils/rendezvous/ray_backend.py index c0d2bd906..7eba70b73 100644 --- a/verl/utils/rendezvous/ray_backend.py +++ b/verl/utils/rendezvous/ray_backend.py @@ -15,15 +15,13 @@ import logging import time -from cupy.cuda.nccl import NcclCommunicator, get_unique_id - import ray +from cupy.cuda.nccl import NcclCommunicator, get_unique_id from ray.util import list_named_actors @ray.remote class NCCLIDStore: - def __init__(self, nccl_id): self._nccl_id = nccl_id @@ -44,11 +42,9 @@ def get_nccl_id_store_by_name(name): return None -def create_nccl_communicator_in_ray(rank: int, - world_size: int, - group_name: str, - max_retries: int = 100, - interval_s: int = 5): +def create_nccl_communicator_in_ray( + rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5 +): if rank == 0: nccl_id = get_unique_id() nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id) @@ -73,5 +69,5 @@ def create_nccl_communicator_in_ray(rank: int, rank=rank, ) return communicator - logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds") + logging.info(f"failed to get nccl_id for {i + 1} time, sleep for {interval_s} seconds") time.sleep(interval_s) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 70346be7f..fe101cc9e 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -15,11 +15,13 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None): - if data_source == 'openai/gsm8k': + if data_source == "openai/gsm8k": from . import gsm8k + res = gsm8k.compute_score(solution_str, ground_truth) - elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: + elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]: from . import math + res = math.compute_score(solution_str, ground_truth) # [Optional] Math-Verify Integration # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify). @@ -28,20 +30,28 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N # from . import math_verify # res = math_verify.compute_score(solution_str, ground_truth) - elif data_source == 'math_dapo' or data_source.startswith("aime"): + elif data_source == "math_dapo" or data_source.startswith("aime"): from . import math_dapo + res = math_dapo.compute_score(solution_str, ground_truth) elif data_source in [ - 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', - 'numina_olympiads' + "numina_aops_forum", + "numina_synthetic_math", + "numina_amc_aime", + "numina_synthetic_amc", + "numina_cn_k12", + "numina_olympiads", ]: from . import prime_math + res = prime_math.compute_score(solution_str, ground_truth) - elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: + elif data_source in ["codecontests", "apps", "codeforces", "taco"]: from . import prime_code + res = prime_code.compute_score(solution_str, ground_truth, continuous=True) - elif data_source in ['hiyouga/geometry3k']: + elif data_source in ["hiyouga/geometry3k"]: from . import geo3k + res = geo3k.compute_score(solution_str, ground_truth) else: raise NotImplementedError(f"Reward function is not implemented for {data_source=}") diff --git a/verl/utils/reward_score/geo3k.py b/verl/utils/reward_score/geo3k.py index 92c455469..699445cd7 100644 --- a/verl/utils/reward_score/geo3k.py +++ b/verl/utils/reward_score/geo3k.py @@ -13,11 +13,12 @@ # limitations under the License. import re + from mathruler.grader import extract_boxed_content, grade_answer def format_reward(predict_str: str) -> float: - pattern = re.compile(r'.*.*\\boxed\{.*\}.*', re.DOTALL) + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) match_result = re.fullmatch(pattern, predict_str) return 1.0 if match_result else 0.0 diff --git a/verl/utils/reward_score/gsm8k.py b/verl/utils/reward_score/gsm8k.py index 709103764..f5d4c1585 100644 --- a/verl/utils/reward_score/gsm8k.py +++ b/verl/utils/reward_score/gsm8k.py @@ -15,25 +15,25 @@ import re -def extract_solution(solution_str, method='strict'): - assert method in ['strict', 'flexible'] +def extract_solution(solution_str, method="strict"): + assert method in ["strict", "flexible"] - if method == 'strict': + if method == "strict": # this also tests the formatting of the model solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) if solution is None: final_answer = None else: final_answer = solution.group(0) - final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '') - elif method == 'flexible': + final_answer = final_answer.split("#### ")[1].replace(",", "").replace("$", "") + elif method == "flexible": answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) final_answer = None if len(answer) == 0: # no reward is there is no answer pass else: - invalid_str = ['', '.'] + invalid_str = ["", "."] # find the last number that is not '.' for final_answer in reversed(answer): if final_answer not in invalid_str: @@ -41,7 +41,7 @@ def extract_solution(solution_str, method='strict'): return final_answer -def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.): +def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): """The scoring function for GSM8k. Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. @@ -60,4 +60,4 @@ def compute_score(solution_str, ground_truth, method='strict', format_score=0., if answer == ground_truth: return score else: - return format_score \ No newline at end of file + return format_score diff --git a/verl/utils/reward_score/math.py b/verl/utils/reward_score/math.py index 50792aa6e..a768b1fef 100644 --- a/verl/utils/reward_score/math.py +++ b/verl/utils/reward_score/math.py @@ -15,13 +15,13 @@ def compute_score(solution_str, ground_truth) -> float: - retval = 0. + retval = 0.0 try: string_in_last_boxed = last_boxed_only_string(solution_str) if string_in_last_boxed is not None: answer = remove_boxed(string_in_last_boxed) if is_equiv(answer, ground_truth): - retval = 1. + retval = 1.0 except Exception as e: print(e) @@ -49,15 +49,15 @@ def is_equiv(str1, str2, verbose=False): def remove_boxed(s): if "\\boxed " in s: left = "\\boxed " - assert s[:len(left)] == left - return s[len(left):] + assert s[: len(left)] == left + return s[len(left) :] left = "\\boxed{" - assert s[:len(left)] == left + assert s[: len(left)] == left assert s[-1] == "}" - return s[len(left):-1] + return s[len(left) : -1] def last_boxed_only_string(string): @@ -85,7 +85,7 @@ def last_boxed_only_string(string): if right_brace_idx is None: retval = None else: - retval = string[idx:right_brace_idx + 1] + retval = string[idx : right_brace_idx + 1] return retval diff --git a/verl/utils/reward_score/math_dapo.py b/verl/utils/reward_score/math_dapo.py index 624651ad0..fdb8ddd42 100644 --- a/verl/utils/reward_score/math_dapo.py +++ b/verl/utils/reward_score/math_dapo.py @@ -20,10 +20,10 @@ from typing import Optional def last_boxed_only_string(string: str) -> Optional[str]: """Extract the last LaTeX boxed expression from a string. - + Args: string: Input string containing LaTeX code - + Returns: The last boxed expression or None if not found """ @@ -45,26 +45,25 @@ def last_boxed_only_string(string: str) -> Optional[str]: break i += 1 - return string[idx:right_brace_idx + 1] if right_brace_idx is not None else None + return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None def remove_boxed(s: str) -> str: """Remove the LaTeX boxed command from a string. - + Args: s: String with format "\\boxed{content}" - + Returns: The content inside the boxed command """ left = "\\boxed{" - assert s[:len(left)] == left, f"box error: {s}" + assert s[: len(left)] == left, f"box error: {s}" assert s[-1] == "}", f"box error: {s}" - return s[len(left):-1] + return s[len(left) : -1] class timeout: - def __init__(self, seconds=1, error_message="Timeout"): self.seconds = seconds self.error_message = error_message @@ -141,10 +140,10 @@ REMOVED_EXPRESSIONS = [ def normalize_final_answer(final_answer: str) -> str: """Normalize a final answer to a quantitative reasoning question. - + Args: final_answer: The answer string to normalize - + Returns: Normalized answer string """ @@ -180,18 +179,17 @@ def normalize_final_answer(final_answer: str) -> str: return final_answer.strip() -def is_correct_minerva(solution_str: str, - gt: str, - gt_need_extract: bool = False, - answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)") -> tuple[bool, str]: +def is_correct_minerva( + solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +) -> tuple[bool, str]: """Check if the solution is correct according to Minerva criteria. - + Args: solution_str: The solution string to check gt: The ground truth answer gt_need_extract: Whether the ground truth needs extraction answer_pattern: Regex pattern to extract the answer - + Returns: Tuple of (is_correct, normalized_prediction) """ @@ -209,23 +207,23 @@ def is_correct_minerva(solution_str: str, return (pred == gt), pred -def is_correct_strict_box(pred: str, - gt: str, - pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: +def is_correct_strict_box( + pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None +) -> tuple[int, Optional[str]]: """Check if the prediction is correct using strict boxed answer criteria. - + Args: pred: The prediction string gt: The ground truth answer pause_tokens_index: Indices of pause tokens - + Returns: Tuple of (score, extracted_prediction) """ # Extract the relevant part of the prediction if pause_tokens_index is not None: assert len(pause_tokens_index) == 4 - pred = pred[pause_tokens_index[-1] - 100:] + pred = pred[pause_tokens_index[-1] - 100 :] else: pred = pred[-100:] @@ -236,18 +234,17 @@ def is_correct_strict_box(pred: str, return 1 if (extracted_pred == gt) else -1, extracted_pred -def verify(solution_str: str, - answer: str, - strict_box_verify: bool = False, - pause_tokens_index: Optional[list[int]] = None) -> bool: +def verify( + solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None +) -> bool: """Verify if the solution is correct. - + Args: solution_str: The solution string to verify answer: The ground truth answer strict_box_verify: Whether to use strict box verification pause_tokens_index: Indices of pause tokens - + Returns: True if the solution is correct, False otherwise """ @@ -259,18 +256,20 @@ def verify(solution_str: str, return correct, pred -def compute_score(solution_str: str, - ground_truth: str, - strict_box_verify: bool = False, - pause_tokens_index: Optional[list[int]] = None) -> float: +def compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, +) -> float: """Compute the reward score for a solution. - + Args: solution_str: The solution string ground_truth: The ground truth answer config: Configuration object containing reward model settings pause_tokens_index: Indices of pause tokens - + Returns: Reward score (1.0 for correct, -1.0 for incorrect) """ diff --git a/verl/utils/reward_score/math_verify.py b/verl/utils/reward_score/math_verify.py index 8a995c111..c1ce7c1a4 100644 --- a/verl/utils/reward_score/math_verify.py +++ b/verl/utils/reward_score/math_verify.py @@ -13,9 +13,9 @@ # limitations under the License. try: - from math_verify.metric import math_metric - from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig from math_verify.errors import TimeoutException + from math_verify.metric import math_metric + from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig except ImportError: print("To use Math-Verify, please install it first by running `pip install math-verify`.") @@ -25,13 +25,13 @@ def compute_score(model_output: str, ground_truth: str, timeout_score: float = 0 gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), ) - ret_score = 0. + ret_score = 0.0 # Wrap the ground truth in \boxed{} format for verification ground_truth_boxed = "\\boxed{" + ground_truth + "}" try: ret_score, _ = verify_func([ground_truth_boxed], [model_output]) - except Exception as e: + except Exception: pass except TimeoutException: ret_score = timeout_score diff --git a/verl/utils/reward_score/prime_code/__init__.py b/verl/utils/reward_score/prime_code/__init__.py index 49b0d30f5..4a39be777 100644 --- a/verl/utils/reward_score/prime_code/__init__.py +++ b/verl/utils/reward_score/prime_code/__init__.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import check_correctness as apps_check_correctness import json import re import traceback +from .utils import check_correctness as apps_check_correctness + def compute_score(completion, test_cases, continuous=False): # try to get code solution from completion. if the completion is pure code, this will not take effect. - solution = completion.split('```python')[-1].split('```')[0] + solution = completion.split("```python")[-1].split("```")[0] try: try: if not isinstance(test_cases, dict): @@ -35,7 +36,7 @@ def compute_score(completion, test_cases, continuous=False): success = all(map(lambda x: x == True, res)) if success: return success, metadata - except Exception as e: + except Exception: pass test_cases_list = [] @@ -53,7 +54,7 @@ def compute_score(completion, test_cases, continuous=False): res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=5, debug=False) try: metadata = dict(enumerate(metadata))[0] # metadata can be empty occasionally - except Exception as e: + except Exception: metadata = {} metadata["test_case"] = {} metadata["test_case"]["input"] = str(test_case["inputs"][0]) @@ -66,7 +67,7 @@ def compute_score(completion, test_cases, continuous=False): break res_count = len(res_list) if len(res_list) > 0 else 1 success = sum(map(lambda x: x == True, res_list)) / res_count - except Exception as e: + except Exception: traceback.print_exc(10) success = False metadata_list = None diff --git a/verl/utils/reward_score/prime_code/testing_util.py b/verl/utils/reward_score/prime_code/testing_util.py index 34845d07e..9597ea2e0 100644 --- a/verl/utils/reward_score/prime_code/testing_util.py +++ b/verl/utils/reward_score/prime_code/testing_util.py @@ -13,38 +13,35 @@ # limitations under the License. import ast -import json -import sys import faulthandler +import json import platform -# used for debugging to time steps -from datetime import datetime - # to run the solution files we're using a timing based approach import signal +import sys +import traceback -import numpy as np +# used for debugging to time steps +from datetime import datetime +from enum import Enum # for capturing the stdout from io import StringIO # used for testing the code that reads from input -from unittest.mock import patch, mock_open +from unittest.mock import mock_open, patch +import numpy as np from pyext import RuntimeModule -from enum import Enum - -import traceback - def truncatefn(s, length=300): assert isinstance(s, str) if len(s) <= length: return s - return s[:length // 2] + "...(truncated) ..." + s[-length // 2:] + return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :] class CODE_TYPE(Enum): @@ -72,7 +69,6 @@ signal.signal(signal.SIGALRM, timeout_handler) # from https://stackoverflow.com/a/16571630/6416660 # alternative use redirect_stdout() from contextlib class Capturing(list): - def __enter__(self): self._stdout = sys.stdout sys.stdout = self._stringio = StringIO() @@ -99,7 +95,7 @@ def combined_int_check(val): def clean_traceback(error_traceback): - file_start = error_traceback.find('File \"\"') + file_start = error_traceback.find('File ""') # print(file_start) error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:] return error_traceback @@ -137,7 +133,6 @@ def run_test(in_outs, test=None, debug=False, timeout=15): print(f"loading test code = {datetime.now().time()}") if which_type == CODE_TYPE.call_based: - sol += test if debug: print(f"sol = {sol}") @@ -172,7 +167,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(last_block, ast.If): condition = last_block.test if ast.unparse(condition).strip() == "__name__ == '__main__'": - test = (ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body)) + test = ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) except: pass @@ -249,7 +244,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): truncate_line_size = 300 // (raw_inputs.count("\n") + 1) raw_inputs = "\n".join( - [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")]) + [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")] + ) raw_outputs = truncatefn(raw_outputs, 200) else: raw_inputs = truncatefn(raw_inputs) @@ -290,7 +286,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): output = list(output) tmp_result = output == in_outs["outputs"][index] - if (isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]): + if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) # ground truth sequences are not tuples @@ -390,7 +386,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): nl = "\n" if not isinstance(inputs, list): print( - f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" ) else: print( @@ -459,7 +455,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): nl = "\n" if not isinstance(inputs, list): print( - f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" ) else: print( @@ -488,18 +484,22 @@ def run_test(in_outs, test=None, debug=False, timeout=15): try: all_ints = all( combined_int_check(e1) and combined_int_check(e2) - for e1, e2 in zip(output, in_outs["outputs"][index])) + for e1, e2 in zip(output, in_outs["outputs"][index]) + ) if not all_ints: if debug: - print([ - combined_int_check(e1) and combined_int_check(e2) - for e1, e2 in zip(output, in_outs["outputs"][index]) - ]) + print( + [ + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index]) + ] + ) output_float = [float(e) for e in output] gt_float = [float(e) for e in in_outs["outputs"][index]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and - np.allclose(output_float, gt_float)) - except Exception as e: + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) + except Exception: pass if debug: @@ -509,13 +509,15 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(output[0], list): all_ints = all( combined_int_check(e1) and combined_int_check(e2) - for e1, e2 in zip(output[0], in_outs["outputs"][index])) + for e1, e2 in zip(output[0], in_outs["outputs"][index]) + ) if not all_ints: output_float = [float(e) for e in output[0]] gt_float = [float(e) for e in in_outs["outputs"][index][0]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and - np.allclose(output_float, gt_float)) - except Exception as e: + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) + except Exception: pass if tmp_result == True: @@ -580,7 +582,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): nl = "\n" if not isinstance(inputs, list): print( - f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" ) else: print( @@ -593,7 +595,6 @@ def run_test(in_outs, test=None, debug=False, timeout=15): def custom_compare_(output, ground_truth): - if isinstance(output, list): output_1 = "\n".join(output) if stripped_string_compare(output_1, ground_truth): @@ -615,7 +616,6 @@ def stripped_string_compare(s1, s2): def call_method(method, inputs): - if isinstance(inputs, list): inputs = "\n".join(inputs) @@ -633,7 +633,7 @@ def call_method(method, inputs): def _inner_call_method(_method): try: return _method() - except SystemExit as e: + except SystemExit: pass finally: pass diff --git a/verl/utils/reward_score/prime_code/utils.py b/verl/utils/reward_score/prime_code/utils.py index 07f3db88e..91232659e 100644 --- a/verl/utils/reward_score/prime_code/utils.py +++ b/verl/utils/reward_score/prime_code/utils.py @@ -15,25 +15,26 @@ # Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py import multiprocessing -from typing import Dict, Optional -from datasets import load_dataset -from .testing_util import run_test +import os +import sys import traceback -import os, sys +from typing import Optional + +from .testing_util import run_test def _temp_run(sample, generation, debug, result, metadata_list, timeout): - with open(os.devnull, 'w') as devnull: + with open(os.devnull, "w") as devnull: sys.stdout = devnull sys.stderr = devnull try: res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) result.append(res) metadata_list.append(metadata) - except Exception as e: + except Exception: # print(e) # some tracebacks are extremely long. traceback.print_exc(10) - result.append([-1 for i in range(len(sample['inputs']))]) + result.append([-1 for i in range(len(sample["inputs"]))]) metadata_list.append({}) @@ -55,5 +56,5 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru # consider that all tests failed result = [[-1 for i in range(len(in_outs["inputs"]))]] if debug: - print(f"global timeout") + print("global timeout") return result[0], metadata_list diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py index 49c6acbd5..4bb002188 100644 --- a/verl/utils/reward_score/prime_math/__init__.py +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -18,11 +18,14 @@ Call grade_answer(given_answer: str, ground_truth: str). FROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py """ + +import os import re + import sympy from pylatexenc import latex2text from sympy.parsing import sympy_parser -import os + from . import math_normalize from .grader import math_equal @@ -40,7 +43,6 @@ def timeout(timeout_seconds: int = 8): import signal def decorator(func): - def handler(signum, frame): raise TimeoutError("Operation timed out!") @@ -166,26 +168,26 @@ def _normalize(expr: str) -> str: expr = expr.replace("trillion", "*10^12") for unit in [ - "degree", - "cm", - "centimeter", - "meter", - "mile", - "second", - "minute", - "hour", - "day", - "week", - "month", - "year", - "foot", - "feet", - "inch", - "yard", - "liter", + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + "liter", ]: expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub(f"\^ *\\\\circ", "", expr) + expr = re.sub("\^ *\\\\circ", "", expr) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] @@ -258,8 +260,12 @@ def split_tuple(expr: str): expr = _strip_properly_formatted_commas(expr) if len(expr) == 0: return [] - if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and - all([ch not in expr[1:-1] for ch in TUPLE_CHARS])): + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): elems = [elem.strip() for elem in expr[1:-1].split(",")] else: elems = [expr] @@ -298,10 +304,11 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool: ground_truth_elems = split_tuple(ground_truth_normalized) given_elems = split_tuple(given_normalized) - if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or - ground_truth_normalized[-1] != given_normalized[-1]): - is_correct = False - elif len(ground_truth_elems) != len(given_elems): + if ( + len(ground_truth_elems) > 1 + and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]) + or len(ground_truth_elems) != len(given_elems) + ): is_correct = False else: for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): @@ -323,9 +330,9 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool: def remove_boxed(s): left = "\\boxed{" try: - assert s[:len(left)] == left + assert s[: len(left)] == left assert s[-1] == "}" - return s[len(left):-1] + return s[len(left) : -1] except: return None @@ -357,16 +364,16 @@ def _last_boxed_only_string(string): if left_brace_idx is None or right_brace_idx is None: return None - return string[left_brace_idx + 1:right_brace_idx].strip() + return string[left_brace_idx + 1 : right_brace_idx].strip() def match_answer(response): is_matched = False - for ans_marker in ['answer:', "answer is", "answers are"]: + for ans_marker in ["answer:", "answer is", "answers are"]: ans_idx = response.lower().rfind(ans_marker) if ans_idx != -1: is_matched = True - response = response[ans_idx + len(ans_marker):].strip() + response = response[ans_idx + len(ans_marker) :].strip() if response.endswith("\n"): response = response[:-2] @@ -389,11 +396,11 @@ def match_answer(response): if dot_idx != -1: response = response[:dot_idx].strip() - for ans_marker in ['be ', "is ", "are ", "=", ": ", "get ", 'be\n', "is\n", "are\n", ":\n", "get\n"]: + for ans_marker in ["be ", "is ", "are ", "=", ": ", "get ", "be\n", "is\n", "are\n", ":\n", "get\n"]: ans_idx = response.lower().rfind(ans_marker) if ans_idx != -1: is_matched = True - response = response[ans_idx + len(ans_marker):].strip() + response = response[ans_idx + len(ans_marker) :].strip() if response.endswith("\n"): response = response[:-2] diff --git a/verl/utils/reward_score/prime_math/grader.py b/verl/utils/reward_score/prime_math/grader.py index e3e1686f1..9a371bc4e 100644 --- a/verl/utils/reward_score/prime_math/grader.py +++ b/verl/utils/reward_score/prime_math/grader.py @@ -93,9 +93,9 @@ This logic is largely copied from the Hendrycks' MATH release (math_equivalence) """ import contextlib +import math import re import signal -import math from math import isclose from typing import Union @@ -118,12 +118,13 @@ def is_digit(s): def normalize(answer, pi) -> str: # checking if answer is $ and removing $ in that case to compare - if isinstance(answer, str) and bool(re.match(r'\$\d+(\.\d+)?', answer)): + if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)): return answer[1:] # checking if answer is % or \\% and removing % - if isinstance(answer, str) and (bool(re.match(r'^\d+(\.\d+)?%$', answer)) or - bool(re.match(r'^\d+(\.\d+)?\\%$', answer))): + if isinstance(answer, str) and ( + bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) + ): return answer.replace("\\%", "").replace("%", "") # handle base @@ -151,13 +152,12 @@ def handle_pi(string, pi): # Iterate over the string and find all occurrences of "\pi" with a valid previous character while idx != -1: - if idx > 0 and string[idx - 1].isdigit(): # Replace "\pi" with "*math.pi" if the previous character is a digit - string = string[:idx] + f"*{pi}" + string[idx + 3:] + string = string[:idx] + f"*{pi}" + string[idx + 3 :] else: # Replace "\pi" with "1*math.pi" if the previous character is not a digit - string = string[:idx] + f"1*{pi}" + string[idx + 3:] + string = string[:idx] + f"1*{pi}" + string[idx + 3 :] # Find the next occurrence of "\pi" idx = string.find("\pi", idx + 1) @@ -171,12 +171,14 @@ def handle_pi(string, pi): return string -def math_equal(prediction: Union[bool, float, str], - reference: Union[float, str], - include_percentage: bool = True, - tolerance: float = 1e-4, - timeout: float = 10.0, - pi: float = math.pi) -> bool: +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + tolerance: float = 1e-4, + timeout: float = 10.0, + pi: float = math.pi, +) -> bool: """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal @@ -226,9 +228,9 @@ def math_equal(prediction: Union[bool, float, str], prediction = format_intervals(prediction) pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") and - not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and - not reference.startswith("[")): + if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( + prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + ): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") for s in ["{", "}", "(", ")"]: @@ -238,15 +240,23 @@ def math_equal(prediction: Union[bool, float, str], return True ## [a, b] vs. [c, d], return a==c and b==d - if (prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and - prediction[0] == reference[0] and prediction[-1] == reference[-1]): + if ( + prediction + and reference + and prediction[0] in "([" + and prediction[-1] in ")]" + and prediction[0] == reference[0] + and prediction[-1] == reference[-1] + ): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): - if all([ + if all( + [ math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts) - ]): + ] + ): return True if "," in prediction and "," in reference: @@ -254,23 +264,24 @@ def math_equal(prediction: Union[bool, float, str], ref_parts = [item.strip() for item in reference.split(",")] if len(pred_parts) == len(ref_parts): - if all([ - math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) - for i in range(len(pred_parts)) - ]): + if all( + [math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))] + ): return True else: return False # if we have point == tuple of values if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": - pred_parts = prediction[prediction.find("(") + 1:-1].split(",") + pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): - if all([ + if all( + [ math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts) - ]): + ] + ): return True # if reference is a matrix @@ -279,10 +290,12 @@ def math_equal(prediction: Union[bool, float, str], pred_matrix = parse_expr(prediction) ref_matrix_items = reference.split()[1:-1:2] if len(pred_matrix) == len(ref_matrix_items): - if all([ + if all( + [ math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix) - ]): + ] + ): return True except Exception: pass @@ -291,15 +304,21 @@ def math_equal(prediction: Union[bool, float, str], try: pred_matrix = eval(prediction) # ref_matrix_items = reference.split()[1:-1:2] - ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip( - "\\end{pmatrix}").rstrip("\end{pmatrix}") + ref_matrix_items = ( + reference.lstrip("\\begin{pmatrix}") + .lstrip("\begin{pmatrix}") + .rstrip("\\end{pmatrix}") + .rstrip("\end{pmatrix}") + ) ref_matrix_items = ref_matrix_items.split("\\") ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] if len(pred_matrix) == len(ref_matrix_items): - if all([ + if all( + [ math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix) - ]): + ] + ): return True except Exception: pass @@ -308,7 +327,6 @@ def math_equal(prediction: Union[bool, float, str], def symbolic_equal(a, b, tolerance, timeout=10.0): - def _parse(s): for f in [parse_expr, parse_latex]: try: @@ -343,7 +361,6 @@ class TimeoutException(Exception): @contextlib.contextmanager def time_limit(seconds: float): - def signal_handler(signum, frame): raise TimeoutException("Timed out!") diff --git a/verl/utils/reward_score/prime_math/math_normalize.py b/verl/utils/reward_score/prime_math/math_normalize.py index e9921a5ad..acc05d4d4 100644 --- a/verl/utils/reward_score/prime_math/math_normalize.py +++ b/verl/utils/reward_score/prime_math/math_normalize.py @@ -36,6 +36,7 @@ This logic is largely copied from the Hendrycks' MATH release (math_equivalence) From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py """ + import re from typing import Optional @@ -188,4 +189,4 @@ def _strip_string(string): # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) - return string \ No newline at end of file + return string diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index fee45da0d..cd9fbd9cb 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Callable +import copy import heapq +from typing import List, Tuple import torch -from torch import distributed as dist - from tensordict import TensorDict -import copy +from torch import distributed as dist def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): # see: https://en.wikipedia.org/wiki/Largest_differencing_method class Set: - def __init__(self) -> None: self.sum = 0 self.items = [] @@ -47,7 +45,6 @@ def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): return self.items < other.items class State: - def __init__(self, items: List[Tuple[int, int]], k: int) -> None: self.k = k # sets should always be decreasing order @@ -125,8 +122,9 @@ def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): partitions = final_state.get_partitions() if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * \ - k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) return partitions @@ -144,13 +142,14 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool partition_sums[min_idx] += seqlen if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * \ - k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) return partitions def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): - """ get order of seq lengths to make partitions balanced, this is + """get order of seq lengths to make partitions balanced, this is used in balacing sum of seqlength across dp ranks and microbatches Parameters: seqlen_list (List[int]): @@ -192,7 +191,7 @@ def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], pr max_sum_seqlen = None total_sum_seqlen = 0 for offset in range(0, len(seqlen_list), batch_size): - cur_sum_seqlen = sum(seqlen_list[offset:offset + batch_size]) + cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: min_sum_seqlen = cur_sum_seqlen if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: @@ -208,12 +207,12 @@ def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], pr max_sum_seqlen_balanced = max(balanced_sum_seqlen_list) return { - f'{prefix}/min': min_sum_seqlen, - f'{prefix}/max': max_sum_seqlen, - f'{prefix}/minmax_diff': max_sum_seqlen - min_sum_seqlen, - f'{prefix}/balanced_min': min_sum_seqlen_balanced, - f'{prefix}/balanced_max': max_sum_seqlen_balanced, - f'{prefix}/mean': total_sum_seqlen / len(partitions) + f"{prefix}/min": min_sum_seqlen, + f"{prefix}/max": max_sum_seqlen, + f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen, + f"{prefix}/balanced_min": min_sum_seqlen_balanced, + f"{prefix}/balanced_max": max_sum_seqlen_balanced, + f"{prefix}/mean": total_sum_seqlen / len(partitions), } @@ -226,15 +225,16 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): and the number of valid tokens in each micro batch is well balanced. """ # this is per local micro_bsz - max_seq_len = batch['attention_mask'].shape[-1] - assert max_token_len >= max_seq_len, \ - f'max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}' + max_seq_len = batch["attention_mask"].shape[-1] + assert max_token_len >= max_seq_len, ( + f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" + ) - seq_len_effective: torch.Tensor = batch['attention_mask'].sum(dim=1) + seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) total_seqlen = seq_len_effective.sum().item() num_micro_batches = ceildiv(total_seqlen, max_token_len) if dist.is_initialized(): - num_micro_batches = torch.tensor([num_micro_batches], device='cuda') + num_micro_batches = torch.tensor([num_micro_batches], device="cuda") dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) num_micro_batches = num_micro_batches.cpu().item() @@ -248,7 +248,7 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): for partition in micro_bsz_idx: curr_micro_batch = [] for idx in partition: - curr_micro_batch.append(batch[idx:idx + 1]) + curr_micro_batch.append(batch[idx : idx + 1]) curr_micro_batch = torch.cat(curr_micro_batch) micro_batches.append(curr_micro_batch) diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index b82d87a39..c82c57ae5 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utils for tokenization.""" + import warnings -__all__ = ['hf_tokenizer', 'hf_processor'] +__all__ = ["hf_tokenizer", "hf_processor"] def set_pad_token_id(tokenizer): @@ -26,10 +27,10 @@ def set_pad_token_id(tokenizer): """ if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}') + warnings.warn(f"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}') + warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}") def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): @@ -47,12 +48,13 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw """ from transformers import AutoTokenizer - if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path: + + if correct_gemma2 and isinstance(name_or_path, str) and "gemma-2-2b-it" in name_or_path: # the EOS token in gemma2 is ambiguious, which may worsen RL performance. # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a - warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.') - kwargs['eos_token'] = '' - kwargs['eos_token_id'] = 107 + warnings.warn("Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.") + kwargs["eos_token"] = "" + kwargs["eos_token_id"] = 107 tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) if correct_pad_token: set_pad_token_id(tokenizer) @@ -69,6 +71,7 @@ def hf_processor(name_or_path, **kwargs): transformers.ProcessorMixin: The pretrained processor. """ from transformers import AutoProcessor + try: processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) except Exception: diff --git a/verl/utils/torch_dtypes.py b/verl/utils/torch_dtypes.py index 242c8fb37..015dae5a1 100644 --- a/verl/utils/torch_dtypes.py +++ b/verl/utils/torch_dtypes.py @@ -15,16 +15,16 @@ Adapted from Cruise. """ -import torch - from typing import Union +import torch + HALF_LIST = [16, "16", "fp16", "float16", torch.float16] FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32] BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16] -class PrecisionType(object): +class PrecisionType: """Type of precision used. >>> PrecisionType.HALF == 16 @@ -73,10 +73,10 @@ class PrecisionType(object): @staticmethod def to_str(precision): if precision == torch.float16: - return 'fp16' + return "fp16" elif precision == torch.float32: - return 'fp32' + return "fp32" elif precision == torch.bfloat16: - return 'bf16' + return "bf16" else: raise RuntimeError(f"unexpected precision: {precision}") diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index b9d2a1c28..bbabd77de 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -15,7 +15,7 @@ Contain small torch utilities """ -from typing import Dict, Union, List, Optional +from typing import Dict, List, Optional, Union import torch import torch.distributed @@ -25,6 +25,7 @@ from torch import nn try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True except ImportError: FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False @@ -63,8 +64,9 @@ def logprobs_from_logits(logits, labels, inplace_backward=True): def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) - assert isinstance( - output, tuple), "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + assert isinstance(output, tuple), ( + "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + ) return -output[0] @@ -148,9 +150,9 @@ def masked_whiten(values, mask, shift_mean=True): def get_response_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64): - ''' + """ end of sentence token can be int or list: 1 or [1, 2] - e.g. + e.g. response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0], [78, 0, 76, 2, 1, 0, 0], [23, 98, 1, 0, 0, 0, 0], @@ -165,7 +167,7 @@ def get_response_mask(response_id: torch.Tensor, eos_token: Union[int, List[int] [1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0]]) - ''' + """ eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) @@ -223,8 +225,9 @@ def allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], TensorDict], def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]: - assert tensors.batch_size[0] % batch_size == 0, \ - f'input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}' + assert tensors.batch_size[0] % batch_size == 0, ( + f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" + ) return tensors.split(batch_size) @@ -252,49 +255,49 @@ def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): if tensors.shape[-1] >= max_seq_len: return tensors pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1]) - return F.pad(tensors, pad_tuple, 'constant', pad_token_id) + return F.pad(tensors, pad_tuple, "constant", pad_token_id) -def postprocess_data(input_ids: torch.Tensor, - attention_mask: torch.Tensor, - max_length: int, - pad_token_id: int, - left_pad=True, - truncation='error'): +def postprocess_data( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + max_length: int, + pad_token_id: int, + left_pad=True, + truncation="error", +): """ input_data is the output from tokenizer. """ - assert truncation in ['left', 'right', 'error'] + assert truncation in ["left", "right", "error"] assert input_ids.ndim == 2 sequence_length = input_ids.shape[-1] if sequence_length < max_length: - input_ids = pad_sequence_to_length(input_ids, - max_seq_len=max_length, - pad_token_id=pad_token_id, - left_pad=left_pad) - attention_mask = pad_sequence_to_length(attention_mask, - max_seq_len=max_length, - pad_token_id=0, - left_pad=left_pad) + input_ids = pad_sequence_to_length( + input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad + ) + attention_mask = pad_sequence_to_length( + attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad + ) elif sequence_length > max_length: - if truncation == 'left': + if truncation == "left": # actually, left truncation may not be reasonable input_ids = input_ids[:, -max_length:] attention_mask = attention_mask[:, -max_length:] - elif truncation == 'right': + elif truncation == "right": input_ids = input_ids[:, :max_length] attention_mask = attention_mask[:, :max_length] - elif truncation == 'error': - raise NotImplementedError(f'{sequence_length=} is larger than {max_length=}') + elif truncation == "error": + raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}") else: - raise NotImplementedError(f'Unknown truncation method {truncation}') + raise NotImplementedError(f"Unknown truncation method {truncation}") return input_ids, attention_mask def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor): - """ Remove the pad token. + """Remove the pad token. Args: input_ids shape: [bs, seq_length] @@ -304,7 +307,7 @@ def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor): """ no_padding_batch = [] for ids, mask in zip(input_ids, attention_mask): - no_padding_batch.append((ids[len(ids) - mask.sum():]).cpu().numpy().tolist()) + no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist()) return no_padding_batch @@ -318,7 +321,7 @@ def log_probs_from_logits_response(input_ids, logits, response_length): Returns: response_log_prob: """ - response_logits = logits[:, -response_length - 1:-1] + response_logits = logits[:, -response_length - 1 : -1] response = input_ids[:, -response_length:] response_log_prob = logprobs_from_logits(logits=response_logits, labels=response) return response_log_prob @@ -344,11 +347,10 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) - output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length] + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] return output @@ -368,23 +370,20 @@ def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batc response_length: int """ from flash_attn.bert_padding import pad_input + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1] input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) - output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length] + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] return output -from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper) - - def post_process_logits(input_ids, logits, temperature, top_k, top_p): - if temperature != 1.: + if temperature != 1.0: logits = logits.div_(temperature) # inplace operation to avoid OOM # TODO: add them back # if top_k is not None and top_k > 0: @@ -398,9 +397,10 @@ def post_process_logits(input_ids, logits, temperature, top_k, top_p): Optimizer related """ +import math + from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR -import math def get_cosine_schedule_with_warmup( @@ -432,7 +432,7 @@ def get_cosine_schedule_with_warmup( Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ - assert min_lr_ratio >= 0 and min_lr_ratio <= 1. + assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0 coef = (1 - min_lr_ratio) * 0.5 intercept = (1 + min_lr_ratio) * 0.5 @@ -451,7 +451,6 @@ def get_constant_schedule_with_warmup( num_warmup_steps: int, last_epoch: int = -1, ): - def lr_lambda(current_step): return min(1, float(current_step) / float(max(1, num_warmup_steps))) @@ -471,10 +470,12 @@ def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index dd8c07cd9..8d753df64 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -14,37 +14,41 @@ """ A unified tracking interface that supports logging data to different backend """ + import dataclasses from enum import Enum from functools import partial from pathlib import Path -from typing import List, Union, Dict, Any +from typing import Any, Dict, List, Union -class Tracking(object): +class Tracking: supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console"] - def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): + def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None): if isinstance(default_backend, str): default_backend = [default_backend] for backend in default_backend: - if backend == 'tracking': + if backend == "tracking": import warnings + warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning) else: - assert backend in self.supported_backend, f'{backend} is not supported' + assert backend in self.supported_backend, f"{backend} is not supported" self.logger = {} - if 'tracking' in default_backend or 'wandb' in default_backend: + if "tracking" in default_backend or "wandb" in default_backend: import wandb - wandb.init(project=project_name, name=experiment_name, config=config) - self.logger['wandb'] = wandb - if 'mlflow' in default_backend: - import mlflow + wandb.init(project=project_name, name=experiment_name, config=config) + self.logger["wandb"] = wandb + + if "mlflow" in default_backend: import os + import mlflow + MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", None) if MLFLOW_TRACKING_URI: mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) @@ -54,31 +58,33 @@ class Tracking(object): experiment = mlflow.set_experiment(project_name) mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name) mlflow.log_params(_compute_mlflow_params_from_objects(config)) - self.logger['mlflow'] = _MlflowLoggingAdapter() + self.logger["mlflow"] = _MlflowLoggingAdapter() if "swanlab" in default_backend: - import swanlab import os + import swanlab + SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None) SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") if SWANLAB_API_KEY: swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten - swanlab.init(project=project_name, - experiment_name=experiment_name, - config={ - "FRAMEWORK": "veRL", - **config - }, - logdir=SWANLAB_LOG_DIR, - mode=SWANLAB_MODE) + swanlab.init( + project=project_name, + experiment_name=experiment_name, + config={"FRAMEWORK": "veRL", **config}, + logdir=SWANLAB_LOG_DIR, + mode=SWANLAB_MODE, + ) self.logger["swanlab"] = swanlab - if 'vemlp_wandb' in default_backend: + if "vemlp_wandb" in default_backend: import os + import volcengine_ml_platform from volcengine_ml_platform import wandb as vemlp_wandb + volcengine_ml_platform.init( ak=os.environ["VOLC_ACCESS_KEY_ID"], sk=os.environ["VOLC_SECRET_ACCESS_KEY"], @@ -91,15 +97,16 @@ class Tracking(object): config=config, sync_tensorboard=True, ) - self.logger['vemlp_wandb'] = vemlp_wandb + self.logger["vemlp_wandb"] = vemlp_wandb - if 'tensorboard' in default_backend: - self.logger['tensorboard'] = _TensorboardAdapter() + if "tensorboard" in default_backend: + self.logger["tensorboard"] = _TensorboardAdapter() - if 'console' in default_backend: + if "console" in default_backend: from verl.utils.logger.aggregate_logger import LocalLogger + self.console_logger = LocalLogger(print_to_console=True) - self.logger['console'] = self.console_logger + self.logger["console"] = self.console_logger def log(self, data, step, backend=None): for default_backend, logger_instance in self.logger.items(): @@ -107,21 +114,22 @@ class Tracking(object): logger_instance.log(data=data, step=step) def __del__(self): - if 'wandb' in self.logger: - self.logger['wandb'].finish(exit_code=0) - if 'swanlab' in self.logger: - self.logger['swanlab'].finish() - if 'vemlp_wandb' in self.logger: - self.logger['vemlp_wandb'].finish(exit_code=0) - if 'tensorboard' in self.logger: - self.logger['tensorboard'].finish() + if "wandb" in self.logger: + self.logger["wandb"].finish(exit_code=0) + if "swanlab" in self.logger: + self.logger["swanlab"].finish() + if "vemlp_wandb" in self.logger: + self.logger["vemlp_wandb"].finish(exit_code=0) + if "tensorboard" in self.logger: + self.logger["tensorboard"].finish() class _TensorboardAdapter: - def __init__(self): - from torch.utils.tensorboard import SummaryWriter import os + + from torch.utils.tensorboard import SummaryWriter + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "tensorboard_log") os.makedirs(tensorboard_dir, exist_ok=True) print(f"Saving tensorboard log to {tensorboard_dir}.") @@ -136,10 +144,10 @@ class _TensorboardAdapter: class _MlflowLoggingAdapter: - def log(self, data, step): import mlflow - results = {k.replace('@', '_at_'): v for k, v in data.items()} + + results = {k.replace("@", "_at_"): v for k, v in data.items()} mlflow.log_metrics(metrics=results, step=step) @@ -147,7 +155,7 @@ def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]: if params is None: return {} - return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep='/') + return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/") def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): @@ -159,7 +167,7 @@ def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): return {k: _transform(v) for k, v in x.items()} if isinstance(x, list): if convert_list_to_dict: - return {'list_len': len(x)} | {f'{i}': _transform(v) for i, v in enumerate(x)} + return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)} else: return [_transform(v) for v in x] if isinstance(x, Path): @@ -172,20 +180,20 @@ def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]: import pandas as pd - ans = pd.json_normalize(raw, sep=sep).to_dict(orient='records')[0] + + ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0] assert isinstance(ans, dict) return ans @dataclasses.dataclass class ValidationGenerationsLogger: - def log(self, loggers, samples, step): - if 'wandb' in loggers: + if "wandb" in loggers: self.log_generations_to_wandb(samples, step) - if 'swanlab' in loggers: + if "swanlab" in loggers: self.log_generations_to_swanlab(samples, step) - if 'mlflow' in loggers: + if "mlflow" in loggers: self.log_generations_to_mlflow(samples, step) def log_generations_to_wandb(self, samples, step): @@ -193,9 +201,11 @@ class ValidationGenerationsLogger: import wandb # Create column names for all samples - columns = ["step"] + sum([[f"input_{i+1}", f"output_{i+1}", f"score_{i+1}"] for i in range(len(samples))], []) + columns = ["step"] + sum( + [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] + ) - if not hasattr(self, 'validation_table'): + if not hasattr(self, "validation_table"): # Initialize the table on first call self.validation_table = wandb.Table(columns=columns) @@ -232,18 +242,20 @@ class ValidationGenerationsLogger: score: {sample[2]} """ - swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i+1}")) + swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) # Log to swanlab swanlab.log({"val/generations": swanlab_text_list}, step=step) def log_generations_to_mlflow(self, samples, step): """Log validation generation to mlflow as artifacts""" - #https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact + # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact + + import json + import tempfile import mlflow - import tempfile - import json + try: with tempfile.TemporaryDirectory() as tmp_dir: validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json") diff --git a/verl/utils/ulysses.py b/verl/utils/ulysses.py index 67aa150c8..f3a9e637c 100644 --- a/verl/utils/ulysses.py +++ b/verl/utils/ulysses.py @@ -16,11 +16,12 @@ Utilities for DeepSpeed Ulysses Sequence Parallelism. DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509 Inspired from: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py """ -from typing import Any, Optional, List, Tuple + +from typing import Any, Optional, Tuple import torch -from torch import Tensor import torch.distributed as dist +from torch import Tensor from torch.distributed import ProcessGroup _ULYSSES_SEQUENCE_PARALLEL_GROUP = None @@ -162,7 +163,6 @@ def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = class SeqAllToAll(torch.autograd.Function): - @staticmethod def forward( ctx: Any, @@ -195,14 +195,15 @@ class SeqAllToAll(torch.autograd.Function): class Gather(torch.autograd.Function): - @staticmethod - def forward(ctx: Any, - group: dist.ProcessGroup, - local_tensor: Tensor, - gather_dim: int, - grad_scaler: bool = True, - async_op=False) -> Tensor: + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_tensor: Tensor, + gather_dim: int, + grad_scaler: bool = True, + async_op=False, + ) -> Tensor: ctx.group = group ctx.gather_dim = gather_dim ctx.grad_scaler = grad_scaler @@ -226,32 +227,40 @@ class Gather(torch.autograd.Function): def backward(ctx: Any, grad_output: Tensor) -> Any: if ctx.grad_scaler: grad_output = grad_output * ctx.sp_world_size - return (None, grad_output.split(ctx.part_size, - dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), None, None, None, None) + return ( + None, + grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), + None, + None, + None, + None, + ) -def gather_outpus_and_unpad(x: Tensor, - gather_dim: int, - unpad_dim: int = None, - padding_size: int = 0, - grad_scaler: bool = True, - group: Optional[dist.ProcessGroup] = None): +def gather_outpus_and_unpad( + x: Tensor, + gather_dim: int, + unpad_dim: int = None, + padding_size: int = 0, + grad_scaler: bool = True, + group: Optional[dist.ProcessGroup] = None, +): group = get_ulysses_sequence_parallel_group() if group is None else group sp_size = get_ulysses_sequence_parallel_world_size() if group == None: return x x = Gather.apply(group, x, gather_dim, grad_scaler) if unpad_dim is not None: - assert isinstance(padding_size, int), 'padding size is not given or is not an integer' + assert isinstance(padding_size, int), "padding size is not given or is not an integer" if padding_size == 0: return x x = _unpad_tensor(x, unpad_dim, padding_size) return x -def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, - position_ids_rmpad: Optional[torch.Tensor] = None, - sp_size: int = 1): +def ulysses_pad_and_slice_inputs( + input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 +): """ Pad and slice input_ids to be divisible by sp_size Pad position_ids to be divisible by sp_size. @@ -268,7 +277,7 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, Returns: torch.Tensor: padded and sliced input_ids torch.Tensor: padded and sliced position_ids - int: pad size + int: pad size """ if position_ids_rmpad is not None: assert position_ids_rmpad.size(0) == 1 @@ -290,5 +299,6 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, def validate_ulysses_config(num_heads, ulysses_sequence_size): if ulysses_sequence_size > 1: - assert num_heads % ulysses_sequence_size == 0,\ + assert num_heads % ulysses_sequence_size == 0, ( f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" + ) diff --git a/verl/workers/actor/base.py b/verl/workers/actor/base.py index 144f0b90e..430c21858 100644 --- a/verl/workers/actor/base.py +++ b/verl/workers/actor/base.py @@ -14,17 +14,18 @@ """ The base class for Actor """ -from abc import ABC, abstractmethod -from typing import Iterable, Dict -from verl import DataProto +from abc import ABC, abstractmethod +from typing import Dict + import torch -__all__ = ['BasePPOActor'] +from verl import DataProto + +__all__ = ["BasePPOActor"] class BasePPOActor(ABC): - def __init__(self, config): """The base class for PPO actor diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 9be6641d2..0281b20dd 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -16,101 +16,107 @@ Single Process Actor """ import itertools -from typing import Iterable, Tuple +from typing import Tuple import torch +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from verl import DataProto -from verl.trainer.ppo.core_algos import compute_policy_loss, kl_penalty, agg_loss -from verl.workers.actor import BasePPOActor -from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import logprobs_from_logits, masked_mean -from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad -from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs +from verl.workers.actor import BasePPOActor -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis - -__all__ = ['DataParallelPPOActor'] +__all__ = ["DataParallelPPOActor"] class DataParallelPPOActor(BasePPOActor): - def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None): """When optimizer is None, it is Reference Policy""" super().__init__(config) self.actor_module = actor_module self.actor_optimizer = actor_optimizer - self.use_remove_padding = self.config.get('use_remove_padding', False) - print(f'Actor use_remove_padding={self.use_remove_padding}') + self.use_remove_padding = self.config.get("use_remove_padding", False) + print(f"Actor use_remove_padding={self.use_remove_padding}") self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 self.compute_entropy_from_logits = ( torch.compile(verl_F.entropy_from_logits, dynamic=True) - if self.config.get('use_torch_compile', True) # use torch compile by default - else verl_F.entropy_from_logits) + if self.config.get("use_torch_compile", True) # use torch compile by default + else verl_F.entropy_from_logits + ) - def _forward_micro_batch(self, - micro_batch, - temperature, - calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_micro_batch( + self, micro_batch, temperature, calculate_entropy=False + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Returns: + Returns: entropy: # (bs, response_len) log_probs: # (bs, response_len) """ - response_length = micro_batch['responses'].size(-1) + response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} - if 'multi_modal_inputs' in micro_batch: - for key in micro_batch['multi_modal_inputs'][0].keys(): - multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], - dim=0) + if "multi_modal_inputs" in micro_batch: + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 + ) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): - input_ids = micro_batch['input_ids'] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] entropy = None if position_ids.dim() == 3: # qwen2vl mrope position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary if position_ids.dim() == 3: - position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), - indices).transpose(0, 1).unsqueeze( - 1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # for compute the log_prob input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.use_ulysses_sp: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, - self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size + ) input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.actor_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) logits_rmpad.div_(temperature) @@ -119,9 +125,9 @@ class DataParallelPPOActor(BasePPOActor): inplace_backward = True if calculate_entropy: inplace_backward = False - log_probs = logprobs_from_logits(logits=logits_rmpad, - labels=input_ids_rmpad_rolled, - inplace_backward=inplace_backward) + log_probs = logprobs_from_logits( + logits=logits_rmpad, labels=input_ids_rmpad_rolled, inplace_backward=inplace_backward + ) # compute entropy if calculate_entropy: @@ -132,36 +138,35 @@ class DataParallelPPOActor(BasePPOActor): # gather and unpad for the ulysses sp log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) if calculate_entropy: - entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + entropy_rmpad = gather_outpus_and_unpad( + entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad back to (bsz, seqlen) if calculate_entropy: - full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) - full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) # only return response part: if calculate_entropy: - entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) + entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) else: # not using rmpad and no ulysses sp - output = self.actor_module(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating logits = output.logits logits.div_(temperature) - logits = logits[:, -response_length - 1:-1, :] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch['responses']) + logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) if calculate_entropy: entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) @@ -204,21 +209,21 @@ class DataParallelPPOActor(BasePPOActor): # set to eval self.actor_module.eval() - micro_batch_size = data.meta_info['micro_batch_size'] - temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error - use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() if has_multi_modal_inputs: num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ['multi_modal_inputs'] + non_tensor_select_keys = ["multi_modal_inputs"] micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) elif use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: micro_batches = batch.split(micro_batch_size) @@ -229,11 +234,11 @@ class DataParallelPPOActor(BasePPOActor): if isinstance(micro_batch, DataProto): micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} - response_mask = micro_batch['attention_mask'][:, -micro_batch['responses'].size(-1):] + response_mask = micro_batch["attention_mask"][:, -micro_batch["responses"].size(-1) :] with torch.no_grad(): - entropy, log_probs = self._forward_micro_batch(micro_batch, - temperature=temperature, - calculate_entropy=calculate_entropy) + entropy, log_probs = self._forward_micro_batch( + micro_batch, temperature=temperature, calculate_entropy=calculate_entropy + ) log_probs_lst.append(log_probs) if calculate_entropy: entropy_lst.append(entropy) @@ -254,19 +259,19 @@ class DataParallelPPOActor(BasePPOActor): # make sure we are in training mode self.actor_module.train() - temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] if self.config.use_kl_loss: - select_keys.append('ref_log_prob') + select_keys.append("ref_log_prob") batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 if has_multi_modal_inputs: num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ['multi_modal_inputs'] + non_tensor_select_keys = ["multi_modal_inputs"] dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) else: dataloader = batch.split(self.config.ppo_mini_batch_size) @@ -277,14 +282,18 @@ class DataParallelPPOActor(BasePPOActor): # split batch into micro_batches mini_batch = data if has_multi_modal_inputs: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) # split batch into micro_batches micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) @@ -296,17 +305,21 @@ class DataParallelPPOActor(BasePPOActor): data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} else: data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload - responses = data['responses'] + responses = data["responses"] response_length = responses.size(1) - attention_mask = data['attention_mask'] + attention_mask = data["attention_mask"] response_mask = attention_mask[:, -response_length:] - old_log_prob = data['old_log_probs'] - advantages = data['advantages'] + old_log_prob = data["old_log_probs"] + advantages = data["advantages"] clip_ratio = self.config.clip_ratio - clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - clip_ratio_c = self.config.get('clip_ratio_c', 3.0) + clip_ratio_low = ( + self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + ) + clip_ratio_high = ( + self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + ) + clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode @@ -314,9 +327,9 @@ class DataParallelPPOActor(BasePPOActor): calculate_entropy = False if entropy_coeff != 0: calculate_entropy = True - entropy, log_prob = self._forward_micro_batch(micro_batch=data, - temperature=temperature, - calculate_entropy=calculate_entropy) + entropy, log_prob = self._forward_micro_batch( + micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy + ) pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( old_log_prob=old_log_prob, @@ -327,7 +340,8 @@ class DataParallelPPOActor(BasePPOActor): cliprange_low=clip_ratio_low, cliprange_high=clip_ratio_high, clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode) + loss_agg_mode=loss_agg_mode, + ) if entropy_coeff != 0: entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) @@ -338,18 +352,18 @@ class DataParallelPPOActor(BasePPOActor): policy_loss = pg_loss if self.config.use_kl_loss: - ref_log_prob = data['ref_log_prob'] + ref_log_prob = data["ref_log_prob"] # compute kl loss - kld = kl_penalty(logprob=log_prob, - ref_logprob=ref_log_prob, - kl_penalty=self.config.kl_loss_type) - kl_loss = agg_loss(loss_mat=kld, - loss_mask=response_mask, - loss_agg_mode=self.config.loss_agg_mode) + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) + kl_loss = agg_loss( + loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode + ) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics['actor/kl_loss'] = kl_loss.detach().item() - metrics['actor/kl_coef'] = self.config.kl_loss_coef + metrics["actor/kl_loss"] = kl_loss.detach().item() + metrics["actor/kl_coef"] = self.config.kl_loss_coef if self.config.use_dynamic_bsz: # relative to the dynamic bsz @@ -359,15 +373,15 @@ class DataParallelPPOActor(BasePPOActor): loss.backward() data = { - 'actor/pg_loss': pg_loss.detach().item(), - 'actor/pg_clipfrac': pg_clipfrac.detach().item(), - 'actor/ppo_kl': ppo_kl.detach().item(), - 'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item(), + "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), } append_to_dict(metrics, data) grad_norm = self._optimizer_step() - data = {'actor/grad_norm': grad_norm.detach().item()} + data = {"actor/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, data) self.actor_optimizer.zero_grad() return metrics diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 6a8250906..58e74826e 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -19,41 +19,42 @@ In megatron actor, the differences are: Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer """ -import importlib from functools import partial -from packaging.version import Version -from typing import Iterable, Dict +from typing import Dict, Iterable import torch -from torch import nn import torch.distributed -from megatron.core.optimizer import OptimizerConfig from megatron.core import parallel_state as mpu -from megatron.core import ModelParallelConfig -from verl.utils.megatron_utils import get_model_config -from megatron.core.pipeline_parallel import get_forward_backward_func - from megatron.core.distributed import finalize_model_grads + # from megatron.core.optimizer import DistributedOptimizer - from megatron.core.optimizer import DistributedOptimizer - +from megatron.core.pipeline_parallel import get_forward_backward_func from omegaconf import OmegaConf -from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits -from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) -from verl import DataProto -from verl.trainer.ppo.core_algos import compute_policy_loss, kl_penalty, agg_loss -from verl.workers.actor import BasePPOActor -from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import logprobs_from_logits, masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches +from torch import nn -__all__ = ['MegatronPPOActor'] +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty +from verl.utils.megatron.pipeline_parallel import compute_transformers_input_shapes, make_batch_generator +from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits +from verl.utils.megatron_utils import get_model_config +from verl.utils.py_functional import append_to_dict +from verl.utils.torch_functional import broadcast_dict_tensor, split_dict_tensor_into_batches +from verl.workers.actor import BasePPOActor + +__all__ = ["MegatronPPOActor"] class MegatronPPOActor(BasePPOActor): - - def __init__(self, config, model_config, hf_config, tf_config, actor_module: nn.ModuleList, - actor_optimizer: DistributedOptimizer): + def __init__( + self, + config, + model_config, + hf_config, + tf_config, + actor_module: nn.ModuleList, + actor_optimizer: DistributedOptimizer, + ): """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron. Args: @@ -108,17 +109,19 @@ class MegatronPPOActor(BasePPOActor): self.actor_module = actor_module self.actor_optimizer: DistributedOptimizer = actor_optimizer - self.optimizer_step_args = OmegaConf.create({ - 'skip_grad': None, - 'overlap_dp_param_comm': False, - 'overlap_dp_grad_comm': False, - 'gradient_accumulation_steps': 1, - 'sequence_parallel': self.tf_config.sequence_parallel, - 'DDP_impl': 'local', - 'layernorm_allreduce_bucket_threshold': 0, - 'pipeline_model_parallel_split_rank': None, - 'reduce_grads_use_alltoall': False - }) + self.optimizer_step_args = OmegaConf.create( + { + "skip_grad": None, + "overlap_dp_param_comm": False, + "overlap_dp_grad_comm": False, + "gradient_accumulation_steps": 1, + "sequence_parallel": self.tf_config.sequence_parallel, + "DDP_impl": "local", + "layernorm_allreduce_bucket_threshold": 0, + "pipeline_model_parallel_split_rank": None, + "reduce_grads_use_alltoall": False, + } + ) config = get_model_config(self.actor_module[0]) print(config) @@ -126,11 +129,11 @@ class MegatronPPOActor(BasePPOActor): def _validate_config(self, config) -> None: """Validate config options not implemented for Megatron backend""" - assert config.get('ulysses_sequence_parallel_size', 1) == 1 - if config.get('shuffle', False): - assert config.data_loader_seed is not None, f'If shuffle dataloader, seed must be manually set' + assert config.get("ulysses_sequence_parallel_size", 1) == 1 + if config.get("shuffle", False): + assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" if config.megatron.tensor_model_parallel_size == 1: - print(f'[Warining] Because actor tp size == 1, set sp to False') + print("[Warining] Because actor tp size == 1, set sp to False") config.megatron.sequence_parallel = False self.config = config @@ -155,61 +158,64 @@ class MegatronPPOActor(BasePPOActor): data.batch = data.batch.contiguous() def compute_logprobs_fn(output, data): - response = data['responses'] + response = data["responses"] response_length = response.size(1) logits = output - logits = logits[:, -response_length - 1:-1].contiguous() + logits = logits[:, -response_length - 1 : -1].contiguous() log_probs = vocab_parallel_log_probs_from_logits(logits, response) - return {'log_probs': log_probs} + return {"log_probs": log_probs} # We make recompute_old_log_prob by default here. # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside - recompute_old_log_prob = self.config.get('recompute_old_log_prob', True) + recompute_old_log_prob = self.config.get("recompute_old_log_prob", True) entropys = torch.Tensor() if recompute_old_log_prob: - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] batch = data.select(batch_keys=select_keys).batch - input_ids = batch['input_ids'] + input_ids = batch["input_ids"] batch_size = input_ids.size(0) - response = batch['responses'] + response = batch["responses"] response_length = response.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data, - forward_only=True, - post_process_fn=compute_logprobs_fn, - calculate_entropy=calculate_entropy) + output = self.forward_backward_batch( + data, forward_only=True, post_process_fn=compute_logprobs_fn, calculate_entropy=calculate_entropy + ) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank if calculate_entropy: - log_probs = torch.cat([o[0]['log_probs'] for o in output], dim=0) # (bs, seq_size) + log_probs = torch.cat([o[0]["log_probs"] for o in output], dim=0) # (bs, seq_size) else: - log_probs = torch.cat([o['log_probs'] for o in output], dim=0) # (bs, seq_size) + log_probs = torch.cat([o["log_probs"] for o in output], dim=0) # (bs, seq_size) log_probs = log_probs.to(torch.float32) else: - log_probs = torch.empty(size=(batch_size, response_length), - dtype=torch.float32, - device=input_ids.device) + log_probs = torch.empty( + size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + ) # broadcast across pp ranks - torch.distributed.broadcast(tensor=log_probs, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - async_op=False) + torch.distributed.broadcast( + tensor=log_probs, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + async_op=False, + ) if calculate_entropy: # Note that o[0] is metrics, o[1] is entropy if mpu.is_pipeline_last_stage(ignore_virtual=True): entropys = torch.cat([o[1] for o in output], dim=0) entropys = entropys.to(torch.float32) else: - entropys = torch.empty(size=(batch_size, response_length), - dtype=torch.float32, - device=input_ids.device) + entropys = torch.empty( + size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + ) # broadcast across pp ranks - torch.distributed.broadcast(tensor=entropys, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - async_op=False) + torch.distributed.broadcast( + tensor=entropys, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + async_op=False, + ) # add empty cache after each compute torch.cuda.empty_cache() @@ -238,20 +244,20 @@ class MegatronPPOActor(BasePPOActor): Returns: """ - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] if self.config.use_kl_loss: - select_keys.append('ref_log_prob') + select_keys.append("ref_log_prob") data = data.select(batch_keys=select_keys) - return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size, - epochs=self.config.ppo_epochs, - seed=self.config.data_loader_seed, - dataloader_kwargs={'shuffle': self.config.shuffle}) + return data.make_iterator( + mini_batch_size=self.config.ppo_mini_batch_size, + epochs=self.config.ppo_epochs, + seed=self.config.data_loader_seed, + dataloader_kwargs={"shuffle": self.config.shuffle}, + ) - def forward_backward_batch(self, - data: DataProto, - forward_only=False, - post_process_fn=None, - calculate_entropy=False): + def forward_backward_batch( + self, data: DataProto, forward_only=False, post_process_fn=None, calculate_entropy=False + ): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -259,25 +265,27 @@ class MegatronPPOActor(BasePPOActor): """ # broadcast from last pp rank to all other pp ranks # TODO: actually, we just need to control the sampling order. - broadcast_dict_tensor(data.batch, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group()) + broadcast_dict_tensor( + data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group() + ) # split into micro-batches - data.batch['attention_mask'] = data.batch['attention_mask'].to(bool) + data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) - if data.meta_info.get('micro_batch_size', None) is not None: - batch_size = data.meta_info['micro_batch_size'] + if data.meta_info.get("micro_batch_size", None) is not None: + batch_size = data.meta_info["micro_batch_size"] else: batch_size = self.config.ppo_micro_batch_size_per_gpu batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size) # compute input shapes for pp stages - input_shapes = compute_transformers_input_shapes(batches, - meta_info={ - 'sequence_parallel': self.tf_config.sequence_parallel, - 'hidden_size': self.model_config.hidden_size - }) + input_shapes = compute_transformers_input_shapes( + batches, + meta_info={ + "sequence_parallel": self.tf_config.sequence_parallel, + "hidden_size": self.model_config.hidden_size, + }, + ) n_micro_batch = len(batches) - seq_len = batches[0]['input_ids'].shape[1] + seq_len = batches[0]["input_ids"].shape[1] forward_backward_func = get_forward_backward_func() @@ -287,47 +295,49 @@ class MegatronPPOActor(BasePPOActor): metrics = {} if forward_only: if post_process_fn is None: - metrics['logits'] = output + metrics["logits"] = output else: stats = post_process_fn(output, data) metrics.update(stats) if not calculate_entropy: return 1.0, metrics - responses = data['responses'] + responses = data["responses"] response_length = responses.size(1) - attention_mask = data['attention_mask'] + attention_mask = data["attention_mask"] response_mask = attention_mask[:, -response_length:] loss_agg_mode = self.config.loss_agg_mode # compute policy loss logits = output - logits = logits[:, -response_length - 1:-1].contiguous() + logits = logits[:, -response_length - 1 : -1].contiguous() ret_entropy = None if not forward_only: - old_log_prob = data['old_log_probs'] - advantages = data['advantages'] + old_log_prob = data["old_log_probs"] + advantages = data["advantages"] - clip_ratio = meta_info['clip_ratio'] + clip_ratio = meta_info["clip_ratio"] clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - clip_ratio_c = meta_info['clip_ratio_c'] + clip_ratio_c = meta_info["clip_ratio_c"] log_prob = vocab_parallel_log_probs_from_logits(logits, responses) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode, + ) policy_loss = pg_loss if calculate_entropy: entropy = vocab_parallel_entropy(logits) if not forward_only: entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - entropy_coeff = meta_info['entropy_coeff'] + entropy_coeff = meta_info["entropy_coeff"] policy_loss = pg_loss - entropy_coeff * entropy_loss else: ret_entropy = entropy @@ -337,46 +347,47 @@ class MegatronPPOActor(BasePPOActor): policy_loss = 1.0 else: if self.config.use_kl_loss: - ref_log_prob = data['ref_log_prob'] + ref_log_prob = data["ref_log_prob"] # compute kl loss kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics['actor/kl_loss'] = kl_loss.detach().item() - metrics['actor/kl_coef'] = self.config.kl_loss_coef + metrics["actor/kl_loss"] = kl_loss.detach().item() + metrics["actor/kl_coef"] = self.config.kl_loss_coef # return loss and stats - stats.update({ - 'actor/pg_loss': pg_loss.detach().item(), - 'actor/pg_clipfrac': pg_clipfrac.detach().item(), - 'actor/ppo_kl': ppo_kl.detach().item(), - 'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item() - }) + stats.update( + { + "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + ) append_to_dict(metrics, stats) return policy_loss, [metrics, ret_entropy] def forward_step(batch_iter, model): batch = next(batch_iter) - input_ids = batch['input_ids'] - attention_mask = batch['attention_mask'] - position_ids = batch['position_ids'] + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] from verl.models.mcore import get_mcore_forward_fn + forward_fn = get_mcore_forward_fn(self.hf_config) - output = forward_fn(model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel) + output = forward_fn( + model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel + ) if forward_only: meta_info = None else: - clip_ratio_c = self.config.get('clip_ratio_c', 3.0) + clip_ratio_c = self.config.get("clip_ratio_c", 3.0) meta_info = { - 'clip_ratio': self.config.clip_ratio, - 'entropy_coeff': self.config.entropy_coeff, - 'clip_ratio_c': clip_ratio_c + "clip_ratio": self.config.clip_ratio, + "entropy_coeff": self.config.entropy_coeff, + "clip_ratio_c": clip_ratio_c, } return output, partial(loss_func, data=batch, meta_info=meta_info) diff --git a/verl/workers/critic/base.py b/verl/workers/critic/base.py index 9d1055df4..8201758f3 100644 --- a/verl/workers/critic/base.py +++ b/verl/workers/critic/base.py @@ -14,17 +14,17 @@ """ Base class for a critic """ + from abc import ABC, abstractmethod import torch from verl import DataProto -__all__ = ['BasePPOCritic'] +__all__ = ["BasePPOCritic"] class BasePPOCritic(ABC): - def __init__(self, config): super().__init__() self.config = config diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index d100425b3..85abd96df 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -14,102 +14,107 @@ """ Implement a multiprocess PPOCritic """ + import itertools -from typing import Iterable import torch import torch.distributed +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn, optim - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto from verl.trainer.ppo import core_algos -from verl.workers.critic import BasePPOCritic from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import masked_mean -from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad -from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs +from verl.workers.critic import BasePPOCritic -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis - -__all__ = ['DataParallelPPOCritic'] +__all__ = ["DataParallelPPOCritic"] class DataParallelPPOCritic(BasePPOCritic): - def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer): super().__init__(config=config) self.critic_module = critic_module self.critic_optimizer = critic_optimizer - self.use_remove_padding = self.config.model.get('use_remove_padding', False) - print(f'Critic use_remove_padding={self.use_remove_padding}') + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + print(f"Critic use_remove_padding={self.use_remove_padding}") - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) def _forward_micro_batch(self, micro_batch): - response_length = micro_batch['responses'].size(-1) + response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} - if 'multi_modal_inputs' in micro_batch: - for key in micro_batch['multi_modal_inputs'][0].keys(): - multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], - dim=0) + if "multi_modal_inputs" in micro_batch: + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 + ) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): - input_ids = micro_batch['input_ids'] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] batch, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] if position_ids.dim() == 3: # qwen2vl mrope position_ids = position_ids.transpose(0, 1) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary if position_ids.dim() == 3: - position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), - indices).transpose(0, 1).unsqueeze( - 1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.critic_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating + output = self.critic_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating values_rmpad = output.logits values_rmpad = values_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: - values_rmpad = gather_outpus_and_unpad(values_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + values_rmpad = gather_outpus_and_unpad( + values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad it back values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) - values = values[:, -response_length - 1:-1] + values = values[:, -response_length - 1 : -1] else: - output = self.critic_module(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating + output = self.critic_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating values = output.logits - values = values[:, -response_length - 1:-1].squeeze(-1) + values = values[:, -response_length - 1 : -1].squeeze(-1) return values def _optimizer_step(self): @@ -130,19 +135,19 @@ class DataParallelPPOCritic(BasePPOCritic): def compute_values(self, data: DataProto) -> torch.Tensor: self.critic_module.eval() - micro_batch_size = data.meta_info['micro_batch_size'] - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + micro_batch_size = data.meta_info["micro_batch_size"] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] batch = data.select(batch_keys=select_keys).batch - use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] - has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() if has_multi_modal_inputs: num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ['multi_modal_inputs'] + non_tensor_select_keys = ["multi_modal_inputs"] micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) elif use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: micro_batches = batch.split(micro_batch_size) @@ -156,10 +161,10 @@ class DataParallelPPOCritic(BasePPOCritic): values = self._forward_micro_batch(micro_batch) values_lst.append(values) values = torch.concat(values_lst, dim=0) - responses = data.batch['responses'] - attention_mask = data.batch['attention_mask'] + responses = data.batch["responses"] + attention_mask = data.batch["attention_mask"] response_length = responses.size(1) - values = values * attention_mask[:, -response_length - 1:-1] + values = values * attention_mask[:, -response_length - 1 : -1] if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) @@ -174,15 +179,15 @@ class DataParallelPPOCritic(BasePPOCritic): self.critic_module.train() metrics = {} - select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] + select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 if has_multi_modal_inputs: num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ['multi_modal_inputs'] + non_tensor_select_keys = ["multi_modal_inputs"] dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) else: dataloader = batch.split(self.config.ppo_mini_batch_size) @@ -199,35 +204,39 @@ class DataParallelPPOCritic(BasePPOCritic): micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) self.critic_optimizer.zero_grad() for data in micro_batches: - #Support all devices + # Support all devices if isinstance(data, DataProto): data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} else: data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload - input_ids = data['input_ids'] - responses = data['responses'] - attention_mask = data['attention_mask'] - position_ids = data['position_ids'] - values = data['values'] - returns = data['returns'] + input_ids = data["input_ids"] + responses = data["responses"] + attention_mask = data["attention_mask"] + position_ids = data["position_ids"] + values = data["values"] + returns = data["returns"] response_length = responses.size(1) - response_mask = attention_mask[:, -response_length - 1:-1] + response_mask = attention_mask[:, -response_length - 1 : -1] vpreds = self._forward_micro_batch(data) # assert not torch.any(torch.isnan(vpreds)).item() - vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, - values=values, - returns=returns, - response_mask=response_mask, - cliprange_value=self.config.cliprange_value) + vf_loss, vf_clipfrac = core_algos.compute_value_loss( + vpreds=vpreds, + values=values, + returns=returns, + response_mask=response_mask, + cliprange_value=self.config.cliprange_value, + ) if self.config.use_dynamic_bsz: # relative to the dynamic bsz loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) @@ -237,15 +246,15 @@ class DataParallelPPOCritic(BasePPOCritic): loss.backward() data = { - 'critic/vf_loss': vf_loss.detach().item(), - 'critic/vf_clipfrac': vf_clipfrac.detach().item(), - 'critic/vpred_mean': masked_mean(vpreds, response_mask).detach().item(), + "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_clipfrac": vf_clipfrac.detach().item(), + "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), } append_to_dict(metrics, data) grad_norm = self._optimizer_step() - data = {'critic/grad_norm': grad_norm.detach().item()} + data = {"critic/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, data) self.critic_optimizer.zero_grad() return metrics diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 9f5b08b62..dbd06d6d4 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -15,35 +15,36 @@ Implement a multiprocess PPOCritic """ -import importlib from functools import partial -from packaging.version import Version from typing import Iterable import torch import torch.distributed +from megatron.core import parallel_state as mpu +from megatron.core.optimizer import DistributedOptimizer, OptimizerConfig +from megatron.core.pipeline_parallel import get_forward_backward_func from omegaconf import OmegaConf from torch import nn from verl import DataProto from verl.trainer.ppo import core_algos -from verl.workers.critic import BasePPOCritic -from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) +from verl.utils.megatron.pipeline_parallel import compute_transformers_input_shapes, make_batch_generator from verl.utils.py_functional import append_to_dict -from verl.utils.torch_dtypes import PrecisionType -from verl.utils.torch_functional import masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches -from verl.utils.megatron import sequence_parallel as sp_utils -from megatron.core.optimizer import OptimizerConfig - -from megatron.core import parallel_state as mpu -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.optimizer import DistributedOptimizer +from verl.utils.torch_functional import broadcast_dict_tensor, masked_mean, split_dict_tensor_into_batches +from verl.workers.critic import BasePPOCritic class MegatronPPOCritic(BasePPOCritic): - - def __init__(self, config, model_config, hf_config, tf_config, critic_module: nn.ModuleList, - critic_optimizer: DistributedOptimizer, critic_optimizer_config: OptimizerConfig): + def __init__( + self, + config, + model_config, + hf_config, + tf_config, + critic_module: nn.ModuleList, + critic_optimizer: DistributedOptimizer, + critic_optimizer_config: OptimizerConfig, + ): super().__init__(config=config) self._validate_config(config) self.model_config = model_config @@ -55,51 +56,55 @@ class MegatronPPOCritic(BasePPOCritic): self.critic_optimizer_config = critic_optimizer_config # we create a separate nametuple for optimizer step so that global args won't affect it. - self.optimizer_step_args = OmegaConf.create({ - 'skip_grad': None, - 'overlap_dp_param_comm': False, - 'overlap_dp_grad_comm': False, - 'gradient_accumulation_steps': 1, - 'sequence_parallel': self.tf_config.sequence_parallel, - 'DDP_impl': 'local', - 'layernorm_allreduce_bucket_threshold': 0, - 'pipeline_model_parallel_split_rank': None, - 'reduce_grads_use_alltoall': False - }) + self.optimizer_step_args = OmegaConf.create( + { + "skip_grad": None, + "overlap_dp_param_comm": False, + "overlap_dp_grad_comm": False, + "gradient_accumulation_steps": 1, + "sequence_parallel": self.tf_config.sequence_parallel, + "DDP_impl": "local", + "layernorm_allreduce_bucket_threshold": 0, + "pipeline_model_parallel_split_rank": None, + "reduce_grads_use_alltoall": False, + } + ) def _validate_config(self, config) -> None: """Validate config options not implemented for Megatron backend""" - assert config.get('ulysses_sequence_parallel_size', 1) == 1 + assert config.get("ulysses_sequence_parallel_size", 1) == 1 if config.shuffle: - assert config.data_loader_seed is not None, f'If shuffle dataloader, seed must be manually set' + assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" if config.megatron.tensor_model_parallel_size == 1: - print(f'[Warining] Because critic tp size == 1, set sp to False') + print("[Warining] Because critic tp size == 1, set sp to False") config.megatron.sequence_parallel = False self.config = config def compute_values(self, data: DataProto) -> DataProto: # data.batch = data.batch.to(self.critic_module.module.device) - responses = data.batch['responses'] - attention_mask = data.batch['attention_mask'] + responses = data.batch["responses"] + attention_mask = data.batch["attention_mask"] response_length = responses.size(1) with torch.no_grad(): output = self.forward_backward_batch(data=data, forward_only=True) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank - values = torch.cat([o['vpreds'] for o in output], dim=0) # (bs, seq_size, vocal_size) + values = torch.cat([o["vpreds"] for o in output], dim=0) # (bs, seq_size, vocal_size) values = values.to(torch.float32) else: values = torch.empty_like(attention_mask, dtype=torch.float32) # each tp ranks should contain the same value values = values * attention_mask - values = values[:, -response_length - 1:-1] + values = values[:, -response_length - 1 : -1] values = values.contiguous() # sync among pp ranks - torch.distributed.broadcast(tensor=values, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.broadcast( + tensor=values, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) # add empty cache after each compute torch.cuda.empty_cache() @@ -107,42 +112,46 @@ class MegatronPPOCritic(BasePPOCritic): return values def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: - select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] + select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] data = data.select(batch_keys=select_keys) - return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size, - epochs=self.config.ppo_epochs, - seed=self.config.data_loader_seed, - dataloader_kwargs={'shuffle': self.config.shuffle}) + return data.make_iterator( + mini_batch_size=self.config.ppo_mini_batch_size, + epochs=self.config.ppo_epochs, + seed=self.config.data_loader_seed, + dataloader_kwargs={"shuffle": self.config.shuffle}, + ) def forward_backward_batch(self, data: DataProto, forward_only=False): # broadcast from last pp rank to all other pp ranks data.batch = data.batch.contiguous() - broadcast_dict_tensor(data.batch, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group()) + broadcast_dict_tensor( + data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group() + ) # split into micro-batches - data.batch['attention_mask'] = data.batch['attention_mask'].to(bool) + data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size_per_gpu) n_micro_batch = len(batches) - seq_len = batches[0]['input_ids'].shape[1] + seq_len = batches[0]["input_ids"].shape[1] # compute input shapes for pp stages - input_shapes = compute_transformers_input_shapes(batches, - meta_info={ - 'sequence_parallel': self.tf_config.sequence_parallel, - 'hidden_size': self.model_config.hidden_size - }) + input_shapes = compute_transformers_input_shapes( + batches, + meta_info={ + "sequence_parallel": self.tf_config.sequence_parallel, + "hidden_size": self.model_config.hidden_size, + }, + ) forward_backward_func = get_forward_backward_func() def loss_func(output, data, meta_info): if forward_only: - return 1.0, {'vpreds': output} + return 1.0, {"vpreds": output} - responses = data['responses'] - attention_mask = data['attention_mask'] - values = data['values'] - returns = data['returns'] + responses = data["responses"] + attention_mask = data["attention_mask"] + values = data["values"] + returns = data["returns"] response_length = responses.size(1) response_mask = attention_mask[:, -response_length:] @@ -150,35 +159,40 @@ class MegatronPPOCritic(BasePPOCritic): cliprange_value = self.config.cliprange_value vpreds = output # (bs, sequence_length) - vpreds = vpreds[:, -response_length - 1:-1] + vpreds = vpreds[:, -response_length - 1 : -1] - vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, - values=values, - returns=returns, - response_mask=response_mask, - cliprange_value=cliprange_value) + vf_loss, vf_clipfrac = core_algos.compute_value_loss( + vpreds=vpreds, + values=values, + returns=returns, + response_mask=response_mask, + cliprange_value=cliprange_value, + ) stats = { - 'critic/vf_loss': vf_loss.detach().item(), - 'critic/vf_clipfrac': vf_clipfrac.detach().item(), - 'critic/vpred_mean': masked_mean(vpreds, response_mask).detach().item(), + "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_clipfrac": vf_clipfrac.detach().item(), + "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), } return vf_loss, stats def forward_step(batch_iter, model): batch = next(batch_iter) - input_ids = batch['input_ids'] - attention_mask = batch['attention_mask'] - position_ids = batch['position_ids'] + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] from verl.models.mcore import get_mcore_forward_fn + forward_fn = get_mcore_forward_fn(self.hf_config) - output = forward_fn(model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, - value_model=True) + output = forward_fn( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel=self.tf_config.sequence_parallel, + value_model=True, + ) return output, partial(loss_func, data=batch, meta_info={}) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 5ab5724e4..979186bf5 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -18,46 +18,53 @@ The main entry point to run the PPO algorithm import logging import os import warnings -import psutil +import psutil import torch import torch.distributed -from torch.distributed.device_mesh import init_device_mesh -import verl.utils.torch_functional as verl_F +from codetiming import Timer from omegaconf import DictConfig, open_dict +from torch.distributed.device_mesh import init_device_mesh + +import verl.utils.torch_functional as verl_F from verl import DataProto from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import register, Dispatch -from verl.utils import hf_tokenizer, hf_processor +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage +from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager -from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ - load_fsdp_model_to_gpu +from verl.utils.fsdp_utils import ( + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask -from verl.utils.flops_counter import FlopsCounter -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from codetiming import Timer - logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: - device_mesh = init_device_mesh('cuda', - mesh_shape=(world_size // fsdp_size, fsdp_size), - mesh_dim_names=['ddp', 'fsdp']) + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) return device_mesh def get_sharding_strategy(device_mesh): from torch.distributed.fsdp import ShardingStrategy + if device_mesh.ndim == 1: sharding_strategy = ShardingStrategy.FULL_SHARD elif device_mesh.ndim == 2: @@ -77,6 +84,7 @@ class ActorRolloutRefWorker(Worker): super().__init__() self.config = config import torch.distributed + if not torch.distributed.is_initialized(): torch.distributed.init_process_group() @@ -87,76 +95,85 @@ class ActorRolloutRefWorker(Worker): # build device mesh for Ulysses Sequence Parallel self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.actor.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=['dp', 'sp']) + self.ulysses_device_mesh = init_device_mesh( + "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.role = role - assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref'] + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] - self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref'] - self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref'] - self._is_ref = self.role in ['ref', 'actor_rollout_ref'] + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] self._is_offload_param = False self._is_offload_optimizer = False if self._is_actor: - self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False) - self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False) + self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) + self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload - self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False) + self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) # normalize config if self._is_actor: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n - self.config.actor.ppo_mini_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size) - assert self.config.actor.ppo_mini_batch_size > 0, f'ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization' + self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + assert self.config.actor.ppo_mini_batch_size > 0, ( + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" + ) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: - self.config.actor.ppo_micro_batch_size //= (self.device_mesh.size() // - self.ulysses_sequence_parallel_size) + self.config.actor.ppo_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size - assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, \ - f'normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}' - assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, \ - f'normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}' + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) # normalize rollout config if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: - self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.size() // - self.ulysses_sequence_parallel_size) + self.config.rollout.log_prob_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: - self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.size() // - self.ulysses_sequence_parallel_size) + self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size - def _build_model_optimizer(self, - model_path, - fsdp_config, - optim_config, - override_model_config, - use_remove_padding=False, - enable_gradient_checkpointing=False, - trust_remote_code=False, - use_liger=False, - role='actor'): - from verl.utils.model import print_model_size, update_model_config, get_generation_config - from verl.utils.torch_dtypes import PrecisionType - from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForVision2Seq - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload + def _build_model_optimizer( + self, + model_path, + fsdp_config, + optim_config, + override_model_config, + use_remove_padding=False, + enable_gradient_checkpointing=False, + trust_remote_code=False, + use_liger=False, + role="actor", + ): from torch import optim + from torch.distributed.fsdp import CPUOffload, MixedPrecision + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq - assert role in ['actor', 'ref'] + from verl.utils.model import get_generation_config, print_model_size, update_model_config + from verl.utils.torch_dtypes import PrecisionType - log_gpu_memory_usage('Before init from HF AutoModel', logger=logger) + assert role in ["actor", "ref"] + + log_gpu_memory_usage("Before init from HF AutoModel", logger=logger) local_path = copy_to_local(model_path) # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect @@ -164,7 +181,7 @@ class ActorRolloutRefWorker(Worker): self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) - torch_dtype = fsdp_config.get('model_dtype', None) + torch_dtype = fsdp_config.get("model_dtype", None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: @@ -176,18 +193,19 @@ class ActorRolloutRefWorker(Worker): self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) override_config_kwargs = { - 'bos_token_id': self.tokenizer.bos_token_id, - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_model_config) update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) if self.rank == 0: - print(f'Model config after override: {actor_model_config}') + print(f"Model config after override: {actor_model_config}") # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang - init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings, - mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -196,39 +214,43 @@ class ActorRolloutRefWorker(Worker): else: actor_module_class = AutoModelForCausalLM - actor_module = actor_module_class.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=actor_model_config, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + actor_module = actor_module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) if use_remove_padding or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + _apply_liger_kernel_to_instance(model=actor_module) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 actor_module.to(torch_dtype) if enable_gradient_checkpointing: - actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) torch.distributed.barrier() if self.rank == 0: print_model_size(actor_module) - log_gpu_memory_usage('After init from HF AutoModel', logger=logger) + log_gpu_memory_usage("After init from HF AutoModel", logger=logger) # We wrap FSDP for rollout as well - mixed_precision_config = fsdp_config.get('mixed_precision', None) + mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 @@ -236,13 +258,13 @@ class ActorRolloutRefWorker(Worker): mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None)) + auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get("wrap_policy", None)) - if self._is_rollout and self.config.rollout.name == 'hf': + if self._is_rollout and self.config.rollout.name == "hf": # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma auto_wrap_policy = None - print(f'wrap_policy: {auto_wrap_policy}') + print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) @@ -250,7 +272,7 @@ class ActorRolloutRefWorker(Worker): # TODO: add transformer policy # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation - cpu_offload = None if role == 'actor' else CPUOffload(offload_params=True) + cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) actor_module_fsdp = FSDP( actor_module, cpu_offload=cpu_offload, @@ -262,125 +284,149 @@ class ActorRolloutRefWorker(Worker): mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, - forward_prefetch=False) + forward_prefetch=False, + ) - log_gpu_memory_usage('After Actor FSDP init', logger=logger) + log_gpu_memory_usage("After Actor FSDP init", logger=logger) # TODO: add more optimizer args into config - if role == 'actor' and optim_config is not None: + if role == "actor" and optim_config is not None: from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup - actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(), - lr=optim_config.lr, - betas=optim_config.get('betas', (0.9, 0.999)), - weight_decay=optim_config.get('weight_decay', 1e-2)) - total_steps = optim_config.get('total_training_steps', 0) - num_warmup_steps = int(optim_config.get('lr_warmup_steps', -1)) - warmup_style = optim_config.get('warmup_style', 'constant') + actor_optimizer = optim.AdamW( + actor_module_fsdp.parameters(), + lr=optim_config.lr, + betas=optim_config.get("betas", (0.9, 0.999)), + weight_decay=optim_config.get("weight_decay", 1e-2), + ) + + total_steps = optim_config.get("total_training_steps", 0) + num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) + warmup_style = optim_config.get("warmup_style", "constant") if num_warmup_steps < 0: - num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.) + num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - if warmup_style == 'constant': - actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, - num_warmup_steps=num_warmup_steps) - elif warmup_style == 'cosine': - actor_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=actor_optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=total_steps) + if warmup_style == "constant": + actor_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps + ) + elif warmup_style == "cosine": + actor_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps + ) else: - raise NotImplementedError(f'Warmup style {warmup_style} is not supported') + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") else: actor_optimizer = None actor_lr_scheduler = None - log_gpu_memory_usage('After actor optimizer init', logger=logger) + log_gpu_memory_usage("After actor optimizer init", logger=logger) return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh + # TODO(sgm): support FSDP hybrid shard for larger model infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' - rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) rollout_name = self.config.rollout.name - if rollout_name == 'hf': + if rollout_name == "hf": from verl.workers.rollout import HFRollout from verl.workers.sharding_manager import BaseShardingManager + rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) rollout_sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? - elif rollout_name == 'vllm': - from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode + elif rollout_name == "vllm": + from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout from verl.workers.sharding_manager import FSDPVLLMShardingManager - log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None) + + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) local_path = copy_to_local(self.config.model.path) - if vllm_mode == 'customized': - rollout = vLLMRollout(actor_module=self.actor_module_fsdp, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config) - elif vllm_mode == 'spmd': - rollout = vLLMRollout(model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code) + if vllm_mode == "customized": + rollout = vLLMRollout( + actor_module=self.actor_module_fsdp, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + ) + elif vllm_mode == "spmd": + rollout = vLLMRollout( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + ) else: raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") - log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = 'dummy_hf' - rollout_sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - full_params='hf' in self.config.rollout.load_format, - device_mesh=rollout_device_mesh) - log_gpu_memory_usage('After building sharding manager', logger=None) + self.config.rollout.load_format = "dummy_hf" + rollout_sharding_manager = FSDPVLLMShardingManager( + module=self.actor_module_fsdp, + inference_engine=rollout.inference_engine, + model_config=self.actor_model_config, + full_params="hf" in self.config.rollout.load_format, + device_mesh=rollout_device_mesh, + ) + log_gpu_memory_usage("After building sharding manager", logger=None) - elif rollout_name == 'sglang': + elif rollout_name == "sglang": from verl.workers.rollout.sglang_rollout import SGLangRollout + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. # However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to: # "RuntimeError: No CUDA GPUs are available". # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path. # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager - log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None) + + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) local_path = copy_to_local(self.config.model.path) - rollout = SGLangRollout(actor_module=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config) - log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None) + rollout = SGLangRollout( + actor_module=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + ) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = 'dummy_hf' - rollout_sharding_manager = FSDPSGLangShardingManager(module=self.actor_module_fsdp, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - full_params='hf' in self.config.rollout.load_format, - device_mesh=rollout_device_mesh) - log_gpu_memory_usage('After building sharding manager', logger=None) + self.config.rollout.load_format = "dummy_hf" + rollout_sharding_manager = FSDPSGLangShardingManager( + module=self.actor_module_fsdp, + inference_engine=rollout.inference_engine, + model_config=self.actor_model_config, + full_params="hf" in self.config.rollout.load_format, + device_mesh=rollout_device_mesh, + ) + log_gpu_memory_usage("After building sharding manager", logger=None) return rollout, rollout_sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from verl.workers.actor import DataParallelPPOActor + # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) from omegaconf import OmegaConf - override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) - use_remove_padding = self.config.model.get('use_remove_padding', False) + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + + use_remove_padding = self.config.model.get("use_remove_padding", False) if self._is_actor or self._is_rollout: # we need the model for actor and rollout @@ -390,46 +436,51 @@ class ActorRolloutRefWorker(Worker): else: optim_config = None fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), - trust_remote_code=self.config.model.get('trust_remote_code', False), - use_liger=self.config.model.get('use_liger', False), - role='actor') + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( + self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) + ) # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) # load from checkpoint if self._is_actor: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding - self.actor = DataParallelPPOActor(config=self.config.actor, - actor_module=self.actor_module_fsdp, - actor_optimizer=self.actor_optimizer) + self.actor = DataParallelPPOActor( + config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) if self._is_rollout: self.rollout, self.rollout_sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get('trust_remote_code', False)) + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) if self._is_ref: - self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, - fsdp_config=self.config.ref.fsdp_config, - optim_config=None, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - trust_remote_code=self.config.model.get( - 'trust_remote_code', False), - use_liger=self.config.model.get('use_liger', False), - role='ref')[0] + self.ref_module_fsdp = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=self.config.ref.fsdp_config, + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding @@ -442,7 +493,8 @@ class ActorRolloutRefWorker(Worker): optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents) + checkpoint_contents=self.config.actor.checkpoint.contents, + ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): @@ -455,33 +507,34 @@ class ActorRolloutRefWorker(Worker): if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) - log_gpu_memory_usage('Before update policy', logger=logger) + log_gpu_memory_usage("Before update policy", logger=logger) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # perform training - with Timer(name='update_policy', logger=None) as timer: + with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(data=data) delta_time = timer.last - global_num_tokens = data.meta_info['global_token_num'] + global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics[ - 'perf/mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - metrics['perf/max_memory_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3) - metrics['perf/max_memory_reserved_gb'] = torch.cuda.max_memory_reserved() / (1024**3) - metrics['perf/cpu_memory_used_gb'] = psutil.virtual_memory().used / (1024**3) + metrics["perf/mfu/actor"] = ( + estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + ) + metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics['actor/lr'] = lr + metrics["actor/lr"] = lr - log_gpu_memory_usage('After update policy', logger=logger) + log_gpu_memory_usage("After update policy", logger=logger) # TODO: here, we should return all metrics - output = DataProto(meta_info={'metrics': metrics}) + output = DataProto(meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to('cpu') + output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @@ -500,34 +553,33 @@ class ActorRolloutRefWorker(Worker): load_fsdp_model_to_gpu(self.actor_module_fsdp) meta_info = { - 'eos_token_id': - self.generation_config.eos_token_id - if self.generation_config is not None else self.tokenizer.eos_token_id, - 'pad_token_id': - self.generation_config.pad_token_id - if self.generation_config is not None else self.tokenizer.pad_token_id, + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: - # after parameters sync with rollout, offload actor model to CPU if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage('After entering rollout sharding manager', logger=logger) + log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) output = self.rollout.generate_sequences(prompts=prompts) - log_gpu_memory_usage('After rollout generation', logger=logger) + log_gpu_memory_usage("After rollout generation", logger=logger) output = self.rollout_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # clear kv cache - log_gpu_memory_usage('After generate_sequences', logger=logger) + log_gpu_memory_usage("After generate_sequences", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -539,22 +591,21 @@ class ActorRolloutRefWorker(Worker): # Support all hardwares data = data.to(torch.cuda.current_device()) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info['temperature'] = self.config.rollout.temperature + data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) - output = DataProto.from_dict(tensors={ - 'old_log_probs': output, - 'entropys': entropys - }, - meta_info={'temperature': self.config.rollout.temperature}) + output = DataProto.from_dict( + tensors={"old_log_probs": output, "entropys": entropys}, + meta_info={"temperature": self.config.rollout.temperature}, + ) output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -564,7 +615,7 @@ class ActorRolloutRefWorker(Worker): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage('After compute_log_prob', logger=logger) + log_gpu_memory_usage("After compute_log_prob", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -575,17 +626,17 @@ class ActorRolloutRefWorker(Worker): data = data.to(torch.cuda.current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size - data.meta_info['temperature'] = self.config.rollout.temperature - data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - output = DataProto.from_dict(tensors={'ref_log_prob': output}) + output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -599,13 +650,13 @@ class ActorRolloutRefWorker(Worker): # only support save and load ckpt for actor assert self._is_actor import torch + if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - self.checkpoint_manager.save_checkpoint(local_path=local_path, - hdfs_path=hdfs_path, - global_step=global_step, - max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) torch.distributed.barrier() if self._is_offload_param: @@ -616,9 +667,9 @@ class ActorRolloutRefWorker(Worker): if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - self.checkpoint_manager.load_checkpoint(local_path=local_path, - hdfs_path=hdfs_path, - del_local_after_load=del_local_after_load) + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @@ -628,10 +679,10 @@ class ActorRolloutRefWorker(Worker): class CriticWorker(Worker): - def __init__(self, config): super().__init__() import torch.distributed + if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.config = config @@ -644,12 +695,12 @@ class CriticWorker(Worker): self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=['dp', 'sp']) + self.ulysses_device_mesh = init_device_mesh( + "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -659,89 +710,99 @@ class CriticWorker(Worker): # normalize config self.config.ppo_mini_batch_size *= self.config.rollout_n - self.config.ppo_mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size if self.config.ppo_micro_batch_size is not None: - self.config.ppo_micro_batch_size //= (torch.distributed.get_world_size() // - self.ulysses_sequence_parallel_size) - self.config.forward_micro_batch_size //= (torch.distributed.get_world_size() // - self.ulysses_sequence_parallel_size) + self.config.ppo_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) + self.config.forward_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, \ - f'normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}' - assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, \ - f'normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}' + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) def _build_critic_model_optimizer(self, config): # the following line is necessary - from verl.utils.model import LambdaLayer, print_model_size, squeeze - from verl.utils.torch_dtypes import PrecisionType - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision from torch import optim + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import MixedPrecision + + from verl.utils.model import print_model_size + from verl.utils.torch_dtypes import PrecisionType local_path = copy_to_local(config.model.path) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_to_local(config.model.tokenizer_path) - self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) - self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) from omegaconf import OmegaConf - override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + + override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) override_config_kwargs = { - 'bos_token_id': self.tokenizer.bos_token_id, - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_config) if self.rank == 0: - print(f'Critic overriding config {override_config_kwargs}') + print(f"Critic overriding config {override_config_kwargs}") - torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32') + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig, AutoModelForTokenClassification - from torch import nn trust_remote_code = False critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) critic_model_config.num_labels = 1 - init_context = get_init_weight_context_manager(use_meta_tensor=not critic_model_config.tie_word_embeddings, - mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(critic_model_config, 'classifier_dropout', 0.) - setattr(critic_model_config, 'hidden_dropout', '0') - critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=critic_model_config, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + critic_model_config.classifier_dropout = 0.0 + critic_model_config.hidden_dropout = "0" + critic_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=critic_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) - use_remove_padding = config.model.get('use_remove_padding', False) + use_remove_padding = config.model.get("use_remove_padding", False) if use_remove_padding or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) # some parameters may not in torch_dtype critic_module.to(torch_dtype) - if config.model.get('enable_gradient_checkpointing', False): - critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + if config.model.get("enable_gradient_checkpointing", False): + critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if self.rank == 0: print_model_size(critic_module) self.critic_model_config = critic_model_config fsdp_config = self.config.model.fsdp_config - mixed_precision_config = fsdp_config.get('mixed_precision', None) + mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 @@ -751,70 +812,78 @@ class CriticWorker(Worker): auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy) - log_gpu_memory_usage('Before critic FSDP', logger=None) + log_gpu_memory_usage("Before critic FSDP", logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation - critic_module = FSDP(critic_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None) + critic_module = FSDP( + critic_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) - log_gpu_memory_usage('After critic FSDP', logger=None) + log_gpu_memory_usage("After critic FSDP", logger=None) - critic_optimizer = optim.AdamW(critic_module.parameters(), - lr=config.optim.lr, - betas=config.optim.get('betas', (0.9, 0.999)), - weight_decay=config.optim.get('weight_decay', 1e-2)) + critic_optimizer = optim.AdamW( + critic_module.parameters(), + lr=config.optim.lr, + betas=config.optim.get("betas", (0.9, 0.999)), + weight_decay=config.optim.get("weight_decay", 1e-2), + ) - total_steps = config.optim.get('total_training_steps', 0) - num_warmup_steps = int(config.optim.get('lr_warmup_steps', -1)) - warmup_style = config.optim.get('warmup_style', 'constant') + total_steps = config.optim.get("total_training_steps", 0) + num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) + warmup_style = config.optim.get("warmup_style", "constant") if num_warmup_steps < 0: - num_warmup_steps_ratio = config.optim.get('lr_warmup_steps_ratio', 0.) + num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup - if warmup_style == 'constant': - critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, - num_warmup_steps=num_warmup_steps) - elif warmup_style == 'cosine': - critic_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=critic_optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=total_steps) + + if warmup_style == "constant": + critic_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps + ) + elif warmup_style == "cosine": + critic_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps + ) else: - raise NotImplementedError(f'Warmup style {warmup_style} is not supported') + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") return critic_module, critic_optimizer, critic_lr_scheduler @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) from verl.workers.critic import DataParallelPPOCritic + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( - self.config) + self.config + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) - self.critic = DataParallelPPOCritic(config=self.config, - critic_module=self.critic_module, - critic_optimizer=self.critic_optimizer) + self.critic = DataParallelPPOCritic( + config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer + ) self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_manager = FSDPCheckpointManager( @@ -822,28 +891,28 @@ class CriticWorker(Worker): optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.checkpoint.contents) + checkpoint_contents=self.config.checkpoint.contents, + ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): - # Support all hardwares data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) micro_batch_size = self.config.forward_micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size - data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={'values': values}) + output = DataProto.from_dict(tensors={"values": values}) output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to('cpu') + output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) return output @@ -861,19 +930,19 @@ class CriticWorker(Worker): with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) - with Timer(name='update_critic', logger=None) as timer: + with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(data=data) delta_time = timer.last - global_num_tokens = data.meta_info['global_token_num'] + global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics['perf/mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] - metrics['critic/lr'] = lr + metrics["critic/lr"] = lr - output = DataProto(batch=None, meta_info={'metrics': metrics}) + output = DataProto(batch=None, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: @@ -881,19 +950,19 @@ class CriticWorker(Worker): if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) - output = output.to('cpu') + output = output.to("cpu") return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): import torch + if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) - self.checkpoint_manager.save_checkpoint(local_path=local_path, - hdfs_path=hdfs_path, - global_step=global_step, - max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) torch.distributed.barrier() if self._is_offload_param: @@ -902,12 +971,13 @@ class CriticWorker(Worker): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): import torch + if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) - self.checkpoint_manager.load_checkpoint(local_path=local_path, - hdfs_path=hdfs_path, - del_local_after_load=del_local_after_load) + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) torch.distributed.barrier() if self._is_offload_param: @@ -926,6 +996,7 @@ class RewardModelWorker(Worker): def __init__(self, config): super().__init__() import torch.distributed + if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.config = config @@ -938,16 +1009,16 @@ class RewardModelWorker(Worker): self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=['dp', 'sp']) + self.ulysses_device_mesh = init_device_mesh( + "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - self.use_remove_padding = self.config.model.get('use_remove_padding', False) + self.use_remove_padding = self.config.model.get("use_remove_padding", False) # normalize config if self.config.micro_batch_size is not None: @@ -956,8 +1027,9 @@ class RewardModelWorker(Worker): def _build_model(self, config): # the following line is necessary - from transformers import AutoModelForTokenClassification, AutoConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload + from torch.distributed.fsdp import CPUOffload + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from transformers import AutoConfig, AutoModelForTokenClassification # download the checkpoint from hdfs local_path = copy_to_local(config.model.path) @@ -967,29 +1039,34 @@ class RewardModelWorker(Worker): else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) - self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, - trust_remote_code=config.model.get('trust_remote_code', False)) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False)) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) - trust_remote_code = config.model.get('trust_remote_code', False) + trust_remote_code = config.model.get("trust_remote_code", False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, - mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(model_config, 'classifier_dropout', 0.) - reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + model_config.classifier_dropout = 0.0 + reward_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) - if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1: + if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) reward_module.to(torch.bfloat16) @@ -1009,63 +1086,64 @@ class RewardModelWorker(Worker): sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), forward_prefetch=False, - device_mesh=self.device_mesh) + device_mesh=self.device_mesh, + ) return reward_module @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange - from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - input_ids = micro_batch['input_ids'] + from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False) # prevent model thinks we are generating + output = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ) # prevent model thinks we are generating reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outpus_and_unpad(reward_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + reward_rmpad = gather_outpus_and_unpad( + reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad it back rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) else: - output = self.reward_module(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False) + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) @@ -1077,9 +1155,9 @@ class RewardModelWorker(Worker): def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): batch_size = data.batch.batch_size[0] # expand as token_level_reward - attention_mask = data.batch['attention_mask'] - position_ids = data.batch['position_ids'] - response_length = data.batch['responses'].shape[-1] + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + response_length = data.batch["responses"].shape[-1] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores @@ -1090,7 +1168,7 @@ class RewardModelWorker(Worker): return token_level_scores def _switch_chat_template(self, data: DataProto): - src_max_length = data.batch['attention_mask'].shape[-1] + src_max_length = data.batch["attention_mask"].shape[-1] src_tokenizer = self.input_tokenizer target_tokenizer = self.tokenizer @@ -1100,44 +1178,45 @@ class RewardModelWorker(Worker): for i in range(data.batch.batch_size[0]): # extract raw prompt - if isinstance(data.non_tensor_batch['raw_prompt'][i], list): - chat: list = data.non_tensor_batch['raw_prompt'][i] + if isinstance(data.non_tensor_batch["raw_prompt"][i], list): + chat: list = data.non_tensor_batch["raw_prompt"][i] else: - chat: list = data.non_tensor_batch['raw_prompt'][i].tolist() + chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() # extract response - response_ids = data.batch['responses'][i] + response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] - valid_response_length = data.batch['attention_mask'][i][-response_length:].sum() + valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode response = src_tokenizer.decode(valid_response_ids) # remove bos and eos - response = response.replace(src_tokenizer.eos_token, '') + response = response.replace(src_tokenizer.eos_token, "") - chat.append({'role': 'assistant', 'content': response}) + chat.append({"role": "assistant", "content": response}) - prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, - add_generation_prompt=False, - tokenize=False) + prompt_with_chat_template = target_tokenizer.apply_chat_template( + chat, add_generation_prompt=False, tokenize=False + ) if self.rank == 0 and i == 0: # for debugging purpose - print(f'Switch template. chat: {prompt_with_chat_template}') + print(f"Switch template. chat: {prompt_with_chat_template}") # the maximum length is actually determined by the reward model itself - max_length = self.config.get('max_length', src_max_length) + max_length = self.config.get("max_length", src_max_length) if max_length is None: max_length = src_max_length - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors='pt', add_special_tokens=False) + model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) input_ids, attention_mask = verl_F.postprocess_data( - input_ids=model_inputs['input_ids'], - attention_mask=model_inputs['attention_mask'], + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], max_length=max_length, pad_token_id=target_tokenizer.pad_token_id, left_pad=False, # right padding - truncation=self.config.get('truncation', 'right')) # truncate from the right + truncation=self.config.get("truncation", "right"), + ) # truncate from the right rm_input_ids.append(input_ids) rm_attention_mask.append(attention_mask) @@ -1147,26 +1226,28 @@ class RewardModelWorker(Worker): rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids} + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} return DataProto.from_dict(rm_inputs) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): import itertools - from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx + + from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + # Support all hardwares data = data.to(torch.cuda.current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: - rm_input_ids = data.batch['input_ids'] - rm_attention_mask = data.batch['attention_mask'] - rm_position_ids = data.batch['position_ids'] + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] rm_inputs = { - 'input_ids': rm_input_ids, - 'attention_mask': rm_attention_mask, - 'position_ids': rm_position_ids + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, } rm_data = DataProto.from_dict(rm_inputs) @@ -1198,12 +1279,12 @@ class RewardModelWorker(Worker): token_level_scores = self._expand_to_token_level(data, scores) # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={'rm_scores': token_level_scores}) + output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module self.reward_module._handle.reshard(True) - output = output.to('cpu') + output = output.to("cpu") return output diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index e231c4c4f..882e2a45a 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -15,49 +15,49 @@ The main entry point to run the PPO algorithm """ -import os import logging +import os import time -import ray + import torch import torch.distributed -import torch.nn as nn +from codetiming import Timer +from megatron.core import parallel_state as mpu from omegaconf import DictConfig +from verl import DataProto +from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.base.megatron.worker import MegatronWorker +from verl.utils import hf_tokenizer +from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.megatron_utils import ( + load_megatron_param_and_grad, + offload_megatron_param_and_grad, +) +from verl.utils.model import load_mcore_dist_weights, load_megatron_gptmodel_weights from verl.workers.actor.megatron_actor import MegatronPPOActor from verl.workers.critic.megatron_critic import MegatronPPOCritic -from verl.workers.sharding_manager import AllGatherPPModel from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel -from verl.single_controller.base.decorator import register, Dispatch -from verl import DataProto -from verl.utils.fs import copy_to_local -from verl.utils.debug import log_gpu_memory_usage -from verl.utils.model import load_megatron_model_weights, load_megatron_gptmodel_weights, load_mcore_dist_weights -from verl.utils.flops_counter import FlopsCounter -from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager -from verl.utils.megatron_utils import offload_megatron_param_and_grad, load_megatron_param_and_grad -from verl.utils import hf_tokenizer -from verl.third_party.vllm import vllm_version -from codetiming import Timer - -from megatron.core import parallel_state as mpu -from megatron.core import ModelParallelConfig - logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) def set_random_seed(seed): - import torch - import numpy as np import random + + import numpy as np + import torch + torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.device_count() > 0: from megatron.core import tensor_parallel + tensor_parallel.model_parallel_cuda_manual_seed(seed) # FIXME: torch cumsum not support deterministic (used in vllm sampler), # https://github.com/pytorch/pytorch/issues/89492 @@ -82,12 +82,12 @@ class ActorRolloutRefWorker(MegatronWorker): # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): - rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group(backend="nccl") torch.cuda.set_device(rank) if self.config.actor.megatron.sequence_parallel: - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" mpu.initialize_model_parallel( tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, @@ -102,11 +102,11 @@ class ActorRolloutRefWorker(MegatronWorker): set_random_seed(seed=self.config.actor.megatron.seed) self.role = role - assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref'] + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] - self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref'] - self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref'] - self._is_ref = self.role in ['ref', 'actor_rollout_ref'] + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] # TODO(sgm): Currently, we only support reference model param offload # will support other offload later @@ -118,85 +118,88 @@ class ActorRolloutRefWorker(MegatronWorker): if self._is_actor and self._is_rollout: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() - if self.config.actor.get('ppo_micro_batch_size', None): + if self.config.actor.get("ppo_micro_batch_size", None): self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size - self._is_offload_param = self.config.actor.get('param_offload', False) - self._is_offload_grad = self.config.actor.get('grad_offload', False) - self._is_offload_optimizer = self.config.actor.get('optimizer_offload', False) + self._is_offload_param = self.config.actor.get("param_offload", False) + self._is_offload_grad = self.config.actor.get("grad_offload", False) + self._is_offload_optimizer = self.config.actor.get("optimizer_offload", False) elif self._is_ref: - if self.config.ref.get('ppo_micro_batch_size', None): + if self.config.ref.get("ppo_micro_batch_size", None): self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.ppo_micro_batch_size - self._is_offload_param = self.config.ref.get('param_offload', False) + self._is_offload_param = self.config.ref.get("param_offload", False) def _build_model_optimizer(self, model_path, optim_config, override_model_config): - from verl.utils.megatron.optimizer import get_megatron_optimizer from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.model import print_model_size, get_generation_config + + from verl.utils.megatron.optimizer import get_megatron_optimizer from verl.utils.megatron_utils import get_model, init_megatron_optim_config + from verl.utils.model import get_generation_config, print_model_size self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config) self.generation_config = get_generation_config(self.local_path) def megatron_actor_model_provider(pre_process, post_process): from verl.models.mcore import init_mcore_model + parallel_model = init_mcore_model( self.tf_config, self.hf_config, pre_process, post_process, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - value=False) + value=False, + ) parallel_model.cuda() return parallel_model # Step 3: initialize the megatron model if self._is_actor and self._is_rollout: - actor_module = get_model(megatron_actor_model_provider, - wrap_with_ddp=True, - use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer) - print(f'actor_module: {len(actor_module)}') + actor_module = get_model( + megatron_actor_model_provider, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + ) + print(f"actor_module: {len(actor_module)}") if self.config.actor.load_weight: if self.config.actor.megatron.use_dist_checkpointing: - load_mcore_dist_weights(actor_module, - self.config.actor.megatron.dist_checkpointing_path, - is_value_model=False) + load_mcore_dist_weights( + actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False + ) else: - load_megatron_gptmodel_weights(self.config, - self.hf_config, - actor_module, - params_dtype=self.dtype, - is_value_model=False) + load_megatron_gptmodel_weights( + self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False + ) if self.rank == 0: print_model_size(actor_module[0]) - log_gpu_memory_usage('After AllGatherPPModel init', logger=logger) + log_gpu_memory_usage("After AllGatherPPModel init", logger=logger) elif self._is_ref: - print(f'self.config.ref.load_weight: {self.config.ref.load_weight}') - ref_module = get_model(model_provider_func=megatron_actor_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer) + print(f"self.config.ref.load_weight: {self.config.ref.load_weight}") + ref_module = get_model( + model_provider_func=megatron_actor_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer, + ) # ref_module = nn.ModuleList(ref_module) if self.config.ref.load_weight: # should align with the actor: assert self.config.actor.load_weight == self.config.ref.load_weight - print(f'load ref weight start') + print("load ref weight start") if self.config.ref.megatron.use_dist_checkpointing: - load_mcore_dist_weights(ref_module, - self.config.ref.megatron.dist_checkpointing_path, - is_value_model=False) + load_mcore_dist_weights( + ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False + ) else: - load_megatron_gptmodel_weights(self.config, - self.hf_config, - ref_module, - params_dtype=self.dtype, - is_value_model=False) - log_gpu_memory_usage('After ref module init', logger=logger) + load_megatron_gptmodel_weights( + self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False + ) + log_gpu_memory_usage("After ref module init", logger=logger) return ref_module, self.hf_config # TODO: add more optimizer args into config @@ -207,16 +210,17 @@ class ActorRolloutRefWorker(MegatronWorker): optim_config = None actor_optimizer = None - log_gpu_memory_usage('After actor optimizer init', logger=logger) + log_gpu_memory_usage("After actor optimizer init", logger=logger) return actor_module, actor_optimizer, self.hf_config, optim_config def _build_rollout(self, trust_remote_code=False): - if self.config.rollout.name == 'vllm': - from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode - from verl.workers.sharding_manager import MegatronVLLMShardingManager + if self.config.rollout.name == "vllm": from torch.distributed.device_mesh import init_device_mesh + from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout + from verl.workers.sharding_manager import MegatronVLLMShardingManager + # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, # we will reorganize their weight format when resharding from actor to rollout. layer_name_mapping = { @@ -226,47 +230,57 @@ class ActorRolloutRefWorker(MegatronWorker): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' - rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) - log_gpu_memory_usage(f'Before building vllm rollout', logger=None) + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + log_gpu_memory_usage("Before building vllm rollout", logger=None) - from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode local_path = copy_to_local(self.config.model.path) - if vllm_mode == 'customized': - rollout = vLLMRollout(actor_module=self.actor_module, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config) - elif vllm_mode == 'spmd': - rollout = vLLMRollout(model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code) - log_gpu_memory_usage('After building vllm rollout', logger=logger) + if vllm_mode == "customized": + rollout = vLLMRollout( + actor_module=self.actor_module, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + ) + elif vllm_mode == "spmd": + rollout = vLLMRollout( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + ) + log_gpu_memory_usage("After building vllm rollout", logger=logger) # perform weight resharding between actor and rollout - sharding_manager = MegatronVLLMShardingManager(inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - layer_name_mapping=layer_name_mapping, - actor_module=self.actor.actor_module) - log_gpu_memory_usage('After building sharding manager', logger=logger) + sharding_manager = MegatronVLLMShardingManager( + inference_engine=rollout.inference_engine, + model_config=self.actor_model_config, + layer_name_mapping=layer_name_mapping, + actor_module=self.actor.actor_module, + ) + log_gpu_memory_usage("After building sharding manager", logger=logger) else: - raise NotImplementedError('Only vllmRollout is supported with Megatron now') + raise NotImplementedError("Only vllmRollout is supported with Megatron now") return rollout, sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - if self.config.model.get('external_lib', None) is not None: + if self.config.model.get("external_lib", None) is not None: # This is used to import external_lib into the huggingface systems import importlib + importlib.import_module(self.config.model.external_lib) from omegaconf import OmegaConf + from verl.utils.torch_dtypes import PrecisionType - override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) @@ -276,11 +290,12 @@ class ActorRolloutRefWorker(MegatronWorker): optim_config = self.config.actor.optim else: optim_config = None - self.actor_module, self.actor_optimizer, \ - self.actor_model_config, self.actor_optim_config = self._build_model_optimizer( - model_path=self.config.model.path, - optim_config=optim_config, - override_model_config=override_model_config + self.actor_module, self.actor_optimizer, self.actor_model_config, self.actor_optim_config = ( + self._build_model_optimizer( + model_path=self.config.model.path, + optim_config=optim_config, + override_model_config=override_model_config, + ) ) if self._is_actor: @@ -295,7 +310,8 @@ class ActorRolloutRefWorker(MegatronWorker): if self._is_rollout: self.rollout, self.sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get('trust_remote_code', False)) + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) if self._is_ref: self.ref_module, self.ref_model_config = self._build_model_optimizer( @@ -303,19 +319,21 @@ class ActorRolloutRefWorker(MegatronWorker): optim_config=None, override_model_config=override_model_config, ) - self.ref_policy = MegatronPPOActor(config=self.config.ref, - model_config=self.ref_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - actor_module=self.ref_module, - actor_optimizer=None) + self.ref_policy = MegatronPPOActor( + config=self.config.ref, + model_config=self.ref_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + actor_module=self.ref_module, + actor_optimizer=None, + ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_mananager = MegatronCheckpointManager( config=self.config, model_config=self.actor_model_config, - role='actor', + role="actor", model=self.actor_module, arch=self.architectures[0], hf_config=self.hf_config, @@ -324,7 +342,8 @@ class ActorRolloutRefWorker(MegatronWorker): tokenizer=self.tokenizer, optimizer=self.actor_optimizer, use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, - checkpoint_contents=self.config.actor.checkpoint.contents) + checkpoint_contents=self.config.actor.checkpoint.contents, + ) torch.cuda.empty_cache() @@ -334,23 +353,23 @@ class ActorRolloutRefWorker(MegatronWorker): data.batch = data.batch.cuda() - log_gpu_memory_usage('Before update policy', logger=logger) + log_gpu_memory_usage("Before update policy", logger=logger) micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size + data.meta_info["micro_batch_size"] = micro_batch_size dataloader = self.actor.make_minibatch_iterator(data=data) - with Timer(name='update_policy', logger=None) as timer: + with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(dataloader=dataloader) delta_time = timer.last - global_num_tokens = data.meta_info['global_token_num'] + global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics['perf/mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - log_gpu_memory_usage('After update policy', logger=logger) + log_gpu_memory_usage("After update policy", logger=logger) # TODO: here, we should return all metrics - output = DataProto(meta_info={'metrics': metrics}) - output = output.to('cpu') + output = DataProto(meta_info={"metrics": metrics}) + output = output.to("cpu") torch.cuda.empty_cache() return output @@ -360,44 +379,44 @@ class ActorRolloutRefWorker(MegatronWorker): prompts.batch = prompts.batch.cuda() meta_info = { - 'eos_token_id': - self.generation_config.eos_token_id - if self.generation_config is not None else self.tokenizer.eos_token_id, - 'pad_token_id': - self.generation_config.pad_token_id - if self.generation_config is not None else self.tokenizer.pad_token_id, + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info) with self.sharding_manager: - log_gpu_memory_usage('After entering sharding manager', logger=logger) + log_gpu_memory_usage("After entering sharding manager", logger=logger) prompts = self.sharding_manager.preprocess_data(prompts) output = self.rollout.generate_sequences(prompts=prompts) - log_gpu_memory_usage('After rollout generation', logger=logger) + log_gpu_memory_usage("After rollout generation", logger=logger) output = self.sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # clear kv cache torch.cuda.empty_cache() - log_gpu_memory_usage('After generate_sequences', logger=logger) + log_gpu_memory_usage("After generate_sequences", logger=logger) return output @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): - data = data.to('cuda') + data = data.to("cuda") assert self._is_ref if self._is_offload_param: load_megatron_param_and_grad(self.ref_module, torch.cuda.current_device(), self._is_offload_grad) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size - data.meta_info['temperature'] = self.config.rollout.temperature + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - output = DataProto.from_dict(tensors={'ref_log_prob': output}) - output = output.to('cpu') + output = DataProto.from_dict(tensors={"ref_log_prob": output}) + output = output.to("cpu") if self._is_offload_param: offload_megatron_param_and_grad(self.ref_module, self._is_offload_grad) torch.cuda.empty_cache() @@ -406,25 +425,25 @@ class ActorRolloutRefWorker(MegatronWorker): @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): assert self._is_actor - data = data.to('cuda') + data = data.to("cuda") output = data # we should always recompute old_log_probs when it is HybridEngine - output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu - output.meta_info['temperature'] = self.config.rollout.temperature + output.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + output.meta_info["temperature"] = self.config.rollout.temperature old_log_probs, entropys = self.actor.compute_log_prob(data=output, calculate_entropy=True) - output.batch['old_log_probs'] = old_log_probs - output.batch['entropys'] = entropys - output = output.to('cpu') + output.batch["old_log_probs"] = old_log_probs + output.batch["entropys"] = entropys + output = output.to("cpu") # clear kv cache torch.cuda.empty_cache() - log_gpu_memory_usage('After generate_sequences', logger=logger) + log_gpu_memory_usage("After generate_sequences", logger=logger) return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): - self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path, - hdfs_path=hdfs_path, - del_local_after_load=del_local_after_load) + self.checkpoint_mananager.load_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_pretrained_model(self, checkpoint_path, del_local_after_load=True): @@ -432,14 +451,12 @@ class ActorRolloutRefWorker(MegatronWorker): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - self.checkpoint_mananager.save_checkpoint(local_path=checkpoint_path, - hdfs_path=hdfs_path, - global_step=global_step, - max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_mananager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) class CriticWorker(MegatronWorker): - def __init__(self, config): super().__init__() self.config = config @@ -451,12 +468,12 @@ class CriticWorker(MegatronWorker): # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): - rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group(backend="nccl") torch.cuda.set_device(rank) if self.config.megatron.sequence_parallel: - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" mpu.initialize_model_parallel( tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, @@ -473,7 +490,7 @@ class CriticWorker(MegatronWorker): # normalize config self.config.ppo_mini_batch_size *= self.config.rollout_n self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() - if self.config.get('ppo_micro_batch_size', None): + if self.config.get("ppo_micro_batch_size", None): self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size @@ -481,28 +498,34 @@ class CriticWorker(MegatronWorker): def _build_critic_model_optimizer(self, model_path, optim_config, override_model_config): from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.model import print_model_size + from verl.utils.megatron.optimizer import get_megatron_optimizer from verl.utils.megatron_utils import get_model, init_megatron_optim_config + from verl.utils.model import print_model_size self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config) def megatron_critic_model_provider(pre_process, post_process): from verl.models.mcore import init_mcore_model - parallel_model = init_mcore_model(self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True) + + parallel_model = init_mcore_model( + self.tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + value=True, + ) parallel_model.cuda() return parallel_model # Step 3: initialize the megatron model - critic_module = get_model(model_provider_func=megatron_critic_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer) + critic_module = get_model( + model_provider_func=megatron_critic_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + ) # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). # but here, we do not use pp (vpp) yet. For simplicity, we remove the list # critic_module = nn.ModuleList(critic_module) @@ -510,18 +533,16 @@ class CriticWorker(MegatronWorker): if self.config.load_weight: t0 = time.time() if self.config.megatron.use_dist_checkpointing: - load_mcore_dist_weights(critic_module, - self.config.megatron.dist_checkpointing_path, - is_value_model=True) + load_mcore_dist_weights( + critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True + ) else: - load_megatron_gptmodel_weights(self.config, - self.hf_config, - critic_module, - params_dtype=self.dtype, - is_value_model=True) + load_megatron_gptmodel_weights( + self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True + ) t1 = time.time() if torch.distributed.get_rank() == 0: - print(f'critic load_weight time: {t1 - t0}') + print(f"critic load_weight time: {t1 - t0}") if self.rank == 0: print_model_size(critic_module[0]) @@ -535,31 +556,38 @@ class CriticWorker(MegatronWorker): def init_model(self): # create critic from omegaconf import OmegaConf + from verl.utils.torch_dtypes import PrecisionType - if self.config.model.get('external_lib', None) is not None: + if self.config.model.get("external_lib", None) is not None: # This is used to import external_lib into the huggingface systems import importlib + importlib.import_module(self.config.model.external_lib) - override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) - self.critic_module, self.critic_optimizer, self.critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer( - model_path=self.config.model.path, - optim_config=self.config.optim, - override_model_config=override_model_config) - self.critic = MegatronPPOCritic(config=self.config, - model_config=self.critic_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - critic_module=self.critic_module, - critic_optimizer=self.critic_optimizer, - critic_optimizer_config=critic_optimizer_config) + self.critic_module, self.critic_optimizer, self.critic_model_config, critic_optimizer_config = ( + self._build_critic_model_optimizer( + model_path=self.config.model.path, + optim_config=self.config.optim, + override_model_config=override_model_config, + ) + ) + self.critic = MegatronPPOCritic( + config=self.config, + model_config=self.critic_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + critic_module=self.critic_module, + critic_optimizer=self.critic_optimizer, + critic_optimizer_config=critic_optimizer_config, + ) self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_mananager = MegatronCheckpointManager( config=self.config, model_config=self.critic_model_config, - role='critic', + role="critic", model=self.critic_module, arch=self.architectures[0], hf_config=self.hf_config, @@ -568,42 +596,42 @@ class CriticWorker(MegatronWorker): tokenizer=self.tokenizer, optimizer=self.critic_optimizer, use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - checkpoint_contents=self.config.checkpoint.contents) + checkpoint_contents=self.config.checkpoint.contents, + ) @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def compute_values(self, data: DataProto): - data = data.to('cuda') + data = data.to("cuda") values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={'values': values}) - output = output.to('cpu') + output = DataProto.from_dict(tensors={"values": values}) + output = output.to("cpu") return output @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def update_critic(self, data: DataProto): - data = data.to('cuda') + data = data.to("cuda") dataloader = self.critic.make_minibatch_iterator(data) - with Timer(name='update_critic', logger=None) as timer: + with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(dataloader=dataloader) delta_time = timer.last - global_num_tokens = data.meta_info['global_token_num'] + global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics['perf/mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - output = DataProto(batch=None, meta_info={'metrics': metrics}) - output = output.to('cpu') + metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + output = DataProto(batch=None, meta_info={"metrics": metrics}) + output = output.to("cpu") return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): - self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path, - hdfs_path=hdfs_path, - del_local_after_load=del_local_after_load) + self.checkpoint_mananager.load_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None): - self.checkpoint_mananager.save_checkpoint(local_path=checkpoint_path, - hdfs_path=hdfs_path, - global_step=global_steps, - max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_mananager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep + ) class RewardModelWorker(MegatronWorker): @@ -622,12 +650,12 @@ class RewardModelWorker(MegatronWorker): # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): - rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group(backend="nccl") torch.cuda.set_device(rank) if self.config.megatron.sequence_parallel: - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" mpu.initialize_model_parallel( tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, @@ -648,28 +676,32 @@ class RewardModelWorker(MegatronWorker): def _build_rm_model(self, model_path, override_model_config): from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.model import print_model_size - from verl.utils.megatron.optimizer import get_megatron_optimizer - from verl.utils.megatron_utils import get_model, init_megatron_optim_config + + from verl.utils.megatron_utils import get_model self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config) def megatron_rm_model_provider(pre_process, post_process): from verl.models.mcore import init_mcore_model - parallel_model = init_mcore_model(self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True) + + parallel_model = init_mcore_model( + self.tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + value=True, + ) parallel_model.cuda() return parallel_model # Step 3: initialize the megatron model - reward_model = get_model(model_provider_func=megatron_rm_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - use_distributed_optimizer=self.config.reward_model.use_distributed_optimizer) + reward_model = get_model( + model_provider_func=megatron_rm_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + use_distributed_optimizer=self.config.reward_model.use_distributed_optimizer, + ) # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). # but here, we do not use pp (vpp) yet. For simplicity, we remove the list # reward_model = nn.ModuleList(reward_model) @@ -678,11 +710,9 @@ class RewardModelWorker(MegatronWorker): if self.config.megatron.use_dist_checkpointing: load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True) else: - load_megatron_gptmodel_weights(self.config, - self.hf_config, - reward_model, - params_dtype=self.dtype, - is_value_model=True) + load_megatron_gptmodel_weights( + self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True + ) # TODO: add more optimizer args into config torch.cuda.empty_cache() @@ -692,17 +722,19 @@ class RewardModelWorker(MegatronWorker): def init_model(self): # create critic from omegaconf import OmegaConf + from verl.utils.torch_dtypes import PrecisionType - if self.config.model.get('external_lib', None) is not None: + if self.config.model.get("external_lib", None) is not None: # This is used to import external_lib into the huggingface systems import importlib + importlib.import_module(self.config.model.external_lib) - override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer) sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path) - rm_tokenizer_path = self.config.model.get('rm_tokenizer', None) + rm_tokenizer_path = self.config.model.get("rm_tokenizer", None) rm_tokenizer = None if rm_tokenizer_path is not None: rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path) @@ -717,13 +749,15 @@ class RewardModelWorker(MegatronWorker): ) # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel # should be implemented in workers - self.rm = MegatronRewardModel(config=self.config, - reward_model_module=reward_model_module, - model_config=reward_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - sft_tokenizer=sft_tokenizer, - rm_tokenizer=rm_tokenizer) + self.rm = MegatronRewardModel( + config=self.config, + reward_model_module=reward_model_module, + model_config=reward_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + sft_tokenizer=sft_tokenizer, + rm_tokenizer=rm_tokenizer, + ) # TODO: reward model use itself tokenizer instead of sft tokenizer # the input_ids, responses, attention_mask and position_ids may be different! @@ -731,5 +765,5 @@ class RewardModelWorker(MegatronWorker): def compute_rm_score(self, data: DataProto): data.batch = data.batch.cuda() output = self.rm.compute_reward(data) - output = output.to('cpu') + output = output.to("cpu") return output diff --git a/verl/workers/reward_manager/__init__.py b/verl/workers/reward_manager/__init__.py index 04300d949..8d039fe9e 100644 --- a/verl/workers/reward_manager/__init__.py +++ b/verl/workers/reward_manager/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .naive import NaiveRewardManager -from .prime import PrimeRewardManager from .batch import BatchRewardManager from .dapo import DAPORewardManager +from .naive import NaiveRewardManager +from .prime import PrimeRewardManager diff --git a/verl/workers/reward_manager/batch.py b/verl/workers/reward_manager/batch.py index ed92f9c4f..570fdd71d 100644 --- a/verl/workers/reward_manager/batch.py +++ b/verl/workers/reward_manager/batch.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from verl import DataProto from collections import defaultdict +import torch + +from verl import DataProto + class BatchRewardManager: - - def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key='data_source', **reward_kwargs): + def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs): self.tokenizer = tokenizer self.num_examine = num_examine self.compute_score = compute_score @@ -27,9 +28,9 @@ class BatchRewardManager: self.reward_kwargs = reward_kwargs def verify(self, data): - prompt_ids = data.batch['prompts'] - response_ids = data.batch['responses'] - attention_mask = data.batch['attention_mask'] + prompt_ids = data.batch["prompts"] + response_ids = data.batch["responses"] + attention_mask = data.batch["attention_mask"] prompt_len = prompt_ids.shape[-1] valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1) @@ -41,32 +42,33 @@ class BatchRewardManager: response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) responses_str.append(response_str) - ground_truths = [item.non_tensor_batch['reward_model'].get('ground_truth', None) for item in data] + ground_truths = [item.non_tensor_batch["reward_model"].get("ground_truth", None) for item in data] data_sources = data.non_tensor_batch[self.reward_fn_key] - extras = data.non_tensor_batch.get('extra_info', [None] * len(data)) + extras = data.non_tensor_batch.get("extra_info", [None] * len(data)) - scores = self.compute_score(data_sources=data_sources, - solution_strs=responses_str, - ground_truths=ground_truths, - extra_infos=extras, - **self.reward_kwargs) + scores = self.compute_score( + data_sources=data_sources, + solution_strs=responses_str, + ground_truths=ground_truths, + extra_infos=extras, + **self.reward_kwargs, + ) return scores def __call__(self, data: DataProto, return_dict=False): - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): + if "rm_scores" in data.batch.keys(): if return_dict: - return {"reward_tensor": data.batch['rm_scores']} + return {"reward_tensor": data.batch["rm_scores"]} else: - return data.batch['rm_scores'] + return data.batch["rm_scores"] - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_extra_info = defaultdict(list) - prompt_ids = data.batch['prompts'] + prompt_ids = data.batch["prompts"] prompt_len = prompt_ids.shape[-1] - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1) data_sources = data.non_tensor_batch[self.reward_fn_key] @@ -90,16 +92,16 @@ class BatchRewardManager: data_source = data_sources[i] if already_printed.get(data_source, 0) < self.num_examine: - response_str = self.tokenizer.decode(data.batch['responses'][i][:length], skip_special_tokens=True) - prompt_str = self.tokenizer.decode(data.batch['prompts'][i], skip_special_tokens=True) - ground_truth = data[i].non_tensor_batch['reward_model'].get('ground_truth', None) + response_str = self.tokenizer.decode(data.batch["responses"][i][:length], skip_special_tokens=True) + prompt_str = self.tokenizer.decode(data.batch["prompts"][i], skip_special_tokens=True) + ground_truth = data[i].non_tensor_batch["reward_model"].get("ground_truth", None) print("[prompt]", prompt_str) print("[response]", response_str) print("[ground_truth]", ground_truth) print("[score]", scores[i]) already_printed[data_source] = already_printed.get(data_source, 0) + 1 - data.batch['acc'] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device) + data.batch["acc"] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device) if return_dict: return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index a5ef53e5a..98a17adf4 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -12,23 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict + +import torch + from verl import DataProto from verl.utils.reward_score import _default_compute_score -import torch -from collections import defaultdict class DAPORewardManager: - """The reward manager. - """ + """The reward manager.""" - def __init__(self, - tokenizer, - num_examine, - compute_score=None, - reward_fn_key='data_source', - max_resp_len=None, - overlong_buffer_cfg=None) -> None: + def __init__( + self, + tokenizer, + num_examine, + compute_score=None, + reward_fn_key="data_source", + max_resp_len=None, + overlong_buffer_cfg=None, + ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score @@ -37,19 +40,21 @@ class DAPORewardManager: self.max_resp_len = max_resp_len if self.overlong_buffer_cfg is not None: - assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + assert self.max_resp_len is not None, ( + f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + ) def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): + if "rm_scores" in data.batch.keys(): if return_dict: - return {"reward_tensor": data.batch['rm_scores']} + return {"reward_tensor": data.batch["rm_scores"]} else: - return data.batch['rm_scores'] + return data.batch["rm_scores"] - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_extra_info = defaultdict(list) already_print_data_sources = {} @@ -57,15 +62,15 @@ class DAPORewardManager: for i in range(len(data)): data_item = data[i] # DataProtoItem - prompt_ids = data_item.batch['prompts'] + prompt_ids = data_item.batch["prompts"] prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] - response_ids = data_item.batch['responses'] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode @@ -73,13 +78,13 @@ class DAPORewardManager: response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) eos_token = self.tokenizer.eos_token if response_str.endswith(eos_token): - response_str = response_str[:-len(eos_token)] + response_str = response_str[: -len(eos_token)] - ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] data_source = data_item.non_tensor_batch[self.reward_fn_key] - extra_info = data_item.non_tensor_batch.get('extra_info', None) + extra_info = data_item.non_tensor_batch.get("extra_info", None) result = self.compute_score( data_source=data_source, @@ -124,7 +129,7 @@ class DAPORewardManager: for key, value in result.items(): print(f"[{key}]", value) else: - print(f"[score]", score) + print("[score]", score) if return_dict: return { diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index ed50206f8..3a59dc8b2 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict + +import torch + from verl import DataProto from verl.utils.reward_score import _default_compute_score -import torch -from collections import defaultdict class NaiveRewardManager: - """The reward manager. - """ + """The reward manager.""" - def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None: + def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score @@ -32,13 +33,13 @@ class NaiveRewardManager: """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): + if "rm_scores" in data.batch.keys(): if return_dict: - return {"reward_tensor": data.batch['rm_scores']} + return {"reward_tensor": data.batch["rm_scores"]} else: - return data.batch['rm_scores'] + return data.batch["rm_scores"] - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_extra_info = defaultdict(list) already_print_data_sources = {} @@ -46,26 +47,26 @@ class NaiveRewardManager: for i in range(len(data)): data_item = data[i] # DataProtoItem - prompt_ids = data_item.batch['prompts'] + prompt_ids = data_item.batch["prompts"] prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] - response_ids = data_item.batch['responses'] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) - ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] data_source = data_item.non_tensor_batch[self.reward_fn_key] - extra_info = data_item.non_tensor_batch.get('extra_info', None) + extra_info = data_item.non_tensor_batch.get("extra_info", None) score = self.compute_score( data_source=data_source, @@ -96,7 +97,7 @@ class NaiveRewardManager: for key, value in score.items(): print(f"[{key}]", value) else: - print(f"[score]", score) + print("[score]", score) if return_dict: return { diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index fcf3e069b..8bfee5d4f 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -16,6 +16,7 @@ import asyncio from concurrent.futures import ProcessPoolExecutor from functools import partial from typing import Callable, Optional + import torch from transformers import PreTrainedTokenizer @@ -23,7 +24,7 @@ from verl import DataProto from verl.utils.reward_score import _default_compute_score -async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.): +async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): loop = asyncio.get_running_loop() try: # Ensure process_completion is called properly @@ -31,9 +32,10 @@ async def single_compute_score(evaluation_func, completion, reference, task, tas asyncio.wait_for( loop.run_in_executor( executor, - partial(evaluation_func, task, completion, reference, task_extra_info) # Ensure synchronous + partial(evaluation_func, task, completion, reference, task_extra_info), # Ensure synchronous ), - timeout=timeout) + timeout=timeout, + ) ] return await asyncio.gather(*tasks) except asyncio.TimeoutError: @@ -44,19 +46,16 @@ async def single_compute_score(evaluation_func, completion, reference, task, tas return None # Default value for failed rows -async def parallel_compute_score_async(evaluation_func, - completions, - references, - tasks, - extra_info=None, - num_processes=64): +async def parallel_compute_score_async( + evaluation_func, completions, references, tasks, extra_info=None, num_processes=64 +): scores = [] with ProcessPoolExecutor(max_workers=num_processes) as executor: if extra_info is None: extra_info = [None] * len(tasks) # Create tasks for all rows tasks_async = [ - single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.) + single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0) for completion, reference, task, task_extra_info in zip(completions, references, tasks, extra_info) ] # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed. @@ -67,7 +66,7 @@ async def parallel_compute_score_async(evaluation_func, try: proc.kill() except Exception as kill_err: - print('shut down failed: ' + str(kill_err)) + print("shut down failed: " + str(kill_err)) raise # Process results @@ -92,7 +91,7 @@ class PrimeRewardManager: tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: Optional[Callable] = None, - reward_fn_key: str = 'data_source', + reward_fn_key: str = "data_source", ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console @@ -104,51 +103,54 @@ class PrimeRewardManager: verify the batch and save as ``acc`` tensor """ # batched scoring - prompt_ids = data.batch['prompts'] + prompt_ids = data.batch["prompts"] - response_ids = data.batch['responses'] + response_ids = data.batch["responses"] sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) - ground_truth = [data_item.non_tensor_batch['reward_model']['ground_truth'] for data_item in data] + ground_truth = [data_item.non_tensor_batch["reward_model"]["ground_truth"] for data_item in data] data_sources = data.non_tensor_batch[self.reward_fn_key] - extra_info = data.non_tensor_batch.get('extra_info', None) + extra_info = data.non_tensor_batch.get("extra_info", None) assert len(sequences_str) == len(ground_truth) == len(data_sources) try: scores = asyncio.run( - parallel_compute_score_async(self.compute_score, - sequences_str, - ground_truth, - data_sources, - extra_info=extra_info, - num_processes=64)) - except asyncio.TimeoutError as e: - print('Global timeout in reward computing! Setting all as 0.') - scores = [0. for _ in range(len(sequences_str))] + parallel_compute_score_async( + self.compute_score, + sequences_str, + ground_truth, + data_sources, + extra_info=extra_info, + num_processes=64, + ) + ) + except asyncio.TimeoutError: + print("Global timeout in reward computing! Setting all as 0.") + scores = [0.0 for _ in range(len(sequences_str))] except Exception as e: print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}") - scores = [0. for _ in range(len(sequences_str))] - data.batch['acc'] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) + scores = [0.0 for _ in range(len(sequences_str))] + data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) return scores def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] + if "rm_scores" in data.batch.keys(): + return data.batch["rm_scores"] - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) already_print_data_sources = {} # batched scoring - prompt_ids = data.batch['prompts'] + prompt_ids = data.batch["prompts"] prompt_length = prompt_ids.shape[-1] - response_ids = data.batch['responses'] - valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(dim=-1) + response_ids = data.batch["responses"] + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1) sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) - data_sources = data.non_tensor_batch['data_source'] + data_sources = data.non_tensor_batch["data_source"] scores = self.verify(data) diff --git a/verl/workers/reward_model/base.py b/verl/workers/reward_model/base.py index c02487db3..cb719bd0f 100644 --- a/verl/workers/reward_model/base.py +++ b/verl/workers/reward_model/base.py @@ -21,7 +21,6 @@ from verl import DataProto class BasePPORewardModel(ABC): - def __init__(self, config): self.config = config diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 54027dfcc..f01bde579 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -15,36 +15,35 @@ Megatron Reward Model. """ -from tensordict import TensorDict -from verl import DataProto import torch import torch.distributed - -from verl.utils.torch_functional import pad_sequence_to_length -from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) -from verl import DataProto -from verl.utils.torch_functional import broadcast_dict_tensor, split_dict_tensor_into_batches -from verl.workers.reward_model.base import BasePPORewardModel from megatron.core import parallel_state as mpu from megatron.core.pipeline_parallel import get_forward_backward_func +from tensordict import TensorDict + +from verl import DataProto +from verl.utils.megatron.pipeline_parallel import compute_transformers_input_shapes, make_batch_generator +from verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length, split_dict_tensor_into_batches +from verl.workers.reward_model.base import BasePPORewardModel class MegatronRewardModel(BasePPORewardModel): - - def __init__(self, - config, - model_config, - reward_model_module: torch.nn.ModuleList, - hf_config, - tf_config, - sft_tokenizer=None, - rm_tokenizer=None): + def __init__( + self, + config, + model_config, + reward_model_module: torch.nn.ModuleList, + hf_config, + tf_config, + sft_tokenizer=None, + rm_tokenizer=None, + ): self.config = config self.reward_model_module = reward_model_module self.hf_config = hf_config self.tf_config = tf_config self.model_config = model_config - self.device = 'cuda' + self.device = "cuda" self.sft_tokenizer = sft_tokenizer self.rm_tokenizer = rm_tokenizer self.use_different_tokenizer = rm_tokenizer is not None @@ -53,16 +52,16 @@ class MegatronRewardModel(BasePPORewardModel): self.offload_params_to_cpu() def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto: - assert self.use_different_tokenizer, 're-encode need rm tokenizer not be None!' + assert self.use_different_tokenizer, "re-encode need rm tokenizer not be None!" # need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids # 1. remove pad for each sequence # 2. decode by sft_tokenizer, remove sft system prompts # 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids # 4. generate attention_mask and position_ids - input_ids = data.batch['input_ids'] # (bs, seq_len) - attention_mask = data.batch['attention_mask'] - position_ids = data.batch['position_ids'] - ori_values = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids} + input_ids = data.batch["input_ids"] # (bs, seq_len) + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + ori_values = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} ori_bs, ori_seqlen = input_ids.size(0), input_ids.size(1) input_ids_for_rm = [] attention_mask_for_rm = [] @@ -73,27 +72,33 @@ class MegatronRewardModel(BasePPORewardModel): # 1. remove pad for each sequence non_zero_indices = torch.nonzero(mask).view(-1) begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item() - valid_id = id[begin_pos:end_pos + 1] + valid_id = id[begin_pos : end_pos + 1] # 2. decode by sft_tokenizer, remove sft system prompts decode_result = self.sft_tokenizer.decode(valid_id) # workaround - decode_with_rm_chat = decode_result.replace("<|user|>\n", "[INST] ").replace( - "\n<|assistant|>\n", " [/INST]").replace(" \n<|assistant|>\n", " [/INST]") + "" + decode_with_rm_chat = ( + decode_result.replace("<|user|>\n", "[INST] ") + .replace("\n<|assistant|>\n", " [/INST]") + .replace(" \n<|assistant|>\n", " [/INST]") + + "" + ) if print_decode and torch.distributed.get_rank() == 0: # only print first decode result - print(f'device {torch.cuda.current_device()}: sft decode result:\n{decode_result}\n \ - \ndevice {torch.cuda.current_device()}: sft decode result with rm chat template:\n{decode_with_rm_chat}\n\n' - ) + print( + f"device {torch.cuda.current_device()}: sft decode result:\n{decode_result}\n \ + \ndevice {torch.cuda.current_device()}: sft decode result with rm chat template:\n{decode_with_rm_chat}\n\n" + ) print_decode = False # 3. encode by rm_tokenizer - rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, - return_tensors='pt')['input_ids'][0].to(input_ids.device) + rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors="pt")["input_ids"][0].to( + input_ids.device + ) # 4. generate attention_mask and position_ids rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device) cur_seqlen = rm_input_ids.shape[-1] # NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128) if cur_seqlen > ori_seqlen: - print(f'warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}') + print(f"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}") rm_input_ids = rm_input_ids[:ori_seqlen] rm_attention_mask = rm_attention_mask[:ori_seqlen] else: @@ -110,9 +115,9 @@ class MegatronRewardModel(BasePPORewardModel): # (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change # NOTE(gh): need to replace into origin values after compute reward! - data.batch['input_ids'] = input_ids_for_rm - data.batch['attention_mask'] = attention_mask_for_rm - data.batch['position_ids'] = position_ids_for_rm + data.batch["input_ids"] = input_ids_for_rm + data.batch["attention_mask"] = attention_mask_for_rm + data.batch["position_ids"] = position_ids_for_rm return data, ori_values @@ -124,11 +129,11 @@ class MegatronRewardModel(BasePPORewardModel): if self.use_different_tokenizer: data, ori_values = self.re_encode_by_rm_tokenizer(data) - input_ids = data.batch['input_ids'] # (bs, seq_len') - attention_mask = data.batch['attention_mask'] - position_ids = data.batch['position_ids'] + input_ids = data.batch["input_ids"] # (bs, seq_len') + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] - responses = data.batch['responses'] + responses = data.batch["responses"] batch_size = responses.size(0) response_length = responses.size(1) @@ -140,12 +145,15 @@ class MegatronRewardModel(BasePPORewardModel): logits = torch.empty( (input_ids.shape[0], input_ids.shape[1]), dtype=torch.bfloat16, # TODO(sgm): check why is bfloat16 - device=input_ids.device) + device=input_ids.device, + ) # broadcast across pp ranks - torch.distributed.broadcast(tensor=logits, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - async_op=False) + torch.distributed.broadcast( + tensor=logits, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + async_op=False, + ) # (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen') token_level_rewards = logits @@ -155,16 +163,16 @@ class MegatronRewardModel(BasePPORewardModel): if self.use_different_tokenizer: data.batch.update(ori_values) - input_ids = ori_values['input_ids'] - attention_mask = ori_values['attention_mask'] - position_ids = ori_values['position_ids'] + input_ids = ori_values["input_ids"] + attention_mask = ori_values["attention_mask"] + position_ids = ori_values["position_ids"] token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1]) # (bs, ori_seqlen) # assign last valid token reward to ori position eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bs,) eos_mask = torch.zeros_like(attention_mask) - eos_mask[torch.arange(batch_size), eos_mask_idx] = 1. + eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.0 token_level_rewards = token_level_rewards * eos_mask token_level_rewards = token_level_rewards[:, -response_length:] @@ -175,7 +183,7 @@ class MegatronRewardModel(BasePPORewardModel): # add empty cache after each compute torch.cuda.empty_cache() - batch = TensorDict({'rm_scores': token_level_rewards}, batch_size=input_ids.shape[0]) + batch = TensorDict({"rm_scores": token_level_rewards}, batch_size=input_ids.shape[0]) return DataProto(batch=batch) @@ -188,47 +196,52 @@ class MegatronRewardModel(BasePPORewardModel): # broadcast from last pp rank to all other pp ranks # TODO: actually, we just need to control the sampling order. data.batch = data.batch.contiguous() - broadcast_dict_tensor(data.batch, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group()) + broadcast_dict_tensor( + data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group() + ) # split into micro-batches - if self.config is not None and 'ppo_micro_batch_size_per_gpu' in self.config: + if self.config is not None and "ppo_micro_batch_size_per_gpu" in self.config: infer_batch_size = self.config.ppo_micro_batch_size_per_gpu else: infer_batch_size = data.batch.batch_size[0] - data.batch['attention_mask'] = data.batch['attention_mask'].to(bool) + data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) batches = split_dict_tensor_into_batches(data.batch, batch_size=infer_batch_size) n_micro_batch = len(batches) - seq_len = batches[0]['input_ids'].shape[1] + seq_len = batches[0]["input_ids"].shape[1] # compute input shapes for pp stages - input_shapes = compute_transformers_input_shapes(batches, - meta_info={ - 'sequence_parallel': self.tf_config.sequence_parallel, - 'hidden_size': self.model_config.hidden_size - }) + input_shapes = compute_transformers_input_shapes( + batches, + meta_info={ + "sequence_parallel": self.tf_config.sequence_parallel, + "hidden_size": self.model_config.hidden_size, + }, + ) # compute input shapes for pp stages forward_backward_func = get_forward_backward_func() def loss_func(output): - return 1., {'logits': output} + return 1.0, {"logits": output} def forward_step(batch_iter, model): batch = next(batch_iter) - input_ids = batch['input_ids'] - attention_mask = batch['attention_mask'] - position_ids = batch['position_ids'] + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] from verl.models.mcore import get_mcore_forward_fn + forward_fn = get_mcore_forward_fn(self.hf_config) - output = forward_fn(model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, - value_model=True) + output = forward_fn( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel=self.tf_config.sequence_parallel, + value_model=True, + ) return output, loss_func @@ -262,16 +275,16 @@ class MegatronRewardModel(BasePPORewardModel): return losses_reduced def offload_params_to_cpu(self): - if self.device == 'cuda': + if self.device == "cuda": for reward_model_module in self.reward_model_module: for name, param in reward_model_module.named_parameters(): - param.data = param.data.to('cpu', non_blocking=True) - self.device = 'cpu' + param.data = param.data.to("cpu", non_blocking=True) + self.device = "cpu" torch.cuda.empty_cache() def load_params_to_cuda(self): - if self.device == 'cpu': + if self.device == "cpu": for reward_model_module in self.reward_model_module: for name, param in reward_model_module.named_parameters(): param.data = param.data.to(torch.cuda.current_device(), non_blocking=True) - self.device = 'cuda' + self.device = "cuda" diff --git a/verl/workers/rollout/__init__.py b/verl/workers/rollout/__init__.py index 083848c77..5efcd337d 100644 --- a/verl/workers/rollout/__init__.py +++ b/verl/workers/rollout/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .base import BaseRollout -from .naive import NaiveRollout from .hf_rollout import HFRollout +from .naive import NaiveRollout __all__ = ["BaseRollout", "NaiveRollout", "HFRollout"] diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py index 8c2733325..2fb90ed2f 100644 --- a/verl/workers/rollout/base.py +++ b/verl/workers/rollout/base.py @@ -13,15 +13,13 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Iterable, Union from verl import DataProto -__all__ = ['BaseRollout'] +__all__ = ["BaseRollout"] class BaseRollout(ABC): - def __init__(self): """ diff --git a/verl/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py index 061a76180..0e0920f6a 100644 --- a/verl/workers/rollout/hf_rollout.py +++ b/verl/workers/rollout/hf_rollout.py @@ -16,24 +16,25 @@ Rollout with huggingface models. TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single GPU model. Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model to perform generation. """ + import contextlib + import torch import torch.distributed from tensordict import TensorDict from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import GenerationConfig from verl import DataProto from verl.utils.torch_functional import get_response_mask + from .base import BaseRollout -from transformers import GenerationConfig - -__all__ = ['HFRollout'] +__all__ = ["HFRollout"] class HFRollout(BaseRollout): - def __init__(self, module: nn.Module, config): super().__init__() self.config = config @@ -41,7 +42,7 @@ class HFRollout(BaseRollout): def generate_sequences(self, prompts: DataProto) -> DataProto: batch_size = prompts.batch.batch_size[0] - num_chunks = max(batch_size // self.config.get('micro_batch_size', batch_size), 1) + num_chunks = max(batch_size // self.config.get("micro_batch_size", batch_size), 1) batch_prompts = prompts.chunk(chunks=num_chunks) output = [self._generate_minibatch(p) for p in batch_prompts] output = DataProto.concat(output) @@ -49,13 +50,13 @@ class HFRollout(BaseRollout): @torch.no_grad() def _generate_minibatch(self, prompts: DataProto) -> DataProto: - idx = prompts.batch['input_ids'] # (bs, prompt_length) - attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask - position_ids = prompts.batch['position_ids'] + idx = prompts.batch["input_ids"] # (bs, prompt_length) + attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask + position_ids = prompts.batch["position_ids"] # used to construct attention_mask - eos_token_id = prompts.meta_info['eos_token_id'] - pad_token_id = prompts.meta_info['pad_token_id'] + eos_token_id = prompts.meta_info["eos_token_id"] + pad_token_id = prompts.meta_info["pad_token_id"] batch_size = idx.size(0) prompt_length = idx.size(1) @@ -64,16 +65,16 @@ class HFRollout(BaseRollout): param_ctx = contextlib.nullcontext() # make sampling args can be overriden by inputs - do_sample = prompts.meta_info.get('do_sample', self.config.do_sample) - response_length = prompts.meta_info.get('response_length', self.config.response_length) - top_p = prompts.meta_info.get('top_p', self.config.get('top_p', 1.0)) - top_k = prompts.meta_info.get('top_k', self.config.get('top_k', 0)) + do_sample = prompts.meta_info.get("do_sample", self.config.do_sample) + response_length = prompts.meta_info.get("response_length", self.config.response_length) + top_p = prompts.meta_info.get("top_p", self.config.get("top_p", 1.0)) + top_k = prompts.meta_info.get("top_k", self.config.get("top_k", 0)) if top_k is None: top_k = 0 top_k = max(0, top_k) # to be compatible with vllm - temperature = prompts.meta_info.get('temperature', self.config.temperature) + temperature = prompts.meta_info.get("temperature", self.config.temperature) generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k) @@ -81,7 +82,7 @@ class HFRollout(BaseRollout): # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069 param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False) with param_ctx: - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): output = self.module.generate( input_ids=idx, attention_mask=attention_mask, @@ -94,7 +95,8 @@ class HFRollout(BaseRollout): # renormalize_logits=True, output_scores=False, # this is potentially very large return_dict_in_generate=True, - use_cache=True) + use_cache=True, + ) # TODO: filter out the seq with no answers like ds-chat seq = output.sequences @@ -120,20 +122,21 @@ class HFRollout(BaseRollout): response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, - eos_token=eos_token_id, - dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) batch = TensorDict( { - 'prompts': prompt, - 'responses': response, - 'input_ids': seq, - 'attention_mask': attention_mask, - 'position_ids': position_ids + "prompts": prompt, + "responses": response, + "input_ids": seq, + "attention_mask": attention_mask, + "position_ids": position_ids, }, - batch_size=batch_size) + batch_size=batch_size, + ) # empty cache before compute old_log_prob torch.cuda.empty_cache() diff --git a/verl/workers/rollout/naive/naive_rollout.py b/verl/workers/rollout/naive/naive_rollout.py index 2c57bce10..fe56dc4c9 100644 --- a/verl/workers/rollout/naive/naive_rollout.py +++ b/verl/workers/rollout/naive/naive_rollout.py @@ -19,7 +19,6 @@ The output will contain 3. eos_masks 4. log_probs """ -from typing import Iterable, Union import torch import torch.nn.functional as F @@ -28,13 +27,13 @@ from torch import nn from verl import DataProto from verl.utils.torch_functional import logprobs_from_logits + from ..base import BaseRollout -__all__ = ['NaiveRollout'] +__all__ = ["NaiveRollout"] class NaiveRollout(BaseRollout): - def __init__(self, module: nn.Module, config): """A naive rollout. It requires the module to be compatible with huggingface APIs. That is: The module should define __call__ to receive input_ids, attention_mask and position_ids. @@ -51,14 +50,12 @@ class NaiveRollout(BaseRollout): @torch.no_grad() def generate_sequences(self, prompts: DataProto) -> DataProto: """Generate sequences""" - idx = prompts.batch['input_ids'] # (bs, prompt_length) - attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask - position_ids = prompts.batch['position_ids'] + idx = prompts.batch["input_ids"] # (bs, prompt_length) + attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask + position_ids = prompts.batch["position_ids"] # used to construct attention_mask - eos_token_id = prompts.meta_info['eos_token_id'] - if isinstance(eos_token, int): - eos_token = [eos_token] + eos_token_id = prompts.meta_info["eos_token_id"] batch_size = idx.size(0) prompt_length = idx.size(1) @@ -81,7 +78,7 @@ class NaiveRollout(BaseRollout): # optionally crop the logits to only the top k options if self.config.top_k is not None: v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') + logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution @@ -108,14 +105,15 @@ class NaiveRollout(BaseRollout): log_probs = logprobs_from_logits(logits=logits, labels=response) batch = TensorDict( { - 'input_ids': prompts, - 'responses': response, - 'sequences': idx, - 'old_log_probs': log_probs, - 'attention_mask': attention_mask, - 'position_ids': position_ids, + "input_ids": prompts, + "responses": response, + "sequences": idx, + "old_log_probs": log_probs, + "attention_mask": attention_mask, + "position_ids": position_ids, }, - batch_size=batch_size) + batch_size=batch_size, + ) self.module.train() diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index cc3968a46..322ee83f9 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -26,23 +26,26 @@ # limitations under the License. from __future__ import annotations + import os -import numpy as np from contextlib import contextmanager from typing import TYPE_CHECKING, List -from omegaconf import DictConfig -from tensordict import TensorDict -from verl import DataProto -from verl.workers.rollout.base import BaseRollout -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length, pad_2d_list_to_length -from sglang.srt.entrypoints.verl_engine import VerlEngine -from torch.distributed.device_mesh import init_device_mesh -from sglang.srt.sampling.sampling_params import SamplingParams -from verl.third_party.sglang import parallel_state as sglang_ps + +import numpy as np import torch.distributed -from torch.nn.utils.rnn import pad_sequence -from sglang.srt.utils import broadcast_pyobj, get_ip +from omegaconf import DictConfig +from sglang.srt.entrypoints.verl_engine import VerlEngine +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import broadcast_pyobj, get_ip +from tensordict import TensorDict +from torch.distributed.device_mesh import init_device_mesh +from torch.nn.utils.rnn import pad_sequence + +from verl import DataProto +from verl.third_party.sglang import parallel_state as sglang_ps +from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.workers.rollout.base import BaseRollout if TYPE_CHECKING: from torch import nn @@ -59,7 +62,6 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[in # NOTE(linjunrong): adhoc def _post_process_outputs(tokenizer, output): - def _map_each_response(l): # output_token_ids = torch.tensor(l['token_ids']) log_probs = [] @@ -77,7 +79,7 @@ def _post_process_outputs(tokenizer, output): for output_token_ids, log_probs in out_map: batched_output_token_ids.append(output_token_ids) batched_logprobs.append(log_probs) - pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id) + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id) if len(batched_logprobs) > 0: batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id) @@ -85,7 +87,6 @@ def _post_process_outputs(tokenizer, output): class SGLangRollout(BaseRollout): - def __init__( self, actor_module: nn.Module | str, @@ -106,26 +107,29 @@ class SGLangRollout(BaseRollout): super().__init__() self.config = config - assert not (not config.enforce_eager and - config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" + assert not (not config.enforce_eager and config.free_cache_engine), ( + "disable CUDA graph (enforce_eager = False) if free cache engine" + ) tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert (tensor_parallel_size <= torch.distributed.get_world_size() - ), "tensor parallel size should be less than or equal to the world size" + assert tensor_parallel_size <= torch.distributed.get_world_size(), ( + "tensor parallel size should be less than or equal to the world size" + ) - if kwargs.get("train_tp", None) is not None: + if kwargs.get("train_tp") is not None: # deployed with megatron os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - train_tp = kwargs.get("train_tp", None) + train_tp = kwargs.get("train_tp") num_tp_per_train_tp = train_tp // tensor_parallel_size sglang_ps.initialize_parallel_state( tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp, ) - assert (model_hf_config.max_position_embeddings >= config.prompt_length + - config.response_length), "model context length should be greater than total sequence length" + assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, ( + "model context length should be greater than total sequence length" + ) tp_size = tensor_parallel_size world_size = int(os.getenv("WORLD_SIZE", "-1")) @@ -141,20 +145,23 @@ class SGLangRollout(BaseRollout): # get tp_rank of this process in this tp group tp_rank = device_mesh_cpu["tp"].get_local_rank() visible_devices = [None] * device_mesh_cpu.size(1) - torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], - device_mesh_cpu.get_group("tp")) + torch.distributed.all_gather_object( + visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], device_mesh_cpu.get_group("tp") + ) visible_devices_set = set(visible_devices) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set))) nnodes = -(-tp_size // len(visible_devices_set)) server_args = ServerArgs(model_path=actor_module, nnodes=nnodes) ip, port_args = get_ip(), PortArgs.init_new(server_args) - [ip, port_args] = broadcast_pyobj([ip, port_args], - rank=tp_rank, - dist_group=device_mesh_cpu.get_group("tp"), - src=device_mesh_cpu["tp"].mesh[0].item()) + [ip, port_args] = broadcast_pyobj( + [ip, port_args], + rank=tp_rank, + dist_group=device_mesh_cpu.get_group("tp"), + src=device_mesh_cpu["tp"].mesh[0].item(), + ) dist_init_addr = f"{ip}:{port_args.nccl_port}" - load_format = 'dummy' if config.load_format.startswith('dummy') else config.load_format + load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format self.inference_engine = VerlEngine( model_path=actor_module, dtype=config.dtype, @@ -165,7 +172,7 @@ class SGLangRollout(BaseRollout): gpu_id_step=1, load_format=load_format, dist_init_addr=dist_init_addr, - nnodes=nnodes + nnodes=nnodes, # NOTE(Chenyang): if you want to debug the sglang engine # please set the following parameters # Otherwise, it will make the engine run too slow @@ -178,11 +185,13 @@ class SGLangRollout(BaseRollout): # offload self.inference_engine.release_memory_occupation() - kwargs = dict(n=1, - max_new_tokens=config.response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0) + kwargs = dict( + n=1, + max_new_tokens=config.response_length, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + ) # supporting adding any sampling params from the config file for k in config.keys(): if hasattr(SamplingParams(), str(k)): @@ -225,35 +234,42 @@ class SGLangRollout(BaseRollout): # Extract non-tensor data non_tensor_batch = prompts.non_tensor_batch - if 'raw_prompt_ids' not in non_tensor_batch: - non_tensor_batch['raw_prompt_ids'] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) + if "raw_prompt_ids" not in non_tensor_batch: + non_tensor_batch["raw_prompt_ids"] = np.array( + [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object + ) - if 'multi_modal_data' in non_tensor_batch: + if "multi_modal_data" in non_tensor_batch: sglang_inputs = [] - for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'), - non_tensor_batch.pop('multi_modal_data')): - sglang_inputs.append({ - 'prompt_token_ids': raw_prompt_ids, - 'multi_modal_data': multi_modal_data, - 'image_data': multi_modal_data.get('image', None) if isinstance(multi_modal_data, dict) else None - }) + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data") + ): + sglang_inputs.append( + { + "prompt_token_ids": raw_prompt_ids, + "multi_modal_data": multi_modal_data, + "image_data": multi_modal_data.get("image", None) + if isinstance(multi_modal_data, dict) + else None, + } + ) else: - sglang_inputs = [{ - 'prompt_token_ids': raw_prompt_ids - } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')] + sglang_inputs = [ + {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] # Ensure token IDs are lists for input_data in sglang_inputs: - if isinstance(input_data['prompt_token_ids'], np.ndarray): - input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist() - elif not isinstance(input_data['prompt_token_ids'], list): + if isinstance(input_data["prompt_token_ids"], np.ndarray): + input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() + elif not isinstance(input_data["prompt_token_ids"], list): raise TypeError( - f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) # Extract token IDs and image data for SGLang Engine - idx_list = [input_data['prompt_token_ids'] for input_data in sglang_inputs] - image_list = [input_data.get('image_data', None) for input_data in sglang_inputs] + idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] + image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] do_sample = prompts.meta_info.get("do_sample", True) if not do_sample: @@ -279,7 +295,8 @@ class SGLangRollout(BaseRollout): sampling_params=self.sampling_params, return_logprob=True, input_ids=idx_list, - image_data=image_list) + image_data=image_list, + ) out = _post_process_outputs(self.tokenizer, output) @@ -294,10 +311,10 @@ class SGLangRollout(BaseRollout): attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) position_ids = position_ids.repeat_interleave(self.config.n, dim=0) batch_size = batch_size * self.config.n - if 'multi_modal_inputs' in non_tensor_batch: - non_tensor_batch['multi_modal_inputs'] = np.repeat(non_tensor_batch['multi_modal_inputs'], - self.config.n, - axis=0) + if "multi_modal_inputs" in non_tensor_batch: + non_tensor_batch["multi_modal_inputs"] = np.repeat( + non_tensor_batch["multi_modal_inputs"], self.config.n, axis=0 + ) seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) @@ -310,9 +327,9 @@ class SGLangRollout(BaseRollout): # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, - eos_token=eos_token_id, - dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid @@ -329,8 +346,11 @@ class SGLangRollout(BaseRollout): ) # free cache engine - if (self.config.free_cache_engine and self.inference_engine._engine is not None and - self.inference_engine._engine.tokenizer_manager is not None): + if ( + self.config.free_cache_engine + and self.inference_engine._engine is not None + and self.inference_engine._engine.tokenizer_manager is not None + ): self.inference_engine._engine.tokenizer_manager.flush_cache() return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) diff --git a/verl/workers/rollout/tokenizer.py b/verl/workers/rollout/tokenizer.py index c0dfa3a53..1ad7554a1 100644 --- a/verl/workers/rollout/tokenizer.py +++ b/verl/workers/rollout/tokenizer.py @@ -14,10 +14,14 @@ """ The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM. """ + from abc import ABC, abstractmethod from typing import Dict, List, Union -__all__ = ['HybridEngineBaseTokenizer'] +import numpy as np +import torch + +__all__ = ["HybridEngineBaseTokenizer"] class HybridEngineBaseTokenizer(ABC): @@ -85,7 +89,7 @@ class HybridEngineBaseTokenizer(ABC): @abstractmethod def decode( self, - token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, List[int], np.ndarray, torch.Tensor], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, **kwargs, @@ -97,7 +101,7 @@ class HybridEngineBaseTokenizer(ABC): Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -113,9 +117,9 @@ class HybridEngineBaseTokenizer(ABC): pass @abstractmethod - def convert_ids_to_tokens(self, - ids: Union[int, List[int]], - skip_special_tokens: bool = False) -> Union[str, List[str]]: + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: """ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and added tokens. diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 0d6d4c3d8..1c0fe1255 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from importlib.metadata import version, PackageNotFoundError +from importlib.metadata import PackageNotFoundError, version ### # [SUPPORT AMD:] import torch + ### @@ -27,7 +28,7 @@ def get_version(pkg): return None -package_name = 'vllm' +package_name = "vllm" package_version = get_version(package_name) ### @@ -35,16 +36,17 @@ package_version = get_version(package_name) # [SUPPORT AMD:] if "AMD" in torch.cuda.get_device_name(): import re + package_version = version(package_name) - package_version = re.match(r'(\d+\.\d+\.?\d*)', package_version).group(1) + package_version = re.match(r"(\d+\.\d+\.?\d*)", package_version).group(1) else: package_version = get_version(package_name) ### -if package_version <= '0.6.3': - vllm_mode = 'customized' - from .vllm_rollout import vLLMRollout +if package_version <= "0.6.3": + vllm_mode = "customized" from .fire_vllm_rollout import FIREvLLMRollout + from .vllm_rollout import vLLMRollout else: - vllm_mode = 'spmd' + vllm_mode = "spmd" from .vllm_rollout_spmd import vLLMRollout diff --git a/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py b/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py index c65af5a5b..09f1e4ee5 100644 --- a/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py @@ -24,21 +24,20 @@ When working with Megatron: - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -from typing import List + from contextlib import contextmanager -from omegaconf import DictConfig +from typing import List + import torch import torch.distributed +from omegaconf import DictConfig from tensordict import TensorDict from torch import nn +from vllm import SamplingParams from verl import DataProto from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length -from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout -from verl.third_party.vllm import LLM, vllm_version -from verl.third_party.vllm import parallel_state as vllm_ps -from vllm import SamplingParams # TODO # 1. support pp in vllm @@ -56,7 +55,6 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[in class FIREvLLMRollout(vLLMRollout): - def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. @@ -69,13 +67,13 @@ class FIREvLLMRollout(vLLMRollout): """ super().__init__(actor_module, config, tokenizer, model_hf_config, **kwargs) - self.use_fire_sampling = config.get('use_fire_sampling', False) + self.use_fire_sampling = config.get("use_fire_sampling", False) if self.use_fire_sampling: kwargs_0 = kwargs.copy() - kwargs_0['temperature'] = 30 - kwargs_0['max_tokens'] = 1 - if 'top_k' not in kwargs_0 or kwargs_0['top_k'] <= 0: - kwargs_0['top_k'] = 16 + kwargs_0["temperature"] = 30 + kwargs_0["max_tokens"] = 1 + if "top_k" not in kwargs_0 or kwargs_0["top_k"] <= 0: + kwargs_0["top_k"] = 16 self.sampling_params.max_tokens = config.response_length - 1 for k in config.keys(): if hasattr(SamplingParams(), str(k)): @@ -115,13 +113,13 @@ class FIREvLLMRollout(vLLMRollout): if self.config.free_cache_engine: self.inference_engine.init_cache_engine() - idx = prompts.batch['input_ids'] # (bs, prompt_length) + idx = prompts.batch["input_ids"] # (bs, prompt_length) # left-padded attention_mask - attention_mask = prompts.batch['attention_mask'] - position_ids = prompts.batch['position_ids'] + attention_mask = prompts.batch["attention_mask"] + position_ids = prompts.batch["position_ids"] # used to construct attention_mask - eos_token_id = prompts.meta_info['eos_token_id'] + eos_token_id = prompts.meta_info["eos_token_id"] batch_size = idx.size(0) @@ -130,15 +128,15 @@ class FIREvLLMRollout(vLLMRollout): for i in range(batch_size): idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i])) - do_sample = prompts.meta_info.get('do_sample', True) + do_sample = prompts.meta_info.get("do_sample", True) if not do_sample: kwargs = { - 'best_of': 1, - 'top_p': 1.0, - 'top_k': -1, - 'min_p': 0.0, - 'temperature': 0, - 'n': 1 # if greedy, only 1 response + "best_of": 1, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + "temperature": 0, + "n": 1, # if greedy, only 1 response } if not self.use_fire_sampling: @@ -148,7 +146,8 @@ class FIREvLLMRollout(vLLMRollout): prompts=None, # because we have already convert it to prompt token id sampling_params=self.sampling_params, prompt_token_ids=idx_list, - use_tqdm=False) + use_tqdm=False, + ) response = output[0].to(idx.device) # (bs, response_length) log_probs = output[1].to(idx.device) # (bs, response_length) @@ -158,7 +157,8 @@ class FIREvLLMRollout(vLLMRollout): prompts=None, # because we have already convert it to prompt token id sampling_params=self.sampling_params_0, prompt_token_ids=idx_list, - use_tqdm=False) + use_tqdm=False, + ) new_idx_list = [] for i in range(batch_size): new_idx_list.append(idx_list[i] + output_0[0][i].tolist()) @@ -166,7 +166,8 @@ class FIREvLLMRollout(vLLMRollout): prompts=None, # because we have already convert it to prompt token id sampling_params=self.sampling_params, prompt_token_ids=new_idx_list, - use_tqdm=False) + use_tqdm=False, + ) response = torch.cat([output_0[0], output[0]], dim=1).to(idx.device) # (bs, response_length) # log_probs = torch.cat([output_0[1], output[1]], dim=1).to(idx.device) # (bs, response_length) @@ -192,22 +193,23 @@ class FIREvLLMRollout(vLLMRollout): # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, - eos_token=eos_token_id, - dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { - 'prompts': idx, - 'responses': response, - 'input_ids': seq, # here input_ids become the whole sentences + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences # 'old_log_probs': log_probs, # we will recompute old log prob with actor - 'attention_mask': attention_mask, - 'position_ids': position_ids + "attention_mask": attention_mask, + "position_ids": position_ids, }, - batch_size=batch_size) + batch_size=batch_size, + ) # free vllm cache engine if self.config.free_cache_engine: diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 40f09bb02..5f0ef1e29 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -24,21 +24,23 @@ When working with Megatron: - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -from typing import List -from copy import deepcopy + from contextlib import contextmanager -from omegaconf import DictConfig, OmegaConf +from copy import deepcopy +from typing import List + import torch import torch.distributed +from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict from torch import nn +from vllm import SamplingParams from verl import DataProto -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length -from verl.workers.rollout.base import BaseRollout from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps -from vllm import SamplingParams +from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.workers.rollout.base import BaseRollout # TODO # 1. support pp in vllm @@ -56,7 +58,6 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[in class vLLMRollout(BaseRollout): - def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. @@ -69,57 +70,67 @@ class vLLMRollout(BaseRollout): """ super().__init__() self.config = config - assert not (not config.enforce_eager and config.free_cache_engine), \ + assert not (not config.enforce_eager and config.free_cache_engine), ( "disable CUDA graph (enforce_eager = False) if free cache engine" + ) - tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), \ + tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) + assert tensor_parallel_size <= torch.distributed.get_world_size(), ( "tensor parallel size should be less than or equal to the world size" - max_num_batched_tokens = int(self.config.get('max_num_batched_tokens', 8192)) + ) + max_num_batched_tokens = int(self.config.get("max_num_batched_tokens", 8192)) - if kwargs.get('train_tp', None) is not None: + if kwargs.get("train_tp") is not None: # deployed with megatron import os - os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' - os.environ['MEGATRON_IMPORT_TIMERS'] = '0' - train_tp = kwargs.get('train_tp', None) + + os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" + os.environ["MEGATRON_IMPORT_TIMERS"] = "0" + train_tp = kwargs.get("train_tp") num_tp_per_train_tp = train_tp // tensor_parallel_size - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): - vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, - num_tp_per_train_tp=num_tp_per_train_tp) + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): + vllm_ps.initialize_parallel_state( + tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp + ) - assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ + assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, ( "model context length should be greater than total sequence length" + ) - max_model_len = self.config.max_model_len if self.config.max_model_len \ - else config.prompt_length + config.response_length + max_model_len = ( + self.config.max_model_len if self.config.max_model_len else config.prompt_length + config.response_length + ) max_model_len = int(max_model_len) if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: - raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ - please increase max_num_batched_tokens or disable chunked prefill') + raise ValueError( + "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ + please increase max_num_batched_tokens or disable chunked prefill" + ) # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if 'engine_kwargs' not in config else OmegaConf.to_container(deepcopy(config.engine_kwargs)) + engine_kwargs = {} if "engine_kwargs" not in config else OmegaConf.to_container(deepcopy(config.engine_kwargs)) # For each vLLM engine parameter, # - `None` means not setting it, so we pop it, and leave it to vLLM default value # (which can vary across different vLLM versions); # - Otherwise it's the desired value we want to explicitly set. engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} - self.inference_engine = LLM(actor_module, - tokenizer=tokenizer, - model_hf_config=model_hf_config, - tensor_parallel_size=tensor_parallel_size, - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - skip_tokenizer_init=False, - max_model_len=max_model_len, - load_format=config.load_format, - disable_log_stats=config.disable_log_stats, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=config.enable_chunked_prefill, - **engine_kwargs) + self.inference_engine = LLM( + actor_module, + tokenizer=tokenizer, + model_hf_config=model_hf_config, + tensor_parallel_size=tensor_parallel_size, + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + skip_tokenizer_init=False, + max_model_len=max_model_len, + load_format=config.load_format, + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + **engine_kwargs, + ) # Offload vllm model to reduce peak memory usage self.inference_engine.offload_model_weights() @@ -131,8 +142,8 @@ class vLLMRollout(BaseRollout): ) # we may detokenize the result all together later - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): - kwargs['detokenize'] = False + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): + kwargs["detokenize"] = False # supporting adding any sampling params from the config file for k in config.keys(): @@ -166,13 +177,13 @@ class vLLMRollout(BaseRollout): if self.config.free_cache_engine: self.inference_engine.init_cache_engine() - idx = prompts.batch['input_ids'] # (bs, prompt_length) + idx = prompts.batch["input_ids"] # (bs, prompt_length) # left-padded attention_mask - attention_mask = prompts.batch['attention_mask'] - position_ids = prompts.batch['position_ids'] + attention_mask = prompts.batch["attention_mask"] + position_ids = prompts.batch["position_ids"] # used to construct attention_mask - eos_token_id = prompts.meta_info['eos_token_id'] + eos_token_id = prompts.meta_info["eos_token_id"] batch_size = idx.size(0) @@ -181,24 +192,24 @@ class vLLMRollout(BaseRollout): for i in range(batch_size): idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i])) - do_sample = prompts.meta_info.get('do_sample', True) - is_validate = prompts.meta_info.get('validate', False) + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) if not do_sample: kwargs = { - 'best_of': 1, - 'top_p': 1.0, - 'top_k': -1, - 'min_p': 0.0, - 'temperature': 0, - 'n': 1 # if greedy, only 1 response + "best_of": 1, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + "temperature": 0, + "n": 1, # if greedy, only 1 response } elif is_validate: # TODO: try ** kwargs = { - 'top_k': self.config.val_kwargs.top_k, - 'top_p': self.config.val_kwargs.top_p, - 'temperature': self.config.val_kwargs.temperature, - 'n': 1, # if validate, already repeat in ray_trainer + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer } # users can customize different sampling_params at different run @@ -207,7 +218,8 @@ class vLLMRollout(BaseRollout): prompts=None, # because we have already convert it to prompt token id sampling_params=self.sampling_params, prompt_token_ids=idx_list, - use_tqdm=False) + use_tqdm=False, + ) # TODO(sgm): disable logprob when recompute_log_prob is enable # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) @@ -236,22 +248,23 @@ class vLLMRollout(BaseRollout): # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, - eos_token=eos_token_id, - dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { - 'prompts': idx, - 'responses': response, - 'input_ids': seq, # here input_ids become the whole sentences + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences # 'old_log_probs': log_probs, # we will recompute old log prob with actor - 'attention_mask': attention_mask, - 'position_ids': position_ids + "attention_mask": attention_mask, + "position_ids": position_ids, }, - batch_size=batch_size) + batch_size=batch_size, + ) # free vllm cache engine if self.config.free_cache_engine: diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index bf2da53fa..66f0f0b26 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -24,21 +24,22 @@ When working with Megatron: - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -import numpy as np -from typing import List + from contextlib import contextmanager -from omegaconf import DictConfig +from typing import Any, List, Union + +import numpy as np import torch import torch.distributed +from omegaconf import DictConfig from tensordict import TensorDict -from torch import nn -from typing import Any, Union +from vllm import LLM, SamplingParams +from vllm.distributed import parallel_state as vllm_ps + from verl import DataProto +from verl.third_party.vllm import vllm_version from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout -from vllm.distributed import parallel_state as vllm_ps -from vllm import LLM, SamplingParams -from verl.third_party.vllm import vllm_version # TODO # 1. support pp in vllm @@ -63,7 +64,6 @@ def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> class vLLMRollout(BaseRollout): - def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. @@ -76,42 +76,49 @@ class vLLMRollout(BaseRollout): """ super().__init__() self.config = config - assert not (not config.enforce_eager and config.free_cache_engine), \ + assert not (not config.enforce_eager and config.free_cache_engine), ( "disable CUDA graph (enforce_eager = False) if free cache engine" + ) - tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), \ + tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) + assert tensor_parallel_size <= torch.distributed.get_world_size(), ( "tensor parallel size should be less than or equal to the world size" - max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192) + ) + max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192) - if kwargs.get('train_tp', None) is not None: + if kwargs.get("train_tp") is not None: # deployed with megatron import os - os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' - os.environ['MEGATRON_IMPORT_TIMERS'] = '0' - if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): - train_tp = kwargs.get('train_tp', None) + + os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" + os.environ["MEGATRON_IMPORT_TIMERS"] = "0" + if vllm_version in ("0.3.1", "0.4.2", "0.5.4", "0.6.3"): + train_tp = kwargs.get("train_tp") num_tp_per_train_tp = train_tp // tensor_parallel_size - vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, - num_tp_per_train_tp=num_tp_per_train_tp) + vllm_ps.initialize_parallel_state( + tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp + ) else: vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) - assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ + assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, ( "model context length should be greater than total sequence length" + ) max_model_len = int(config.max_model_len or config.prompt_length + config.response_length) if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: - raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ - please increase max_num_batched_tokens or disable chunked prefill') + raise ValueError( + "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ + please increase max_num_batched_tokens or disable chunked prefill" + ) - trust_remote_code = kwargs.get('trust_remote_code', False) - load_format = 'dummy' if config.load_format.startswith('dummy') else config.load_format + trust_remote_code = kwargs.get("trust_remote_code", False) + load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format limit_mm_per_prompt = None - if config.get('limit_images', None): # support for multi-image data - limit_mm_per_prompt = {"image": config.get('limit_images')} + if config.get("limit_images", None): # support for multi-image data + limit_mm_per_prompt = {"image": config.get("limit_images")} self.inference_engine = LLM( model=model_path, @@ -132,7 +139,7 @@ class vLLMRollout(BaseRollout): enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, - seed=config.get('seed', 0), + seed=config.get("seed", 0), ) # Offload vllm model to reduce peak memory usage @@ -145,8 +152,8 @@ class vLLMRollout(BaseRollout): ) # # we may detokenize the result all together later - if vllm_version != '0.3.1': - kwargs['detokenize'] = False + if vllm_version != "0.3.1": + kwargs["detokenize"] = False # supporting adding any sampling params from the config file for k in config.keys(): @@ -177,64 +184,67 @@ class vLLMRollout(BaseRollout): @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine - if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: + if vllm_version in ("0.3.1", "0.4.2", "0.5.4", "0.6.3") and self.config.free_cache_engine: self.inference_engine.init_cache_engine() - idx = prompts.batch['input_ids'] # (bs, prompt_length) + idx = prompts.batch["input_ids"] # (bs, prompt_length) # left-padded attention_mask - attention_mask = prompts.batch['attention_mask'] - position_ids = prompts.batch['position_ids'] + attention_mask = prompts.batch["attention_mask"] + position_ids = prompts.batch["position_ids"] # used to construct attention_mask - eos_token_id = prompts.meta_info['eos_token_id'] + eos_token_id = prompts.meta_info["eos_token_id"] batch_size = idx.size(0) non_tensor_batch = prompts.non_tensor_batch - if 'raw_prompt_ids' not in non_tensor_batch: - non_tensor_batch['raw_prompt_ids'] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) + if "raw_prompt_ids" not in non_tensor_batch: + non_tensor_batch["raw_prompt_ids"] = np.array( + [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object + ) - if batch_size != len(non_tensor_batch['raw_prompt_ids']): - raise RuntimeError('vllm sharding manager is not work properly.') + if batch_size != len(non_tensor_batch["raw_prompt_ids"]): + raise RuntimeError("vllm sharding manager is not work properly.") - if 'multi_modal_data' in non_tensor_batch: + if "multi_modal_data" in non_tensor_batch: vllm_inputs = [] - for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'), - non_tensor_batch.pop('multi_modal_data')): - vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data}) + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data") + ): + vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data}) else: - vllm_inputs = [{ - 'prompt_token_ids': raw_prompt_ids - } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')] + vllm_inputs = [ + {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] # ensure the type of `prompt_token_ids` passed to vllm is list[int] # https://github.com/volcengine/verl/pull/772 for input_data in vllm_inputs: - if isinstance(input_data['prompt_token_ids'], np.ndarray): - input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist() - elif not isinstance(input_data['prompt_token_ids'], list): + if isinstance(input_data["prompt_token_ids"], np.ndarray): + input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() + elif not isinstance(input_data["prompt_token_ids"], list): raise TypeError( - f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) - do_sample = prompts.meta_info.get('do_sample', True) - is_validate = prompts.meta_info.get('validate', False) + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) if not do_sample: kwargs = { - 'best_of': 1, - 'top_p': 1.0, - 'top_k': -1, - 'min_p': 0.0, - 'temperature': 0, - 'n': 1 # if greedy, only 1 response + "best_of": 1, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + "temperature": 0, + "n": 1, # if greedy, only 1 response } elif is_validate: # TODO: try ** kwargs = { - 'top_k': self.config.val_kwargs.top_k, - 'top_p': self.config.val_kwargs.top_p, - 'temperature': self.config.val_kwargs.temperature, - 'n': 1, # if validate, already repeat in ray_trainer + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer } # users can customize different sampling_params at different run @@ -242,7 +252,8 @@ class vLLMRollout(BaseRollout): outputs = self.inference_engine.generate( prompts=vllm_inputs, # because we have already convert it to prompt token id sampling_params=self.sampling_params, - use_tqdm=False) + use_tqdm=False, + ) # TODO(sgm): disable logprob when recompute_log_prob is enable # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) @@ -252,17 +263,19 @@ class vLLMRollout(BaseRollout): for sample_id in range(len(output.outputs)): response.append(output.outputs[sample_id].token_ids) - response = pad_2d_list_to_length(response, self.pad_token_id, - max_length=self.config.response_length).to(idx.device) + response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to( + idx.device + ) if self.sampling_params.n > 1 and do_sample: idx = _repeat_interleave(idx, self.sampling_params.n) attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) position_ids = _repeat_interleave(position_ids, self.sampling_params.n) batch_size = batch_size * self.sampling_params.n - if 'multi_modal_inputs' in non_tensor_batch.keys(): - non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'], - self.sampling_params.n) + if "multi_modal_inputs" in non_tensor_batch.keys(): + non_tensor_batch["multi_modal_inputs"] = _repeat_interleave( + non_tensor_batch["multi_modal_inputs"], self.sampling_params.n + ) seq = torch.cat([idx, response], dim=-1) @@ -278,25 +291,26 @@ class vLLMRollout(BaseRollout): # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[..., -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, - eos_token=eos_token_id, - dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { - 'prompts': idx, - 'responses': response, - 'input_ids': seq, # here input_ids become the whole sentences + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences # 'old_log_probs': log_probs, # we will recompute old log prob with actor - 'attention_mask': attention_mask, - 'position_ids': position_ids + "attention_mask": attention_mask, + "position_ids": position_ids, }, - batch_size=batch_size) + batch_size=batch_size, + ) # free vllm cache engine - if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: + if vllm_version in ("0.3.1", "0.4.2", "0.5.4", "0.6.3") and self.config.free_cache_engine: self.inference_engine.free_cache_engine() return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) diff --git a/verl/workers/sharding_manager/__init__.py b/verl/workers/sharding_manager/__init__.py index d15f5133e..990806634 100644 --- a/verl/workers/sharding_manager/__init__.py +++ b/verl/workers/sharding_manager/__init__.py @@ -13,9 +13,9 @@ # limitations under the License. from verl.utils.import_utils import ( - is_vllm_available, - is_sglang_available, is_megatron_core_available, + is_sglang_available, + is_vllm_available, ) from .base import BaseShardingManager diff --git a/verl/workers/sharding_manager/base.py b/verl/workers/sharding_manager/base.py index d8717890f..d7415892a 100644 --- a/verl/workers/sharding_manager/base.py +++ b/verl/workers/sharding_manager/base.py @@ -19,7 +19,6 @@ from verl import DataProto class BaseShardingManager: - def __enter__(self): pass diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 7682b1cac..aac22b76c 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -25,34 +25,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import logging +import os + import torch -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig +from sglang.srt.entrypoints.verl_engine import VerlEngine from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.debug import log_gpu_memory_usage -from sglang.srt.entrypoints.verl_engine import VerlEngine +from verl.utils.torch_functional import broadcast_dict_tensor + from .base import BaseShardingManager -from verl.third_party.sglang import parallel_state as sglang_ps + # from vllm.distributed import parallel_state as sglang_ps logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) class FSDPSGLangShardingManager(BaseShardingManager): - - def __init__(self, - module: FSDP, - inference_engine: VerlEngine, - model_config, - full_params: bool = False, - device_mesh: DeviceMesh = None): + def __init__( + self, + module: FSDP, + inference_engine: VerlEngine, + model_config, + full_params: bool = False, + device_mesh: DeviceMesh = None, + ): self.module = module self.inference_engine = inference_engine self.model_config = model_config @@ -61,19 +64,21 @@ class FSDPSGLangShardingManager(BaseShardingManager): # Full params self.full_params = full_params if full_params: - FSDP.set_state_dict_type(self.module, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig()) + FSDP.set_state_dict_type( + self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() + ) else: - FSDP.set_state_dict_type(self.module, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + self.module, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states if self.device_mesh is not None: - gen_dp_rank = self.device_mesh['dp'].get_local_rank() + gen_dp_rank = self.device_mesh["dp"].get_local_rank() torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) @@ -82,19 +87,19 @@ class FSDPSGLangShardingManager(BaseShardingManager): def __enter__(self): torch.cuda.empty_cache() - log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger) + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) params = self.module.state_dict() - log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) + log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory - load_format = None if self.full_params else 'dtensor' + load_format = None if self.full_params else "dtensor" self.inference_engine.resume_memory_occupation() self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None) - log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) + log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params torch.cuda.empty_cache() - log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) + log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) # TODO: offload FSDP model weights # self.module.cpu() @@ -108,9 +113,9 @@ class FSDPSGLangShardingManager(BaseShardingManager): torch.cuda.set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): - log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger) + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) self.inference_engine.release_memory_occupation() - log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger) + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) # self.module.to('cuda') # if torch.distributed.get_rank() == 0: diff --git a/verl/workers/sharding_manager/fsdp_ulysses.py b/verl/workers/sharding_manager/fsdp_ulysses.py index 34398d7ed..d9c7b4a2a 100644 --- a/verl/workers/sharding_manager/fsdp_ulysses.py +++ b/verl/workers/sharding_manager/fsdp_ulysses.py @@ -14,19 +14,14 @@ """ Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT """ -from .base import BaseShardingManager from torch.distributed.device_mesh import DeviceMesh -from verl.utils.torch_functional import allgather_dict_tensors -from verl.protocol import all_gather_data_proto -from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group -import numpy as np - -import torch -import torch.distributed - from verl import DataProto +from verl.protocol import all_gather_data_proto +from verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group + +from .base import BaseShardingManager class FSDPUlyssesShardingManager(BaseShardingManager): @@ -44,7 +39,7 @@ class FSDPUlyssesShardingManager(BaseShardingManager): # We have a global SP group # so we have to change to use model-specific sp group self.prev_sp_group = get_ulysses_sequence_parallel_group() - set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group()) + set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group()) # TODO: check how to set seed for each model def __exit__(self, exc_type, exc_value, traceback): @@ -61,8 +56,8 @@ class FSDPUlyssesShardingManager(BaseShardingManager): In Ulysses, we need to make sure the same data is used across a SP group """ if self.device_mesh is not None: - sp_size = self.device_mesh['sp'].size() - group = self.device_mesh['sp'].get_group() + sp_size = self.device_mesh["sp"].size() + group = self.device_mesh["sp"].get_group() all_gather_data_proto(data=data, process_group=group) return data @@ -72,7 +67,7 @@ class FSDPUlyssesShardingManager(BaseShardingManager): Split the data to follow FSDP partition """ if self.device_mesh is not None: - sp_size = self.device_mesh['sp'].size() - sp_rank = self.device_mesh['sp'].get_local_rank() + sp_size = self.device_mesh["sp"].size() + sp_rank = self.device_mesh["sp"].get_local_rank() data = data.chunk(chunks=sp_size)[sp_rank] - return data \ No newline at end of file + return data diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index ddf95e314..eda35f62d 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -12,39 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import inspect import logging -import torch -import numpy as np -from packaging import version -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig -from torch.distributed.device_mesh import DeviceMesh +import os + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP -from verl.third_party.vllm import LLM -from verl.third_party.vllm import parallel_state as vllm_ps from verl import DataProto from verl.protocol import all_gather_data_proto +from verl.third_party.vllm import LLM, vllm_version +from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import log_gpu_memory_usage -from verl.third_party.vllm import vllm_version -from vllm.version import __version__ as VLLM_VERSION from .base import BaseShardingManager from .patch import patched_ds_v3_load_weights, patched_qwen_moe_load_weights logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) class FSDPVLLMShardingManager(BaseShardingManager): - - def __init__(self, - module: FSDP, - inference_engine: LLM, - model_config, - full_params: bool = False, - device_mesh: DeviceMesh = None): + def __init__( + self, + module: FSDP, + inference_engine: LLM, + model_config, + full_params: bool = False, + device_mesh: DeviceMesh = None, + ): self.module = module self.inference_engine = inference_engine self.model_config = model_config @@ -53,13 +51,15 @@ class FSDPVLLMShardingManager(BaseShardingManager): # Full params self.full_params = full_params if full_params: - FSDP.set_state_dict_type(self.module, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig()) + FSDP.set_state_dict_type( + self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() + ) else: - FSDP.set_state_dict_type(self.module, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + self.module, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() @@ -68,7 +68,7 @@ class FSDPVLLMShardingManager(BaseShardingManager): self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states if self.device_mesh is not None: - gen_dp_rank = self.device_mesh['dp'].get_local_rank() + gen_dp_rank = self.device_mesh["dp"].get_local_rank() torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) @@ -85,15 +85,15 @@ class FSDPVLLMShardingManager(BaseShardingManager): # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 torch.cuda.empty_cache() - log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger) + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) params = self.module.state_dict() - log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) + log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory - load_format = 'hf' if self.full_params else 'dtensor' + load_format = "hf" if self.full_params else "dtensor" - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): self.inference_engine.sync_model_weights(params, load_format=load_format) - log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) + log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params else: if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: @@ -103,14 +103,14 @@ class FSDPVLLMShardingManager(BaseShardingManager): # update model params self.update_params(params) - log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) + log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params torch.cuda.empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) - log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) + log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) # TODO: offload FSDP model weights # self.module.cpu() @@ -124,13 +124,13 @@ class FSDPVLLMShardingManager(BaseShardingManager): torch.cuda.set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): - log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) + log_gpu_memory_usage("Before vllm offload in sharding manager", logger=logger) # TODO(ZSL): check this - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): self.inference_engine.offload_model_weights() else: self.inference_engine.sleep(level=1) - log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) + log_gpu_memory_usage("After vllm offload in sharding manager", logger=logger) # self.module.to('cuda') # if torch.distributed.get_rank() == 0: @@ -152,7 +152,7 @@ class FSDPVLLMShardingManager(BaseShardingManager): return data # TODO: Current impl doesn't consider FSDP with torch micro-dp - if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ("0.3.1", "0.4.2", "0.5.4", "0.6.3"): group = vllm_ps.get_tensor_model_parallel_group() else: group = vllm_ps.get_tensor_model_parallel_group().device_group @@ -170,15 +170,24 @@ class FSDPVLLMShardingManager(BaseShardingManager): def update_params(self, updated_params): model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model world_size = torch.distributed.get_world_size() - if model.config.architectures[0] in ['DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM']: + if model.config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: loaded_params = patched_ds_v3_load_weights( - model, ((name, param.full_tensor() if world_size != 1 and hasattr(param, 'full_tensor') else param) - for name, param in updated_params.items())) - elif model.config.architectures[0] in ['Qwen2MoeForCausalLM']: + model, + ( + (name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) + for name, param in updated_params.items() + ), + ) + elif model.config.architectures[0] in ["Qwen2MoeForCausalLM"]: loaded_params = patched_qwen_moe_load_weights( - model, ((name, param.full_tensor() if world_size != 1 and hasattr(param, 'full_tensor') else param) - for name, param in updated_params.items())) + model, + ( + (name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) + for name, param in updated_params.items() + ), + ) else: loaded_params = model.load_weights( - ((name, param.full_tensor() if world_size != 1 else param) for name, param in updated_params.items())) + ((name, param.full_tensor() if world_size != 1 else param) for name, param in updated_params.items()) + ) logger.info(f"vLLM load weights, loaded_params: {len(loaded_params)}") diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index a520a07b9..1cafcfc93 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -15,35 +15,36 @@ This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. """ -import os import logging +import os + import torch import torch.distributed as dist - -from torch import nn - -from megatron.core import parallel_state as mpu from megatron.core import DistributedDataParallel as LocalDDP +from megatron.core import parallel_state as mpu from megatron.core.transformer.module import Float16Module +from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + from verl.utils.debug import log_gpu_memory_usage -from verl.utils.megatron_utils import get_model, unwrap_model +from verl.utils.megatron_utils import ( + broadcast_from_megatron_pp, + broadcast_str_from_megatron_pp, + get_model, + unwrap_model, +) from verl.utils.memory_buffer import ( build_memory_buffer, build_memory_reference_from_module, get_weight_buffer_meta_from_module, ) - from verl.utils.model import normalize_model_name -from verl.workers.actor.megatron_actor import MegatronPPOActor -from verl.utils.megatron_utils import broadcast_from_megatron_pp, broadcast_str_from_megatron_pp logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) class AllGatherPPModel: - def __init__(self, model_provider, use_distributed_optimizer=True) -> None: print( "[WARNING] This class is deprecated and will no longer be supported. Consider using the `MegatronPPOActor` class directly as a replacement." @@ -66,8 +67,10 @@ class AllGatherPPModel: self.memory_buffers = [None] * self.pp_size for cur_pp_rank in rank_list: print( - f'create pp model', f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, ' - f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB') + "create pp model", + f"torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, " + f"reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB", + ) # since the last initialized rank is the current pp rank, after init, the pp rank is still correct mpu.set_pipeline_model_parallel_rank(cur_pp_rank) if cur_pp_rank != self.pp_rank: @@ -77,9 +80,9 @@ class AllGatherPPModel: self.pp_models[cur_pp_rank] = models else: # for regular model, we wrapped it with DDP - models = get_model(model_provider, - wrap_with_ddp=True, - use_distributed_optimizer=use_distributed_optimizer) + models = get_model( + model_provider, wrap_with_ddp=True, use_distributed_optimizer=use_distributed_optimizer + ) assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}" self._this_rank_models = nn.ModuleList(models) self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP))) @@ -97,6 +100,7 @@ class AllGatherPPModel: """Build the parameter buffer in each pp rank""" if pp_rank == self._pp_rank: from verl.utils.memory_buffer import MemoryBuffer + # The code here is very hard-coded, based on the following assumptions: # 1. `len(_this_rank_models) == 1` # 2. `_this_rank_models[0]` is a instance of `DistributedDataParallel` and `use_distributed_optimizer=True` @@ -122,7 +126,7 @@ class AllGatherPPModel: if not to_empty: buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True) else: - buffer.data = torch.empty_like(buffer.data, device='cuda') + buffer.data = torch.empty_like(buffer.data, device="cuda") # rebuild reference after loading to CUDA self._build_param_references(pp_rank) @@ -131,9 +135,9 @@ class AllGatherPPModel: for buffer in self.memory_buffers[pp_rank].values(): if not to_empty: # offload the whole memory buffer to CPU - buffer.data = buffer.data.to('cpu', non_blocking=True) + buffer.data = buffer.data.to("cpu", non_blocking=True) else: - buffer.data = torch.empty_like(buffer.data, device='cpu') + buffer.data = torch.empty_like(buffer.data, device="cpu") self._build_param_references(pp_rank) def load_params_to_cuda(self, to_empty=False): @@ -205,7 +209,7 @@ class AllGatherPPModel: pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module))) # not use Float16Module for name, param in pp_model.named_parameters(): # NOTE(gh) workaround: should not get lora params for inference - if 'lora' in name: + if "lora" in name: continue params[pp_rank][model_chunk_idx][name] = param @@ -245,23 +249,20 @@ Megatron Hybrid Engine: - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -from .base import BaseShardingManager - -import torch import inspect -from torch import nn + import torch.distributed from torch.distributed import new_group -from typing import Dict, Iterable, Union, Tuple -from verl import DataProto -from verl.protocol import all_gather_data_proto -from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) import verl.utils.megatron.tensor_parallel as tp_utils -from verl.third_party.vllm import vllm_version +from verl import DataProto +from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps -from verl.third_party.vllm import LLM from verl.utils.megatron_utils import convert_megatron_model_to_transformers_model +from verl.utils.torch_functional import allgather_dict_tensors + +from .base import BaseShardingManager + # Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp # into infer_tp and micro_tp. By default, we use order micro_dp - tp # NOTICE: in new version of vLLM, We need to all-gather all tp rank's model weights @@ -270,14 +271,16 @@ _MICRO_DATA_PARALLEL_GROUP = None class MegatronVLLMShardingManager(BaseShardingManager): - - def __init__(self, - actor_module: nn.ModuleList, - inference_engine: LLM, - model_config, - layer_name_mapping, - module: AllGatherPPModel = None): + def __init__( + self, + actor_module: nn.ModuleList, + inference_engine: LLM, + model_config, + layer_name_mapping, + module: AllGatherPPModel = None, + ): from megatron.core import parallel_state as mpu + self.actor_module = actor_module self.inference_engine = inference_engine self.model_config = model_config @@ -296,13 +299,12 @@ class MegatronVLLMShardingManager(BaseShardingManager): self.need_tp_reshard = self.infer_tp_size == self.train_tp_size # TODO(sgm): this may not be true for FSDP -> vLLM - assert self.infer_tp_size <= self.train_tp_size, \ - 'Not implemented for infer_tp > train_tp' + assert self.infer_tp_size <= self.train_tp_size, "Not implemented for infer_tp > train_tp" assert self.train_tp_size % self.infer_tp_size == 0 micro_dp_size = self.train_tp_size // self.infer_tp_size num_micro_dp_groups = world_size // micro_dp_size - assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized") + assert _MICRO_DATA_PARALLEL_GROUP is None, "micro data parallel group is already initialized" for i in range(num_micro_dp_groups): ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size) group = new_group(ranks=ranks) @@ -314,11 +316,12 @@ class MegatronVLLMShardingManager(BaseShardingManager): convert_qkv_gate_up_by_simple_split is a parameter affected by the vLLM version. """ from megatron.core import parallel_state as mpu + pp_rank = mpu.get_pipeline_model_parallel_rank() pp_size = mpu.get_pipeline_model_parallel_world_size() vpp_size = len(self.actor_module) - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): all_gather_group = get_micro_data_parallel_group() else: all_gather_group = self.train_tp_group @@ -336,9 +339,9 @@ class MegatronVLLMShardingManager(BaseShardingManager): meta_info.append((pp_rank, scan_vpp_idx, idx, name)) obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=obj_spec_output, - obj=meta_info, - group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.all_gather_object( + object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() + ) layer_list_meta = [item for sublist in obj_spec_output for item in sublist] gen_func = tensor_generator() @@ -350,8 +353,9 @@ class MegatronVLLMShardingManager(BaseShardingManager): cur_name, cur_tensor = next(gen_func) except StopIteration: cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, - self.model_config.num_hidden_layers) + cur_name = normalize_model_name( + name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, self.model_config.num_hidden_layers + ) else: cur_tensor, cur_name = None, None @@ -361,7 +365,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): # (xya): this is a hack to fix the name of the parameters while cur_name.startswith("module."): - cur_name = cur_name[len("module."):] + cur_name = cur_name[len("module.") :] # tp all gather if tp_utils.is_tensor_parallel_param(broad_pp_tensor): @@ -370,11 +374,12 @@ class MegatronVLLMShardingManager(BaseShardingManager): infer_params = [broad_pp_tensor] else: infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] - torch.distributed.all_gather(infer_params, - broad_pp_tensor, - group=mpu.get_tensor_model_parallel_group()) - infer_params = self.default_tp_concat_fn(cur_name, broad_pp_tensor, infer_params, self.model_config, - convert_qkv_gate_up_by_simple_split) + torch.distributed.all_gather( + infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group() + ) + infer_params = self.default_tp_concat_fn( + cur_name, broad_pp_tensor, infer_params, self.model_config, convert_qkv_gate_up_by_simple_split + ) else: infer_params = broad_pp_tensor @@ -385,7 +390,8 @@ class MegatronVLLMShardingManager(BaseShardingManager): self.model_config, self.train_tp_size, 0, # no impact - convert_qkv_gate_up_by_trunk_concat=False) # defualt false + convert_qkv_gate_up_by_trunk_concat=False, + ) # defualt false for converted_name, infer_param in zip(converted_names, converted_params): yield converted_name, infer_param @@ -397,7 +403,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from train tp group (vllm 0.8.2) or micro-dp group (vllm <= 0.6.3) model_config: huggingface model_config TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model - definition so that it is model-agnostic. If the model doesn't implement this function, + definition so that it is model-agnostic. If the model doesn't implement this function, we can throw an error to force user disable TP HybridEngine. """ if self.layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: @@ -408,9 +414,9 @@ class MegatronVLLMShardingManager(BaseShardingManager): v_lst = [] assert model_config.num_attention_heads % model_config.num_key_value_heads == 0 num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads - assert infer_params[0].shape[0] % ( - num_q_per_kv + - 2) == 0, f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" + assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( + f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" + ) kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] for infer_param in infer_params: @@ -419,7 +425,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): split_size = [ kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition + kv_size_per_tp // num_query_groups_per_partition, ] q, k, v = chunk.split(split_size) q_lst.append(q) @@ -461,7 +467,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): # here the params are in train tp format. we iterate params and all-gather # TODO(zhangchi.usc1992) We can consider copy non-tp weight to another infer buffer. # In this way, all the params in the original memory_buffers and can be offload. - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): all_gather_group = get_micro_data_parallel_group() else: all_gather_group = self.train_tp_group @@ -475,8 +481,9 @@ class MegatronVLLMShardingManager(BaseShardingManager): else: infer_params = [torch.empty_like(param) for _ in range(all_gather_group_size)] torch.distributed.all_gather(infer_params, param, group=all_gather_group) - infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config, - convert_qkv_gate_up_by_simple_split) + infer_params = self.default_tp_concat_fn( + name, param, infer_params, self.model_config, convert_qkv_gate_up_by_simple_split + ) else: infer_params = param converted_names, converted_params = convert_megatron_model_to_transformers_model( @@ -485,14 +492,15 @@ class MegatronVLLMShardingManager(BaseShardingManager): self.model_config, self.train_tp_size, self.module.pp_models[0][0].config.num_query_groups, - convert_qkv_gate_up_by_trunk_concat=False) + convert_qkv_gate_up_by_trunk_concat=False, + ) for converted_name, infer_param in zip(converted_names, converted_params): yield converted_name, infer_param def __enter__(self): - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): per_tensor_param = self.per_tensor_generator(convert_qkv_gate_up_by_simple_split=False) - self.inference_engine.sync_model_weights(per_tensor_param, load_format='megatron') + self.inference_engine.sync_model_weights(per_tensor_param, load_format="megatron") else: # > 0.7.2 if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: @@ -503,20 +511,20 @@ class MegatronVLLMShardingManager(BaseShardingManager): model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model loaded_params = model.load_weights(per_tensor_param) logger.info(f"vLLM load weights, loaded_params: {len(loaded_params)}") - log_gpu_memory_usage('After load_weights sharding manager memory', logger=logger) + log_gpu_memory_usage("After load_weights sharding manager memory", logger=logger) if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) def __exit__(self, exc_type, exc_value, traceback): - log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + log_gpu_memory_usage("Before vllm offload in sharding manager", logger=logger) + if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): self.inference_engine.offload_model_weights() else: self.inference_engine.sleep(level=1) for model in self.actor_module: model.train() - log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) + log_gpu_memory_usage("After vllm offload in sharding manager", logger=logger) torch.cuda.empty_cache() @@ -534,10 +542,12 @@ class MegatronVLLMShardingManager(BaseShardingManager): # all gather batch among micro-dp groups micro_dp_size = get_micro_data_parallel_world_size() if micro_dp_size > 1: - data.batch = allgather_dict_tensors(data.batch.contiguous(), - size=get_micro_data_parallel_world_size(), - group=get_micro_data_parallel_group(), - dim=0) + data.batch = allgather_dict_tensors( + data.batch.contiguous(), + size=get_micro_data_parallel_world_size(), + group=get_micro_data_parallel_group(), + dim=0, + ) return data diff --git a/verl/workers/sharding_manager/patch/fsdp_vllm_patch.py b/verl/workers/sharding_manager/patch/fsdp_vllm_patch.py index 419d8f9d9..a4ffa8c97 100644 --- a/verl/workers/sharding_manager/patch/fsdp_vllm_patch.py +++ b/verl/workers/sharding_manager/patch/fsdp_vllm_patch.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +from typing import Iterable, Optional, Set, Tuple + import torch from torch import nn -from typing import Optional, Union, Iterable, Tuple, Set from transformers import PretrainedConfig from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.model_loader.weight_utils import (default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name from vllm.model_executor.models.utils import is_pp_missing_parameter -import re + +def get_layer_index(layer_name: str) -> int: + pattern = r"layers\.(\d+)" + match = re.search(pattern, layer_name) + if match: + return int(match.group(1)) + raise ValueError(f"Unable to parse layer index from '{layer_name}'") def get_layer_index(layer_name: str) -> int: @@ -32,12 +40,11 @@ def get_layer_index(layer_name: str) -> int: def patched_ds_v3_load_weights(model: nn.Module, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, weight_name: str) -> Optional[int]: if hasattr(config, "num_nextn_predict_layers") and (config.num_nextn_predict_layers > 0): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None @@ -46,10 +53,12 @@ def patched_ds_v3_load_weights(model: nn.Module, weights: Iterable[Tuple[str, to ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping(ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=model.config.n_routed_experts) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=model.config.n_routed_experts, + ) params_dict = dict(model.named_parameters()) loaded_params: Set[str] = set() @@ -61,10 +70,10 @@ def patched_ds_v3_load_weights(model: nn.Module, weights: Iterable[Tuple[str, to if spec_layer is not None: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: @@ -126,17 +135,19 @@ def patched_qwen_moe_load_weights(model: nn.Module, weights: Iterable[Tuple[str, # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping(ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=model.config.num_experts) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=model.config.num_experts, + ) params_dict = dict(model.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -150,7 +161,7 @@ def patched_qwen_moe_load_weights(model: nn.Module, weights: Iterable[Tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): + if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, model): @@ -172,7 +183,7 @@ def patched_qwen_moe_load_weights(model: nn.Module, weights: Iterable[Tuple[str, if is_pp_missing_parameter(name, model): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): + if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: continue # custom weight_loader param = params_dict[name] @@ -184,7 +195,7 @@ def patched_qwen_moe_load_weights(model: nn.Module, weights: Iterable[Tuple[str, break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): + if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, model):