mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
Compare commits
11 Commits
7ddb9b29f0
...
v0.3.0.pos
Author | SHA1 | Date | |
---|---|---|---|
070ed6ac3f | |||
5097b13149 | |||
9ef1f48704 | |||
de9e01b847 | |||
b70981bdb9 | |||
52bec83183 | |||
e3e82b8b25 | |||
67986f4e5d | |||
20cd5629fa | |||
8e7780f5ee | |||
ba1245b6d1 |
2
.github/workflows/e2e_sglang_gsm8k.yml
vendored
2
.github/workflows/e2e_sglang_gsm8k.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
NO_PROXY: "localhost,127.0.0.1"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.3.post3
|
||||
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.4.post3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
122
docs/advance/checkpoint.rst
Normal file
122
docs/advance/checkpoint.rst
Normal file
@ -0,0 +1,122 @@
|
||||
Using Checkpoints to Support Fault Tolerance Training
|
||||
=====================================================
|
||||
|
||||
There could be training errors or machine failure during the whole RLHF training process,
|
||||
so it is recommended to enable checkpoints to minimize your loss.
|
||||
|
||||
The API Interface has already been listed in :ref:`config-explain-page`,
|
||||
and we will not repeat them. But there are still some technique details
|
||||
we hope to clarify.
|
||||
|
||||
.. note::
|
||||
|
||||
Notice that the ``checkpoint.contents`` field has no effect to FSDP checkpoint except ``hf_model``,
|
||||
the other 3 fields are binded together to save and load. We recommend to include ``model``, ``optimizer`` and ``extra`` all.
|
||||
|
||||
Checkpoint Saving Directory Structure
|
||||
-------------------------------------
|
||||
|
||||
Commonly, we use the ``default_local_dir`` declared in ``ppo_trainer.yaml`` or ``ppo_megatron_trainer.yml``
|
||||
to work as preffix when saving checkpoints, which is ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``.
|
||||
|
||||
So the inner checkpoint structure of **FSDP** is like:
|
||||
|
||||
.. code::
|
||||
|
||||
checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||
├── global_steps_${i}
|
||||
│ ├── actor
|
||||
│ │ ├── model_world_size_{self.world_size}_rank_{self.rank}.pt
|
||||
│ │ ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt
|
||||
│ │ └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt
|
||||
│ ├── actor_huggingface
|
||||
│ ├── critic
|
||||
│ │ ├── model_world_size_{self.world_size}_rank_{self.rank}.pt
|
||||
│ │ ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt
|
||||
│ │ └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt
|
||||
│ └── critic_huggingface
|
||||
└── latest_checkpointed_iteration.txt
|
||||
|
||||
All model shards, optimizers and extra states are stored togather, in a sharded and distributed way.
|
||||
|
||||
While **Megatron** current checkpoint structure is:
|
||||
|
||||
.. code::
|
||||
|
||||
checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||
├── global_steps_${i}
|
||||
│ ├── actor
|
||||
│ │ ├── huggingface # default save tokenizer, save huggingface model if include ``hf_mode`` in checkpoint.contents
|
||||
│ │ ├── model # save sharded model, naming the same as Megatron
|
||||
│ │ │ ├── mp_rank_xx_yyy # xx is tp_rank in 2 digits, yyy is pp_rank in 3 digits
|
||||
│ │ │ │ └── model_states.pt
|
||||
│ │ │ └── mp_rank_xx_xxx
|
||||
│ │ ├── optim
|
||||
│ │ │ ├── distrib_optim_pp{x}_tp{y}.pt
|
||||
│ │ │ └── distrib_optim_pp{x}_tp{y}.pt
|
||||
│ │ └── rng_states
|
||||
│ └── critic
|
||||
│ │ ├── huggingface
|
||||
│ │ ├── model
|
||||
│ │ ├── optim
|
||||
│ │ └── rng_states
|
||||
└── latest_checkpointed_iteration.txt
|
||||
|
||||
Convert FSDP and Megatron Checkpoints to HuggingFace Format Model
|
||||
-----------------------------------------------------------------
|
||||
|
||||
We provide a tool to convert the FSDP and Megatron checkpoints to HuggingFace format model.
|
||||
The tool is located in ``scripts/model_merger.py``.
|
||||
|
||||
The arguments are as follows:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
usage: model_merger.py [-h] [--backend {fsdp,megatron}]
|
||||
[--tie-word-embedding whether the model share embedding weights]
|
||||
[--is-value-model whether the model is critic model]
|
||||
[--hf_model_path $original_model_path, like {Qwen/Qwen2-7B}]
|
||||
[--local_dir $local_directory saved fsdp or megatron models]
|
||||
[--target_dir $target_dir to save converted models, default is tmp]
|
||||
[--hf_upload_path $huggingface_repo to upload]
|
||||
|
||||
So example use of Megatron model merger is:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
python3 scripts/model_merger.py --backend megatron \
|
||||
--is-value-model \
|
||||
--hf_model_path Qwen/Qwen2-7B \
|
||||
--local_dir checkpoints/verl_megatron_gsm8k_examples/deepseek_megatron_checkpoint_saveload/global_step_1/actor/model
|
||||
|
||||
Megatron Merger details
|
||||
-----------------------
|
||||
|
||||
Current implement of decoder layers uses ``nn.ModuleList`` to store the layers,
|
||||
and thus the model layers on every PP rank and VPP rank starts their index from 0.
|
||||
|
||||
There are 3 ways to correct this behavior:
|
||||
|
||||
1. Modify the decoder layer's state_dict, add ``offset`` to each layer's index, thus rewrite ``nn.ModuleList`` implementation.
|
||||
2. Modify the layer index when saving checkpoint and recover them when loading checkpoint.
|
||||
3. The Checkpoint merger do this work, calculate the actual ``offset`` from ``state_dict`` only, a little complex.
|
||||
|
||||
Current implementation use solution 2.
|
||||
|
||||
Original Checkpoint Utils
|
||||
-------------------------
|
||||
|
||||
Original Checkpoint Utils refer to original checkpoint implementation in ``verl/models/[model]/megatron/checkpoint_utils``.
|
||||
|
||||
We only need ``[model]_loader.py`` in original checkpoint utils now, since we get rid of storing ``hf_model`` every time (which is not recommended for large model training, try only saving sharded models if you can).
|
||||
|
||||
.. note::
|
||||
|
||||
Note that ``[model]_loader`` only support environments where **storage clusters are able to connect with every calculation nodes**.
|
||||
Because it utilizes **sharded load way to minimize the loading checkpoint overhead**.
|
||||
Every rank loads its own data from ``state_dict`` which can be accessed by all of them.
|
||||
While there is also no need to broadcast among DP ranks, since the saved state_dict is only produced by DP rank 0.
|
||||
|
||||
For users who can **only place the huggingface model on one device**, we keep the original costly implementation in ``[model]_loader_deprecated``. In this implementation, rank 0 broadcast all weights to each tp and pp rank, and then dp rank 0 broadcast to all dp ranks. There may be at risks of OOM.
|
||||
|
||||
To use deprecated loader, change the import package of ``load_state_dict_to_megatron_llama``.
|
@ -381,8 +381,8 @@ Trainer
|
||||
critic_warmup: 0
|
||||
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} # hdfs checkpoint path
|
||||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} # local checkpoint path
|
||||
resume_mode: auto # or disable or resume_path if
|
||||
resume_from_path: False
|
||||
resume_mode: auto # or disable or resume_path if resume_from_path is set
|
||||
resume_from_path: null
|
||||
remove_previous_ckpt_in_save: False
|
||||
del_local_ckpt_after_load: False
|
||||
|
||||
|
@ -73,7 +73,6 @@ verl is fast with:
|
||||
:caption: Performance Tuning Guide
|
||||
|
||||
perf/perf_tuning
|
||||
README_vllm0.7.md
|
||||
README_vllm0.8.md
|
||||
|
||||
.. toctree::
|
||||
@ -90,6 +89,7 @@ verl is fast with:
|
||||
advance/dpo_extension
|
||||
advance/fsdp_extension
|
||||
advance/megatron_extension
|
||||
advance/checkpoint
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -10,15 +10,22 @@ Requirements
|
||||
verl supports various backends. Currently, the following configurations are available:
|
||||
|
||||
- **FSDP** and **Megatron-LM** (optional) for training.
|
||||
- **SGLang**, **vLLM** and **TGI** for rollout generation.
|
||||
- **SGLang** (preview), **vLLM** and **TGI** for rollout generation.
|
||||
|
||||
Training backends
|
||||
------------------
|
||||
Choices of Backend Engines
|
||||
----------------------------
|
||||
|
||||
1. Training:
|
||||
|
||||
We recommend using **FSDP** backend to investigate, research and prototype different models, datasets and RL algorithms. The guide for using FSDP backend can be found in :doc:`FSDP Workers<../workers/fsdp_workers>`.
|
||||
|
||||
For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support Megatron-LM v0.11 [1]_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`.
|
||||
|
||||
2. Inference:
|
||||
|
||||
For inference, the integration of both vllm v0.6.3 and v0.8.2 is stable. For huggingface TGI integration, it is usually used for debugging and single GPU exploration. Regarding sglang integration, it is blazing fast and under rapid development - we release it as a preview feature and please give us feedbacks.
|
||||
|
||||
|
||||
Install from docker image
|
||||
-------------------------
|
||||
|
||||
@ -56,9 +63,9 @@ Image and tag: ``whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-
|
||||
|
||||
|
||||
Install verl-SGLang from scratch
|
||||
-------------------------------------
|
||||
---------------------------------------------
|
||||
|
||||
SGLang has largely support the rearch and inference workload at xAI. For verl-sglang installation, ignore the version conflicts reported by pip with vllm. And, SGLang support native API for RLHF, do not need to patch a single line of code.
|
||||
If you want to use SGLang instead of vllm for inference, please follow the instruction here. SGLang has largely support the rearch and inference workload at xAI. For verl-sglang installation, ignore the version conflicts reported by pip with vllm. And, SGLang support native API for RLHF, do not need to patch a single line of code.
|
||||
|
||||
The following steps are quick installation guide for verl-SGLang.
|
||||
|
||||
@ -72,15 +79,15 @@ The following steps are quick installation guide for verl-SGLang.
|
||||
git clone https://github.com/volcengine/verl verl-sglang && cd verl-sglang
|
||||
python3 -m uv pip install .
|
||||
|
||||
# Install the latest stable version of sglang with verl support, currently, the latest version is 0.4.3.post3
|
||||
# Install the latest stable version of sglang with verl support, currently, the latest version is 0.4.4.post3
|
||||
# For SGLang installation, you can also refer to https://docs.sglang.ai/start/install.html
|
||||
python3 -m uv pip install "sglang[all]==0.4.4.post1" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
|
||||
python3 -m uv pip install "sglang[all]==0.4.4.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
|
||||
|
||||
|
||||
Install from custom environment
|
||||
---------------------------------------------
|
||||
|
||||
To manage environment, we recommend using conda:
|
||||
If you do not want to use the official docker image, here is how to start from your own environment. To manage environment, we recommend using conda:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
@ -111,14 +118,9 @@ Megatron is optional. It's dependencies can be setup as below:
|
||||
|
||||
# transformer engine
|
||||
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
|
||||
# megatron core
|
||||
pip3 install megatron-core==0.11.0
|
||||
|
||||
git clone -b core_v0.11.0 https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip3 install -e .
|
||||
|
||||
# megatron core v0.4.0: clone and apply the patch
|
||||
# You can also get the patched Megatron code patch via
|
||||
# git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM
|
||||
|
||||
Install with AMD GPUs - ROCM kernel support
|
||||
------------------------------------------------------------------
|
||||
|
@ -18,7 +18,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
@ -26,14 +25,12 @@ python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \
|
||||
critic.optim.lr=2e-5 \
|
||||
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
critic.ppo_micro_batch_size_per_gpu=4 \
|
||||
critic.megatron.pipeline_model_parallel_size=2 \
|
||||
critic.megatron.virtual_pipeline_model_parallel_size=2 \
|
||||
critic.megatron.tensor_model_parallel_size=4 \
|
||||
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||
trainer.critic_warmup=0 \
|
||||
@ -42,10 +39,10 @@ python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
trainer.experiment_name='deepseek_megatron_checkpoint_saveload' \
|
||||
trainer.n_gpus_per_node=16 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=100 \
|
||||
trainer.save_freq=50 \
|
||||
trainer.test_freq=1 \
|
||||
trainer.total_epochs=15 \
|
||||
trainer.total_training_steps=100 $@
|
||||
trainer.total_training_steps=50 $@
|
||||
|
||||
|
||||
python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
@ -60,7 +57,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
@ -68,14 +64,12 @@ python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \
|
||||
critic.optim.lr=2e-5 \
|
||||
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
critic.ppo_micro_batch_size_per_gpu=4 \
|
||||
critic.megatron.pipeline_model_parallel_size=2 \
|
||||
critic.megatron.virtual_pipeline_model_parallel_size=2 \
|
||||
critic.megatron.tensor_model_parallel_size=4 \
|
||||
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||
trainer.critic_warmup=0 \
|
||||
|
@ -164,8 +164,8 @@ trainer:
|
||||
n_gpus_per_node: 8
|
||||
save_freq: -1
|
||||
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||
resume_mode: auto # or auto or resume_path if
|
||||
resume_from_path: False
|
||||
resume_mode: auto # or disable or resume_path if resume_from_path is set
|
||||
resume_from_path: null
|
||||
test_freq: -1
|
||||
critic_warmup: 0
|
||||
default_hdfs_dir: null
|
||||
|
@ -127,7 +127,7 @@ class PRIMERewardModelWorker(Worker):
|
||||
|
||||
if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1:
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
apply_monkey_patch(model=reward_module)
|
||||
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
||||
|
||||
# some parameters may not in torch_dtype
|
||||
reward_module.to(torch_dtype)
|
||||
|
@ -288,10 +288,10 @@ class RayPRIMETrainer(RayPPOTrainer):
|
||||
print('Training from scratch')
|
||||
return 0
|
||||
else:
|
||||
if not (self.config.trainer.resume_from_path and global_step_folder is not None):
|
||||
assert isinstance(self.config.trainer.resume_mode, str), "resume ckpt must be str type"
|
||||
assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps"
|
||||
global_step_folder = self.config.trainer.resume_mode
|
||||
if self.config.trainer.resume_mode == "resume_path":
|
||||
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
|
||||
assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps"
|
||||
global_step_folder = self.config.trainer.resume_from_path
|
||||
if not os.path.isabs(global_step_folder):
|
||||
working_dir = os.getcwd()
|
||||
global_step_folder = os.path.join(working_dir, global_step_folder)
|
||||
|
21
requirements_sglang.txt
Normal file
21
requirements_sglang.txt
Normal file
@ -0,0 +1,21 @@
|
||||
# requirements.txt records the full set of dependencies for development
|
||||
accelerate
|
||||
codetiming
|
||||
datasets
|
||||
dill
|
||||
flash-attn
|
||||
hydra-core
|
||||
numpy
|
||||
pandas
|
||||
peft
|
||||
pyarrow>=15.0.0
|
||||
pybind11
|
||||
pylatexenc
|
||||
ray[default]>=2.10
|
||||
tensordict<=0.6.2
|
||||
torchdata
|
||||
torchvision
|
||||
transformers
|
||||
wandb
|
||||
sglang[all]==0.4.4.post3
|
||||
torch-memory-saver>=0.0.5
|
@ -163,10 +163,3 @@ if __name__ == '__main__':
|
||||
repo_id=args.hf_upload_path,
|
||||
repo_type="model"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
6
setup.py
6
setup.py
@ -46,7 +46,11 @@ GEO_REQUIRES = ['mathruler']
|
||||
GPU_REQUIRES = ['liger-kernel', 'flash-attn']
|
||||
MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
|
||||
VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2']
|
||||
SGLANG_REQUIRES = ['tensordict<=0.6.2', 'sglang[all]==0.4.4']
|
||||
SGLANG_REQUIRES = [
|
||||
'tensordict<=0.6.2',
|
||||
'sglang[all]==0.4.4.post3',
|
||||
'torch-memory-saver>=0.0.5'
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
'test': TEST_REQUIRES,
|
||||
|
@ -11,7 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import copy
|
||||
import torch.distributed
|
||||
@ -23,16 +27,45 @@ from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_uly
|
||||
from verl.workers.sharding_manager import FSDPUlyssesShardingManager
|
||||
from verl.protocol import DataProto
|
||||
from flash_attn.bert_padding import unpad_input, index_first_axis, rearrange
|
||||
from transformers import LlamaConfig, Qwen2Config
|
||||
from transformers import LlamaConfig, Qwen2Config, PretrainedConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
|
||||
# TODO(sgm): add more models for test
|
||||
# we only need one scale for each model
|
||||
test_configs = {
|
||||
'llama': (LlamaConfig(num_hidden_layers=2), apply_monkey_patch),
|
||||
'qwen2': (Qwen2Config(num_hidden_layers=2), apply_monkey_patch)
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceParallelConfig:
|
||||
config: PretrainedConfig
|
||||
sp_size: int
|
||||
is_valid: bool
|
||||
|
||||
|
||||
def test_configs():
|
||||
return [
|
||||
SequenceParallelConfig(LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32),
|
||||
sp_size=8,
|
||||
is_valid=True),
|
||||
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2,
|
||||
num_attention_heads=28,
|
||||
num_key_value_heads=4,
|
||||
hidden_size=3584),
|
||||
sp_size=4,
|
||||
is_valid=True),
|
||||
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2,
|
||||
num_attention_heads=28,
|
||||
num_key_value_heads=4,
|
||||
hidden_size=3584),
|
||||
sp_size=8,
|
||||
is_valid=False),
|
||||
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4),
|
||||
sp_size=4,
|
||||
is_valid=True),
|
||||
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4),
|
||||
sp_size=8,
|
||||
is_valid=True),
|
||||
]
|
||||
|
||||
|
||||
def sync_model_parameters_global(layer):
|
||||
@ -41,11 +74,23 @@ def sync_model_parameters_global(layer):
|
||||
torch.distributed.broadcast(tensor=p.data, src=0)
|
||||
|
||||
|
||||
def test_hf_casual_fwd():
|
||||
@pytest.mark.parametrize("test_config", test_configs())
|
||||
def test_hf_casual_fwd_bwd(test_config):
|
||||
if not torch.distributed.is_initialized():
|
||||
initialize_global_process_group()
|
||||
|
||||
context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError)
|
||||
with context:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
_hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size)
|
||||
|
||||
# TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort`
|
||||
# torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def _hf_casual_fwd(config, sp_size, dp_size):
|
||||
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
|
||||
|
||||
sp_size = 8
|
||||
dp_size = 1
|
||||
ulysses_device_mesh = init_device_mesh(device_type='cuda',
|
||||
mesh_shape=(dp_size, sp_size),
|
||||
mesh_dim_names=('dp', 'sp'))
|
||||
@ -55,75 +100,71 @@ def test_hf_casual_fwd():
|
||||
seqlen = 128
|
||||
response_length = 127
|
||||
|
||||
for model_name, (config, apply_monkey_patch) in test_configs.items():
|
||||
# patch before load
|
||||
with torch.device('cuda'):
|
||||
model = AutoModelForCausalLM.from_config(config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2')
|
||||
apply_monkey_patch(model)
|
||||
model = model.to(device='cuda')
|
||||
sync_model_parameters_global(model)
|
||||
# patch before load
|
||||
with torch.device('cuda'):
|
||||
model = AutoModelForCausalLM.from_config(config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2')
|
||||
apply_monkey_patch(model, sp_size)
|
||||
model = model.to(device='cuda')
|
||||
sync_model_parameters_global(model)
|
||||
|
||||
# different rank will generate different input_ids following fsdp
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
|
||||
attention_mask = create_random_mask(input_ids=input_ids,
|
||||
max_ratio_of_left_padding=0,
|
||||
max_ratio_of_valid_token=0.9,
|
||||
min_ratio_of_valid_token=0.8)
|
||||
position_ids = compute_position_id_with_mask(
|
||||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
|
||||
# different rank will generate different input_ids following fsdp
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
|
||||
attention_mask = create_random_mask(input_ids=input_ids,
|
||||
max_ratio_of_left_padding=0,
|
||||
max_ratio_of_valid_token=0.9,
|
||||
min_ratio_of_valid_token=0.8)
|
||||
position_ids = compute_position_id_with_mask(
|
||||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
|
||||
|
||||
model_inputs = {
|
||||
'input_ids': input_ids.cuda(),
|
||||
'attention_mask': attention_mask.cuda(),
|
||||
'position_ids': position_ids.int().cuda()
|
||||
}
|
||||
model_inputs = {
|
||||
'input_ids': input_ids.cuda(),
|
||||
'attention_mask': attention_mask.cuda(),
|
||||
'position_ids': position_ids.int().cuda()
|
||||
}
|
||||
|
||||
model_inputs = DataProto.from_dict(model_inputs)
|
||||
model_inputs = DataProto.from_dict(model_inputs)
|
||||
|
||||
# 1. perform ulysses forward
|
||||
with sharding_manager:
|
||||
model_inputs = sharding_manager.preprocess_data(model_inputs)
|
||||
input_ids = model_inputs.batch['input_ids']
|
||||
attention_mask = model_inputs.batch['attention_mask']
|
||||
position_ids = model_inputs.batch['position_ids']
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
# 1. perform ulysses forward
|
||||
with sharding_manager:
|
||||
model_inputs = sharding_manager.preprocess_data(model_inputs)
|
||||
input_ids = model_inputs.batch['input_ids']
|
||||
attention_mask = model_inputs.batch['attention_mask']
|
||||
position_ids = model_inputs.batch['position_ids']
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
# slice input tensor for ulysses
|
||||
# input_ids are padded and sliced
|
||||
# postition_ids are only padded but not sliced
|
||||
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
|
||||
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
|
||||
# slice input tensor for ulysses
|
||||
# input_ids are padded and sliced
|
||||
# postition_ids are only padded but not sliced
|
||||
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
|
||||
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
|
||||
|
||||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
|
||||
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
|
||||
use_cache=False).logits # (1, total_nnz/n, vocab_size)
|
||||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
|
||||
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
|
||||
use_cache=False).logits # (1, total_nnz/n, vocab_size)
|
||||
|
||||
# all_gather output
|
||||
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
# all_gather output
|
||||
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
|
||||
# 2. perform normal forward
|
||||
set_ulysses_sequence_parallel_group(None)
|
||||
logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad,
|
||||
use_cache=False).logits # (1, total_nnz, vocab_size)
|
||||
# 2. perform normal forward
|
||||
set_ulysses_sequence_parallel_group(None)
|
||||
logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad,
|
||||
use_cache=False).logits # (1, total_nnz, vocab_size)
|
||||
|
||||
mean_local = logits_rmpad_local.mean()
|
||||
mean_full = logits_full.mean()
|
||||
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
|
||||
print(f'Fwd Check pass')
|
||||
mean_local = logits_rmpad_local.mean()
|
||||
mean_full = logits_full.mean()
|
||||
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
|
||||
|
||||
|
||||
def test_hf_casual_fwd_bwd():
|
||||
def _hf_casual_fwd_bwd(config, sp_size, dp_size):
|
||||
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
|
||||
|
||||
sp_size = 8
|
||||
dp_size = 1
|
||||
ulysses_device_mesh = init_device_mesh(device_type='cuda',
|
||||
mesh_shape=(dp_size, sp_size),
|
||||
mesh_dim_names=('dp', 'sp'))
|
||||
@ -133,82 +174,78 @@ def test_hf_casual_fwd_bwd():
|
||||
seqlen = 128
|
||||
response_length = 127
|
||||
|
||||
for model_name, (config, apply_monkey_patch) in test_configs.items():
|
||||
# patch before load
|
||||
with torch.device('cuda'):
|
||||
model = AutoModelForCausalLM.from_config(config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2')
|
||||
apply_monkey_patch(model)
|
||||
model = model.to(device='cuda')
|
||||
sync_model_parameters_global(model)
|
||||
# patch before load
|
||||
with torch.device('cuda'):
|
||||
model = AutoModelForCausalLM.from_config(config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2')
|
||||
apply_monkey_patch(model, sp_size)
|
||||
model = model.to(device='cuda')
|
||||
sync_model_parameters_global(model)
|
||||
|
||||
# different rank will generate different input_ids following fsdp
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
|
||||
attention_mask = create_random_mask(input_ids=input_ids,
|
||||
max_ratio_of_left_padding=0,
|
||||
max_ratio_of_valid_token=0.9,
|
||||
min_ratio_of_valid_token=0.8)
|
||||
position_ids = compute_position_id_with_mask(
|
||||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
|
||||
# different rank will generate different input_ids following fsdp
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
|
||||
attention_mask = create_random_mask(input_ids=input_ids,
|
||||
max_ratio_of_left_padding=0,
|
||||
max_ratio_of_valid_token=0.9,
|
||||
min_ratio_of_valid_token=0.8)
|
||||
position_ids = compute_position_id_with_mask(
|
||||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
|
||||
|
||||
model_inputs = {
|
||||
'input_ids': input_ids.cuda(),
|
||||
'attention_mask': attention_mask.cuda(),
|
||||
'position_ids': position_ids.int().cuda()
|
||||
}
|
||||
model_inputs = {
|
||||
'input_ids': input_ids.cuda(),
|
||||
'attention_mask': attention_mask.cuda(),
|
||||
'position_ids': position_ids.int().cuda()
|
||||
}
|
||||
|
||||
model_inputs = DataProto.from_dict(model_inputs)
|
||||
model_inputs = DataProto.from_dict(model_inputs)
|
||||
|
||||
# 1. perform ulysses forward
|
||||
with sharding_manager:
|
||||
model_inputs = sharding_manager.preprocess_data(model_inputs)
|
||||
input_ids = model_inputs.batch['input_ids']
|
||||
attention_mask = model_inputs.batch['attention_mask']
|
||||
position_ids = model_inputs.batch['position_ids']
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
# 1. perform ulysses forward
|
||||
with sharding_manager:
|
||||
model_inputs = sharding_manager.preprocess_data(model_inputs)
|
||||
input_ids = model_inputs.batch['input_ids']
|
||||
attention_mask = model_inputs.batch['attention_mask']
|
||||
position_ids = model_inputs.batch['position_ids']
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
# slice input tensor for ulysses
|
||||
# input_ids are padded and sliced
|
||||
# postition_ids are only padded but not sliced
|
||||
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
|
||||
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
|
||||
# slice input tensor for ulysses
|
||||
# input_ids are padded and sliced
|
||||
# postition_ids are only padded but not sliced
|
||||
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
|
||||
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
|
||||
|
||||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
|
||||
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
|
||||
use_cache=False).logits # (1, total_nnz/n, vocab_size)
|
||||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
|
||||
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
|
||||
use_cache=False).logits # (1, total_nnz/n, vocab_size)
|
||||
|
||||
# all_gather output
|
||||
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
# all_gather output
|
||||
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
|
||||
# 2. perform normal forward
|
||||
set_ulysses_sequence_parallel_group(None)
|
||||
input_ids_full = copy.deepcopy(input_ids_rmpad)
|
||||
position_ids_full = copy.deepcopy(position_ids_rmpad)
|
||||
model_no_sp = copy.deepcopy(model)
|
||||
logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full,
|
||||
use_cache=False).logits # (1, total_nnz, vocab_size)
|
||||
# 2. perform normal forward
|
||||
set_ulysses_sequence_parallel_group(None)
|
||||
input_ids_full = copy.deepcopy(input_ids_rmpad)
|
||||
position_ids_full = copy.deepcopy(position_ids_rmpad)
|
||||
model_no_sp = copy.deepcopy(model)
|
||||
logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full,
|
||||
use_cache=False).logits # (1, total_nnz, vocab_size)
|
||||
|
||||
mean_local = logits_rmpad_local.mean()
|
||||
mean_full = logits_full.mean()
|
||||
mean_local = logits_rmpad_local.mean()
|
||||
mean_full = logits_full.mean()
|
||||
|
||||
mean_full.backward()
|
||||
mean_local.backward()
|
||||
mean_full.backward()
|
||||
mean_local.backward()
|
||||
|
||||
# 3. check the gradients
|
||||
grad = model.model.layers[0].self_attn.q_proj.weight.grad
|
||||
grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad
|
||||
torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)
|
||||
|
||||
print(f'Fwd + BWD Check pass')
|
||||
# 3. check the gradients
|
||||
grad = model.model.layers[0].self_attn.q_proj.weight.grad
|
||||
grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad
|
||||
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
|
||||
torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
local_rank, rank, world_size = initialize_global_process_group()
|
||||
test_hf_casual_fwd()
|
||||
test_hf_casual_fwd_bwd()
|
||||
pytest.main([__file__, "-svv"])
|
||||
|
@ -29,6 +29,18 @@ from verl.utils.ulysses import (
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch,
|
||||
seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)
|
||||
"""
|
||||
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
|
||||
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
def _ulysses_flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
@ -54,6 +66,17 @@ def _ulysses_flash_attention_forward(
|
||||
########## AlltoAll for Ulysses ##########
|
||||
if ulysses_sp_size > 1:
|
||||
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
|
||||
|
||||
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
|
||||
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
|
||||
# For example:
|
||||
# - nheads_k=4, sp=8, repeats=2
|
||||
# - nheads_k=8, sp=8, repeats=1
|
||||
# - nheads_k=16, sp=8, repeats=1
|
||||
repeats = max(ulysses_sp_size // key_states.size(2), 1)
|
||||
key_states = repeat_kv(key_states, repeats)
|
||||
value_states = repeat_kv(value_states, repeats)
|
||||
|
||||
# (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
|
||||
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
|
||||
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
|
||||
@ -84,10 +107,17 @@ def _ulysses_flash_attention_forward(
|
||||
return attn_output
|
||||
|
||||
|
||||
def apply_monkey_patch(model: PreTrainedModel):
|
||||
def apply_monkey_patch(model: PreTrainedModel, ulysses_sp_size: int):
|
||||
"""Replace _flash_attention_forward to _ulysses_flash_attention_forward"""
|
||||
module = sys.modules[model.__module__]
|
||||
|
||||
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads
|
||||
assert num_attention_heads % ulysses_sp_size == 0, \
|
||||
f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
|
||||
assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (
|
||||
f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
|
||||
f"or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,"
|
||||
f"kv heads are repeated to ensure correctness.")
|
||||
# TODO: VLM models only, unify monkey patch to LLM models.
|
||||
if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope
|
||||
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
|
||||
|
7
verl/third_party/sglang/parallel_state.py
vendored
7
verl/third_party/sglang/parallel_state.py
vendored
@ -9,16 +9,13 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.distributed.parallel_state import (
|
||||
import sglang.srt.distributed.parallel_state as ps
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
get_world_group,
|
||||
init_distributed_environment,
|
||||
init_model_parallel_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
|
||||
- We assume the Megatron tp+dp+pp world is already established before calling this function.
|
||||
|
@ -183,8 +183,8 @@ trainer:
|
||||
n_gpus_per_node: 8
|
||||
save_freq: -1
|
||||
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||
resume_mode: auto # or disable or resume_path if
|
||||
resume_from_path: False
|
||||
resume_mode: auto # or disable or resume_path if resume_from_path is set
|
||||
resume_from_path: null
|
||||
del_local_ckpt_after_load: False
|
||||
test_freq: 2
|
||||
critic_warmup: 0
|
||||
|
@ -187,8 +187,8 @@ trainer:
|
||||
n_gpus_per_node: 8
|
||||
save_freq: -1
|
||||
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||
resume_mode: auto # or disable or resume_path if
|
||||
resume_from_path: False
|
||||
resume_mode: auto # or disable or resume_path if resume_from_path is set
|
||||
resume_from_path: null
|
||||
test_freq: -1
|
||||
critic_warmup: 0
|
||||
default_hdfs_dir: null
|
||||
|
@ -210,7 +210,7 @@ class FSDPSFTTrainer(object):
|
||||
|
||||
if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
apply_monkey_patch(model=self.model)
|
||||
apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)
|
||||
|
||||
# Apply Liger kernel if use_liger is enabled
|
||||
if self.config.model.get('use_liger', False):
|
||||
|
@ -705,10 +705,10 @@ class RayPPOTrainer(object):
|
||||
print('Training from scratch')
|
||||
return 0
|
||||
else:
|
||||
if not (self.config.trainer.resume_from_path and global_step_folder is not None):
|
||||
assert isinstance(self.config.trainer.resume_mode, str), "resume ckpt must be str type"
|
||||
assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps"
|
||||
global_step_folder = self.config.trainer.resume_mode
|
||||
if self.config.trainer.resume_mode == "resume_path":
|
||||
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
|
||||
assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps"
|
||||
global_step_folder = self.config.trainer.resume_from_path
|
||||
if not os.path.isabs(global_step_folder):
|
||||
working_dir = os.getcwd()
|
||||
global_step_folder = os.path.join(working_dir, global_step_folder)
|
||||
|
@ -1 +1 @@
|
||||
0.2.0.dev
|
||||
0.3.0.post1
|
||||
|
@ -204,7 +204,7 @@ class ActorRolloutRefWorker(Worker):
|
||||
|
||||
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
apply_monkey_patch(model=actor_module)
|
||||
apply_monkey_patch(model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
||||
|
||||
# Apply Liger kernel to the model if use_liger is set to True
|
||||
if use_liger:
|
||||
@ -709,7 +709,7 @@ class CriticWorker(Worker):
|
||||
use_remove_padding = config.model.get('use_remove_padding', False)
|
||||
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
apply_monkey_patch(model=critic_module)
|
||||
apply_monkey_patch(model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
||||
|
||||
# some parameters may not in torch_dtype
|
||||
critic_module.to(torch_dtype)
|
||||
@ -967,7 +967,7 @@ class RewardModelWorker(Worker):
|
||||
|
||||
if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1:
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
apply_monkey_patch(model=reward_module)
|
||||
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
||||
|
||||
reward_module.to(torch.bfloat16)
|
||||
|
||||
|
@ -655,9 +655,10 @@ class CriticWorker(MegatronWorker):
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def load_checkpoint(self, checkpoint_path, del_local_after_load=True):
|
||||
self.hf_config = self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path,
|
||||
del_local_after_load=del_local_after_load)
|
||||
def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True):
|
||||
self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path,
|
||||
hdfs_path=hdfs_path,
|
||||
del_local_after_load=del_local_after_load)
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None):
|
||||
|
@ -146,6 +146,7 @@ class SGLangRollout(BaseRollout):
|
||||
dtype=config.dtype,
|
||||
mem_fraction_static=config.gpu_memory_utilization,
|
||||
device_mesh_cpu=device_mesh_cpu["tp"],
|
||||
enable_memory_saver=True,
|
||||
base_gpu_id=0,
|
||||
gpu_id_step=1,
|
||||
# NOTE(Chenyang): if you want to debug the sglang engine
|
||||
@ -228,7 +229,7 @@ class SGLangRollout(BaseRollout):
|
||||
top_k=-1,
|
||||
ignore_eos=False,
|
||||
min_new_tokens=0,
|
||||
max_new_tokens=4096,
|
||||
max_new_tokens=self.config.response_length,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
|
@ -80,11 +80,13 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
self.gen_random_states = None
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.empty_cache()
|
||||
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
|
||||
params = self.module.state_dict()
|
||||
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
|
||||
# Copy, not share memory
|
||||
load_format = None if self.full_params else 'dtensor'
|
||||
self.inference_engine.resume_memory_occupation()
|
||||
|
||||
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
|
||||
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
|
||||
@ -106,7 +108,7 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger)
|
||||
self.inference_engine.release_memory_occupation
|
||||
self.inference_engine.release_memory_occupation()
|
||||
log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger)
|
||||
|
||||
# self.module.to('cuda')
|
||||
|
Reference in New Issue
Block a user