From 7ddb9b29f07d5dc6aa65f45302fc655c14ad2511 Mon Sep 17 00:00:00 2001 From: Houmin Wei Date: Mon, 13 Oct 2025 08:18:09 +0800 Subject: [PATCH] [misc] feat: prototype deprecate DataProto and replace with Tensordict: part 3 (#3600) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. This PR continues the work started in PR #3567 by deprecating and removing the left_right padding mode 1. Implement no-padding mode for Megatron engine using nested tensors in sft trainer 2. Deprecating left_right padding mode for FSDP/Megatron engine 3. Introduces a transformation layer within Actor/Critic workers, see more [here](https://github.com/volcengine/verl/blob/main/docs/workers/model_engine.rst) - **Input Format**:​​ Actor/Critic workers continue to receive data in left_rightpadded format. - ​​**Transformation**:​​ This layer dynamically converts left_rightpadded data into the no-padding format using nested tensors. - **Engine Format**:​​ FSDP and Megatron engines now operate exclusively using the no-padding data format by default. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --- tests/special_e2e/sft/run_sft_engine_gsm8k.sh | 4 +- tests/special_e2e/sft/test_sft_engine_all.sh | 20 +-- .../test_multiturn_sft_dataset_on_cpu.py | 18 +-- verl/models/mcore/__init__.py | 2 + verl/models/mcore/model_forward.py | 61 +++++++- verl/models/mcore/registry.py | 37 ++++- verl/models/mcore/util.py | 145 ++++++++++++++++++ verl/trainer/config/sft_trainer_engine.yaml | 5 +- verl/utils/dataset/multiturn_sft_dataset.py | 75 +-------- verl/workers/engine/fsdp/transformer_impl.py | 102 ++++-------- .../engine/megatron/transformer_impl.py | 121 ++++++--------- verl/workers/engine/utils.py | 5 +- verl/workers/roles/actor.py | 13 +- verl/workers/roles/critic.py | 8 +- verl/workers/roles/utils/losses.py | 9 +- verl/workers/roles/utils/padding.py | 119 ++++++++++++++ 16 files changed, 482 insertions(+), 262 deletions(-) create mode 100644 verl/workers/roles/utils/padding.py diff --git a/tests/special_e2e/sft/run_sft_engine_gsm8k.sh b/tests/special_e2e/sft/run_sft_engine_gsm8k.sh index 90f8d8035..f166268b6 100644 --- a/tests/special_e2e/sft/run_sft_engine_gsm8k.sh +++ b/tests/special_e2e/sft/run_sft_engine_gsm8k.sh @@ -29,7 +29,7 @@ PP_SIZE=${PP_SIZE:-1} VPP_SIZE=${VPP_SIZE:-null} CP_SIZE=${CP_SIZE:-1} -PAD_MODE=${PAD_MODE:-left_right} +PAD_MODE=${PAD_MODE:-no_padding} USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} @@ -80,8 +80,6 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \ data.train_files="${TRAIN_FILES}" \ data.val_files="${VAL_FILES}" \ data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ data.pad_mode=${PAD_MODE} \ data.truncation=error \ data.use_dynamic_bsz=True \ diff --git a/tests/special_e2e/sft/test_sft_engine_all.sh b/tests/special_e2e/sft/test_sft_engine_all.sh index 8fed89fe0..62232b4f0 100644 --- a/tests/special_e2e/sft/test_sft_engine_all.sh +++ b/tests/special_e2e/sft/test_sft_engine_all.sh @@ -9,15 +9,6 @@ echo "run with single gpu as golden" BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh # test with fsdp 1 -echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right" -BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh -echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right" -BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh -echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right" -BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh -echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right" -BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh - echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding" BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding" @@ -27,18 +18,14 @@ BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_pa echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding" BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh -# test use_remove_padding and pad_mode left_right/no_padding -echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False" -BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh +# test use_remove_padding and pad_mode no_padding echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False" BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh # test with fsdp 2 -echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right" -BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh -echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding" -BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh +echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2" +BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2" BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh @@ -50,6 +37,7 @@ BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/spe # test with megatron echo "run with tp1 pp1 cp1 num_gpus1" BACKEND=megatron TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 NUM_GPUS=1 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh + echo "run with tp2 pp2 vpp2 cp1 num_gpus8" BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=1 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh diff --git a/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py index 6e7960a45..0c5bbb650 100644 --- a/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py +++ b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py @@ -177,28 +177,26 @@ def test_multiturn_sft_dataset(): assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding" assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding" - # test left right padding + # test no-padding config = { "max_length": 512, "truncation": "error", "multiturn": {"messages_key": "messages"}, - "pad_mode": "left_right", - "max_prompt_length": 64, - "max_response_length": 64, + "pad_mode": "no_padding", } dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) item0 = dataset[0] - # make sure all the input_ids with attention_mask == 0 are all padding - assert torch.all(item0["input_ids"][item0["attention_mask"] == 0] == tokenizer.pad_token_id) + # Verify that the output contains expected keys for no-padding mode + required_keys = ["input_ids", "position_ids", "loss_mask"] + for key in required_keys: + assert key in item0, f"Missing key {key} in no-padding mode dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key} in no-padding mode" # make sure assistant_text matches with expected - assistant_text = tokenizer.decode(item0["responses"][item0["response_mask"] == 1]) + assistant_text = tokenizer.decode(item0["input_ids"][item0["loss_mask"] == 1]) assert assistant_text == "2+2 equals 4.<|im_end|>\n4+4 equals 8.<|im_end|>\n" - # make sure responses are part of input_ids - assert torch.all(item0["input_ids"][-item0["responses"].shape[0] :] == item0["responses"]) - print("All tests passed!") print("Starting test...") diff --git a/verl/models/mcore/__init__.py b/verl/models/mcore/__init__.py index 29d053177..a0f6e76f3 100644 --- a/verl/models/mcore/__init__.py +++ b/verl/models/mcore/__init__.py @@ -16,6 +16,7 @@ from .registry import ( get_mcore_forward_fn, get_mcore_forward_fused_fn, + get_mcore_forward_no_padding_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model, @@ -27,4 +28,5 @@ __all__ = [ "get_mcore_forward_fn", "get_mcore_weight_converter", "get_mcore_forward_fused_fn", + "get_mcore_forward_no_padding_fn", ] diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index e70e11f4e..1c49a5e50 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -16,7 +16,14 @@ from verl.utils.megatron_utils import unwrap_model -from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding +from .util import ( + postprocess_packed_seqs, + postprocess_packed_seqs_no_padding, + preprocess_packed_seqs, + preprocess_packed_seqs_no_padding, + recover_left_padding, + remove_left_padding, +) def gptmodel_forward( @@ -146,3 +153,55 @@ def gptmodel_forward_qwen2_5_vl( if value_model and post_process: output = output[..., 0] return output + + +def gptmodel_forward_no_padding( + model, + input_ids, + value_model=False, + pack_seqs=True, + logits_processor=None, + logits_processor_args: dict = None, + **kwargs, +): + """Default forward pass for GPT models with optional sequence packing.""" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + if pack_seqs: + batch_size = input_ids.shape[0] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=None, + packed_seq_params=packed_seq_params, + ) + + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + # print(f'gptmodel_forward_no_padding: {output_dict=}') + output = { + k: postprocess_packed_seqs_no_padding( + v, packed_seq_params, input_ids, batch_size, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs_no_padding( + output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process + ) + else: + raise NotImplementedError("gptmodel_forward_no_padding only supports packed sequences") + + if value_model and post_process: + # output = output[..., 0] + # while using nested tensor, the advanced indexing operation above will result in an error at backward, i.e. + # ValueError: NestedTensor _nested_select_backward_default(grad_output: t, self: jt_all, dim: any, index: any) + # so we use `squeeze` to remove the last dimension + output = output.squeeze(-1) + + return output diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index 1c70c1028..b40bb1a99 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -33,8 +33,15 @@ from .config_converter import ( hf_to_mcore_config_qwen2moe, hf_to_mcore_config_qwen3moe, ) -from .model_forward import gptmodel_forward, gptmodel_forward_qwen2_5_vl -from .model_forward_fused import fused_forward_gptmodel, fused_forward_qwen2_5_vl +from .model_forward import ( + gptmodel_forward, + gptmodel_forward_no_padding, + gptmodel_forward_qwen2_5_vl, +) +from .model_forward_fused import ( + fused_forward_gptmodel, + fused_forward_qwen2_5_vl, +) from .model_initializer import ( BaseModelInitializer, DeepseekV3Model, @@ -116,6 +123,23 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward, } +# Registry for model forward functions +MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: gptmodel_forward_no_padding, + SupportedModel.QWEN2: gptmodel_forward_no_padding, + SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding, + SupportedModel.MIXTRAL: gptmodel_forward_no_padding, + SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding, + SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding, + SupportedModel.LLAMA4: gptmodel_forward_no_padding, + SupportedModel.QWEN3: gptmodel_forward_no_padding, + SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding, + # SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, + SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding, + SupportedModel.GLM4_MOE: gptmodel_forward_no_padding, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, +} + # Registry for model forward functions MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = { SupportedModel.LLAMA: fused_forward_gptmodel, @@ -220,6 +244,15 @@ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: return MODEL_FORWARD_REGISTRY[model] +def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_FORWARD_NOPAD_REGISTRY[model] + + def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: """ Get the forward function for given model architecture. diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index 9904fc60d..a022c2c29 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -162,6 +162,151 @@ def postprocess_packed_seqs( return output_new +def preprocess_packed_seqs_no_padding( + input_ids: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + seqlens_in_batch = input_ids.offsets().diff() + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding + cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i] + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs_no_padding( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + input_ids: torch.Tensor, + batch_size: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + # The reason why we use input_ids.offsets() instead of packed_seq_params.cu_seqlens_q.diff() + # is that the latter one is the padded length, while the former one is the original length. + cu_seqlens = input_ids.offsets() + seq_lens_cpu: list[int] = cu_seqlens.diff().tolist() + + output_new = [] + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new.append(output[0][start_idx : start_idx + s]) + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new.append(tmp[:s_len]) + + output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) + + return output_new_tensor + + def remove_left_padding( input_ids: torch.Tensor, attention_mask: torch.Tensor, diff --git a/verl/trainer/config/sft_trainer_engine.yaml b/verl/trainer/config/sft_trainer_engine.yaml index 1cc4b32fe..b699690fe 100644 --- a/verl/trainer/config/sft_trainer_engine.yaml +++ b/verl/trainer/config/sft_trainer_engine.yaml @@ -23,12 +23,9 @@ data: messages_key: messages # Key for messages list in multi-turn mode tools_key: tools # Key for tools list in multi-turn mode enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode - pad_mode: left_right + pad_mode: no_padding # for right padding max_length: 1024 - # for left right padding - max_prompt_length: 512 - max_response_length: 512 truncation: error balance_dp_token: False # to be implement custom_cls: diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index e669e292b..58583c6a8 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -29,8 +29,6 @@ from transformers import PreTrainedTokenizer from verl.utils import hf_tokenizer from verl.utils.dataset.dataset_utils import DatasetPadMode from verl.utils.fs import copy_local_path_from_hdfs -from verl.utils.model import compute_position_id_with_mask -from verl.utils.torch_functional import pad_sequence_to_length, postprocess_data def convert_nested_value_to_list_recursive(data_item): @@ -55,15 +53,12 @@ class MultiTurnSFTDataset(Dataset): # Set defaults and extract parameters from config if provided config = config or {} self.pad_mode = config.get("pad_mode", "right") - assert self.pad_mode in ["right", "left_right", "no_padding"], ( - f"Expect pad_mode to be 'right', 'left_right' or 'no_padding'. Got {self.pad_mode}" + assert self.pad_mode in ["right", "no_padding"], ( + f"Expect pad_mode to be 'right' or 'no_padding'. Got {self.pad_mode}" ) self.truncation = config.get("truncation", "error") # for right padding self.max_length = config.get("max_length", 1024) - # for left right paddding to be consistent with RL - self.max_prompt_length = config.get("max_prompt_length", 512) - self.max_response_length = config.get("max_response_length", 512) # Get messages_key from the new multiturn config structure multiturn_config = config.get("multiturn", {}) self.messages_key = multiturn_config.get("messages_key", "messages") @@ -320,10 +315,8 @@ class MultiTurnSFTDataset(Dataset): if messages[0]["role"] == "system": assert messages[1]["role"] == "user" assert messages[2]["role"] == "assistant" - prompt_message_length = 2 elif messages[0]["role"] == "user": assert messages[1]["role"] == "assistant" - prompt_message_length = 1 else: raise ValueError(f"Unknown role: {messages[0]['role']}") @@ -365,68 +358,6 @@ class MultiTurnSFTDataset(Dataset): "position_ids": position_ids, "loss_mask": loss_mask, } - elif self.pad_mode == DatasetPadMode.LEFT_RIGHT: - assert self.truncation == "error", "Only support error truncation for left_right pad mode" - prompt_str = self.tokenizer.apply_chat_template( - messages[:prompt_message_length], - tools=tools, - tokenize=False, - add_generation_prompt=True, - enable_thinking=enable_thinking, - **self.apply_chat_template_kwargs, - ) - prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False) - prompt_length = len(prompt_ids) - prompt_ids = input_ids[:prompt_length].unsqueeze(0) - prompt_attention_mask = attention_mask[:prompt_length].unsqueeze(0) - prompt_loss_mask = loss_mask[:prompt_length].unsqueeze(0) - response_ids = input_ids[prompt_length:].unsqueeze(0) - response_attention_mask = attention_mask[prompt_length:].unsqueeze(0) - response_loss_mask = loss_mask[prompt_length:].unsqueeze(0) - - assert prompt_loss_mask.sum().item() == 0 - - prompt_ids, prompt_attention_mask = postprocess_data( - input_ids=prompt_ids, - attention_mask=prompt_attention_mask, - max_length=self.max_prompt_length, - pad_token_id=self.tokenizer.pad_token_id, - left_pad=True, - truncation=self.truncation, - ) - - response_ids, response_attention_mask = postprocess_data( - input_ids=response_ids, - attention_mask=response_attention_mask, - max_length=self.max_response_length, - pad_token_id=self.tokenizer.pad_token_id, - left_pad=False, - truncation=self.truncation, - ) - response_loss_mask = pad_sequence_to_length( - response_loss_mask, max_seq_len=self.max_response_length, pad_token_id=0, left_pad=False - ) - - prompt_ids = prompt_ids[0] - prompt_attention_mask = prompt_attention_mask[0] - response_ids = response_ids[0] - response_attention_mask = response_attention_mask[0] - response_loss_mask = response_loss_mask[0] - - assert response_attention_mask[0].item() == 1 - assert response_loss_mask[0].item() == 1 - - input_ids = torch.cat((prompt_ids, response_ids), dim=0) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=0) - position_ids = compute_position_id_with_mask(attention_mask) - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "responses": response_ids, - "response_mask": response_loss_mask, - } elif self.pad_mode == DatasetPadMode.NO_PADDING: # truncate input_ids if it is longer than max_length if len(input_ids) > self.max_length: @@ -440,3 +371,5 @@ class MultiTurnSFTDataset(Dataset): "position_ids": position_ids, "loss_mask": loss_mask, } + else: + raise ValueError(f"Unknown pad mode {self.pad_mode}") diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index c89f2cf0c..0430544f8 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -35,7 +35,6 @@ from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.trainer.config import CheckpointConfig from verl.utils import tensordict_utils as tu from verl.utils.activation_offload import enable_activation_offloading -from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.dataset.dataset_utils import DatasetPadMode from verl.utils.debug import log_gpu_memory_usage @@ -703,10 +702,12 @@ class EngineTrainModeCtx: class FSDPEngineWithLMHead(FSDPEngine): def prepare_model_inputs(self, micro_batch: TensorDict): use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) - pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False) temperature = micro_batch["temperature"] + assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported" + multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch.keys(): from verl.utils.model import extract_multi_modal_inputs @@ -726,32 +727,6 @@ class FSDPEngineWithLMHead(FSDPEngine): if pad_mode == DatasetPadMode.NO_PADDING: input_ids_rmpad = input_ids.values().unsqueeze(0) # (1, total_nnz) position_ids_rmpad = position_ids.values().unsqueeze(0) # (1, total_nnz) - elif pad_mode == DatasetPadMode.LEFT_RIGHT: - attention_mask = micro_batch["attention_mask"] - input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - output_args["indices"] = indices - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) - .transpose(0, 1) - .unsqueeze(1) - ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - if "image_bound" in multi_modal_inputs: - from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo - - multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( - input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs - ) else: raise NotImplementedError(f"pad_mode {pad_mode} not implemented") @@ -821,13 +796,6 @@ class FSDPEngineWithLMHead(FSDPEngine): attention_mask, padding=0, output_size=(batch_size, max_seq_len) ) - model_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - elif pad_mode == DatasetPadMode.LEFT_RIGHT: - attention_mask = micro_batch["attention_mask"] model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, @@ -848,7 +816,7 @@ class FSDPEngineWithLMHead(FSDPEngine): def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict): use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) - pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False) temperature = micro_batch["temperature"] calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False) @@ -906,33 +874,10 @@ class FSDPEngineWithLMHead(FSDPEngine): if pad_mode == DatasetPadMode.NO_PADDING: cu_seqlens = input_ids.offsets() - # (bsz, j1) for each sample, is the length of each sample: [real_prompt length + real_response length] + # (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length] log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens) if calculate_entropy: entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) - elif pad_mode == DatasetPadMode.LEFT_RIGHT: - indices = output_args["indices"] - response_length = micro_batch["responses"].size(-1) - batch_size, seqlen = input_ids.shape - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - - # pad back to (bsz, seqlen) - if calculate_entropy: - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - # only return response part: - if calculate_entropy: - entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) else: raise NotImplementedError(f"pad_mode {pad_mode} not implemented") @@ -960,17 +905,12 @@ class FSDPEngineWithLMHead(FSDPEngine): logits_rmpad = torch.cat([t for t in logits.unbind()]) input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"] log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) - # (bsz, j1) for each sample, length of each sample: [real_prompt_length + real_response_length] + # (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length] log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens) if calculate_entropy: entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged) entropy_rmpad = torch.cat([t for t in entropy.unbind()]) entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) - elif pad_mode == DatasetPadMode.LEFT_RIGHT: - logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch["responses"]) - if calculate_entropy: - entropy = entropy[:, -response_length - 1 : -1] # (bsz, response_length) else: raise NotImplementedError(f"pad_mode {pad_mode} not implemented") @@ -1022,7 +962,7 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead): def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict): use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) - response_length = micro_batch["responses"].size(-1) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) if use_remove_padding: input_ids = micro_batch["input_ids"] @@ -1033,24 +973,38 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead): values_rmpad = output[2].squeeze(0).unsqueeze(-1) else: values_rmpad = output.logits - values_rmpad = values_rmpad.squeeze(0) # (total_nnz) - - indices = output_args["indices"] + values_rmpad = values_rmpad.squeeze(0) # (total_nnz, 1) + # FIXME(houmin): confirm why should we squeeze here + values_rmpad = values_rmpad.squeeze(-1) # gather output if sp > 1 if self.use_ulysses_sp: pad_size = output_args["pad_size"] values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) - # pad it back - values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) - values = values[:, -response_length - 1 : -1] + if pad_mode == DatasetPadMode.NO_PADDING: + cu_seqlens = input_ids.offsets() + # (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length] + values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens) + else: + raise NotImplementedError(f"pad_mode {pad_mode} not implemented") + else: if hasattr(self.module, "v_head"): # For trl.AutoModelForCausalLMWithValueHead values = output[2] else: values = output.logits - values = values[:, -response_length - 1 : -1].squeeze(-1) + + if pad_mode == DatasetPadMode.NO_PADDING: + cu_seqlens = input_ids.offsets() + seq_lengths = cu_seqlens.diff() + starts = torch.zeros_like(seq_lengths, dtype=torch.int64) + values = torch.nested.narrow(values, 1, starts, seq_lengths, layout=torch.jagged) + values_rmpad = torch.cat([t for t in values.unbind()]) + # (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length] + values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens) + else: + raise NotImplementedError(f"pad_mode {pad_mode} not implemented") return {"values": values} diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 1b40ec0cb..1bd1bcddd 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -28,6 +28,7 @@ from verl.models.mcore import get_mcore_weight_converter from verl.trainer.config import CheckpointConfig from verl.utils import tensordict_utils as tu from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager +from verl.utils.dataset.dataset_utils import DatasetPadMode from verl.utils.device import get_device_id, get_device_name from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits @@ -512,11 +513,10 @@ class MegatronEngineWithLMHead(MegatronEngine): batch = batch.to(get_device_id()) batch = batch.contiguous() input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"].to(bool) + loss_mask = batch["loss_mask"].to(bool) position_ids = batch["position_ids"] # process vlm inputs - batch["attention_mask"] = batch["attention_mask"].to(bool) has_multi_modal_inputs = "multi_modal_inputs" in batch.keys() if has_multi_modal_inputs: batch["multi_modal_inputs"] = batch["multi_modal_inputs"] @@ -538,20 +538,18 @@ class MegatronEngineWithLMHead(MegatronEngine): return { "input_ids": input_ids, - "attention_mask": attention_mask, + "loss_mask": loss_mask, "position_ids": position_ids, "multi_modal_inputs": multi_modal_inputs, } def prepare_model_outputs(self, output: dict, data: TensorDict): calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False) - responses = data["responses"] - response_length = responses.size(1) - log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous() + log_prob = output["log_probs"] model_output = {"log_probs": log_prob} if calculate_entropy: - entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() + entropy = output["entropy"] model_output["entropy"] = entropy return model_output @@ -561,74 +559,58 @@ class MegatronEngineWithLMHead(MegatronEngine): batch = batch.to(get_device_id()) use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False) calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False) + pad_mode = tu.get_non_tensor_data(batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) temperature = batch["temperature"] model_inputs = self.prepare_model_inputs(batch) input_ids = model_inputs["input_ids"] - attention_mask = model_inputs["attention_mask"] - position_ids = model_inputs["position_ids"] multi_modal_inputs = model_inputs["multi_modal_inputs"] - responses = batch["responses"] - response_length = responses.size(1) - label = position_ids.clone() - label[:, -response_length - 1 : -1] = responses - label_mask = attention_mask.clone() - label_mask[:, : -response_length - 1] = False - label_mask[:, -1] = False + if pad_mode == DatasetPadMode.NO_PADDING: + label = input_ids.clone() + else: + raise NotImplementedError(f"Pad mode {pad_mode} is not supported for megatron engine") - from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn + from verl.models.mcore import get_mcore_forward_no_padding_fn if use_fused_kernels: - forward_fn = get_mcore_forward_fused_fn(self.model_config.hf_config) - # return dict of [logits, entropy] - output = forward_fn( - model, - input_ids, - position_ids, - attention_mask, - sequence_parallel=self.tf_config.sequence_parallel, - multi_modal_inputs=multi_modal_inputs, - labels=label, - labels_mask=label_mask, - temperature=temperature, - ) - else: - forward_fn = get_mcore_forward_fn(self.model_config.hf_config) + raise NotImplementedError("Fused kernels are not supported for megatron engine") - def logits_processor(logits, label, label_mask): - assert logits.shape[:2] == label.shape[:2] - assert label.shape == label_mask.shape - logits.div_(temperature) - ret = {} - if calculate_entropy: - logits_bak = logits.clone() - if torch.distributed.get_rank() == 0: - logger.warning_once( - "For memory-efficient computation, enable fused kernels via " - "`actor_rollout_ref.model.use_fused_kernels=True`. " - "The current `clone()` operation ensures correctness but increases memory usage." - ) - entropy = vocab_parallel_entropy(logits) - ret["entropy"] = entropy - else: - logits_bak = logits - log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) - log_probs = log_probs.masked_fill(~label_mask, 0.0) - ret["log_probs"] = log_probs - return ret + forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config) - logits_processor_args = {"label": label, "label_mask": label_mask} - output = forward_fn( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, - multi_modal_inputs=multi_modal_inputs, - logits_processor=logits_processor, - logits_processor_args=logits_processor_args, - ) + def logits_processor(logits, label): + assert logits.shape[:2] == label.shape[:2] + logits.div_(temperature) + ret = {} + if calculate_entropy: + logits_bak = logits.clone() + if torch.distributed.get_rank() == 0: + logger.warning_once( + "For memory-efficient computation, enable fused kernels via " + "`actor_rollout_ref.model.use_fused_kernels=True`. " + "The current `clone()` operation ensures correctness but increases memory usage." + ) + entropy = vocab_parallel_entropy(logits) + ret["entropy"] = entropy + else: + logits_bak = logits + + # FIXME(houmin): maybe shift label in another place + label = torch.roll(label, shifts=-1, dims=1) + + log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) + ret["log_probs"] = log_probs + return ret + + logits_processor_args = {"label": label} + + output = forward_fn( + model, + input_ids, + multi_modal_inputs=multi_modal_inputs, + logits_processor=logits_processor, + logits_processor_args=logits_processor_args, + ) return output, partial(postprocess_micro_batch_func, data=batch) @@ -671,19 +653,15 @@ class MegatronEngineWithValueHead(MegatronEngineWithLMHead): batch = batch.to(get_device_id()) model_inputs = self.prepare_model_inputs(batch) input_ids = model_inputs["input_ids"] - attention_mask = model_inputs["attention_mask"] - position_ids = model_inputs["position_ids"] multi_modal_inputs = model_inputs["multi_modal_inputs"] - from verl.models.mcore import get_mcore_forward_fn + from verl.models.mcore import get_mcore_forward_no_padding_fn - forward_fn = get_mcore_forward_fn(self.model_config.hf_config) + forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config) output = forward_fn( model, input_ids, - attention_mask, - position_ids, sequence_parallel=self.tf_config.sequence_parallel, multi_modal_inputs=multi_modal_inputs, value_model=True, @@ -692,7 +670,4 @@ class MegatronEngineWithValueHead(MegatronEngineWithLMHead): return output, partial(postprocess_micro_batch_func, data=batch) def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict): - responses = data["responses"] - response_length = responses.size(1) - output = output[:, -response_length - 1 : -1].contiguous() return {"values": output} diff --git a/verl/workers/engine/utils.py b/verl/workers/engine/utils.py index 59658e93f..cbb990c33 100644 --- a/verl/workers/engine/utils.py +++ b/verl/workers/engine/utils.py @@ -66,7 +66,8 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict): """ use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True) - pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT) + pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING) + assert pad_mode == DatasetPadMode.NO_PADDING, "postprocess_batch_func only support NO_PADDING pad_mode" # losses_reduced is a list of dict containing outputs for each micro-batch # reorder entropy and outputs. Return None for other pp ranks @@ -92,8 +93,6 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict): if pad_mode == DatasetPadMode.NO_PADDING: tensors = [tensor for nt in model_output[key] for tensor in nt.unbind()] model_output[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) - elif pad_mode == DatasetPadMode.LEFT_RIGHT: - model_output[key] = torch.cat(model_output[key], dim=0) else: raise NotImplementedError(f"pad_mode {pad_mode} not implemented") diff --git a/verl/workers/roles/actor.py b/verl/workers/roles/actor.py index b9b70de0e..2ad25f6df 100644 --- a/verl/workers/roles/actor.py +++ b/verl/workers/roles/actor.py @@ -33,6 +33,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension from verl.utils.py_functional import append_to_dict from verl.workers.config import ActorConfig from verl.workers.roles.utils.losses import ppo_loss +from verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -116,16 +117,23 @@ class ActorWorker(Worker, DistProfilerExtension): with self.engine.eval_mode(): # TODO: make worker API to accept TensorDict as well data = data.to_tensordict() + data = left_right_2_no_padding(data) output = self.engine.infer_batch(data) if self.engine.is_mp_src_rank_with_outputs(): output = output["model_output"] + log_probs = output["log_probs"] + log_probs = no_padding_2_padding(log_probs, data) # (bsz, response_length) + + entropy = output["entropy"] + if entropy is not None: + entropy = no_padding_2_padding(entropy, data) # (bsz, response_length) + # in megatron, only last pp contains valid data and returned to the single controller output = DataProto.from_dict( - tensors={"old_log_probs": output["log_probs"].float(), "entropy": output["entropy"].float()}, + tensors={"old_log_probs": log_probs.float(), "entropy": entropy.float()}, ) output = output.to("cpu") - return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @@ -155,6 +163,7 @@ class ActorWorker(Worker, DistProfilerExtension): mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size # TODO: make worker API to accept TensorDict as well mini_batch = mini_batch.to_tensordict() + mini_batch = left_right_2_no_padding(mini_batch) output = self.engine.train_batch(mini_batch, self.loss_fn) mini_batch_metrics = output.get("metrics", {}) append_to_dict(metrics, mini_batch_metrics, prefix="actor/") diff --git a/verl/workers/roles/critic.py b/verl/workers/roles/critic.py index 24b644149..d848dc8bb 100644 --- a/verl/workers/roles/critic.py +++ b/verl/workers/roles/critic.py @@ -35,6 +35,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension from verl.utils.py_functional import append_to_dict from verl.workers.config import CriticConfig from verl.workers.roles.utils.losses import value_loss +from verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -140,13 +141,17 @@ class CriticWorker(Worker, DistProfilerExtension): with self.engine.eval_mode(): # TODO: make worker API to accept TensorDict as well data = data.to_tensordict() + data = left_right_2_no_padding(data) output = self.engine.infer_batch(data) if self.engine.is_mp_src_rank_with_outputs(): # in megatron, only last pp contains valid data and returned to the single controller output = output["model_output"] + values = output["values"] + values = no_padding_2_padding(values, data) # (bsz, response_length) + output = DataProto.from_dict( - tensors={"values": output["values"].float()}, + tensors={"values": values.float()}, ) output = output.to("cpu") @@ -177,6 +182,7 @@ class CriticWorker(Worker, DistProfilerExtension): mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size # TODO: make worker API to accept TensorDict as well mini_batch = mini_batch.to_tensordict() + mini_batch = left_right_2_no_padding(mini_batch) output = self.engine.train_batch(mini_batch, self.loss_fn) mini_batch_metrics = output.get("metrics", {}) append_to_dict(metrics, mini_batch_metrics, prefix="critic/") diff --git a/verl/workers/roles/utils/losses.py b/verl/workers/roles/utils/losses.py index 91bf3e33a..3693e5eb4 100644 --- a/verl/workers/roles/utils/losses.py +++ b/verl/workers/roles/utils/losses.py @@ -21,10 +21,11 @@ from verl.utils import tensordict_utils as tu from verl.utils.dataset.dataset_utils import DatasetPadMode from verl.utils.torch_functional import masked_mean from verl.workers.config import ActorConfig, CriticConfig +from verl.workers.roles.utils.padding import no_padding_2_padding def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None): - pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT) + pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING) log_prob = model_output["log_probs"] @@ -52,6 +53,10 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None) log_prob = model_output["log_probs"] entropy = model_output.get("entropy", None) + log_prob = no_padding_2_padding(log_prob, data) # (bsz, response_length) + if entropy is not None: + entropy = no_padding_2_padding(entropy, data) # (bsz, response_length) + metrics = {} response_mask = data["response_mask"].to(bool) @@ -105,7 +110,7 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None) def value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=None): vpreds = model_output["values"] - values = data["values"] + vpreds = no_padding_2_padding(vpreds, data) # (bsz, response_length) values = data["values"] returns = data["returns"] diff --git a/verl/workers/roles/utils/padding.py b/verl/workers/roles/utils/padding.py new file mode 100644 index 000000000..01b82dcbf --- /dev/null +++ b/verl/workers/roles/utils/padding.py @@ -0,0 +1,119 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from tensordict import TensorDict + +from verl.utils import tensordict_utils as tu +from verl.utils.device import ( + is_cuda_available, + is_npu_available, +) + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input + + +def left_right_2_no_padding(data: TensorDict) -> TensorDict: + """ + Convert TensorDict from left-right padding to no-padding format. + + Args: + data: TensorDict with "input_ids", "attention_mask", "response_mask", "position_ids" + + Returns: + data: TensorDict with + - Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids" + - NonTensorData includes "max_seq_len", "max_response_len", "indices" + + Note: + 1. the return input_ids/position_ids/loss_mask are nested tensor. + 2. we will remove "attention_mask", "response" in the return data, but "response_mask" is kept. + """ + assert "input_ids" in data, "input_ids is required in left-right padding data" + assert "attention_mask" in data, "attention_mask is required in left-right padding data" + assert "response_mask" in data, "response_mask is required in left-right padding data" + assert "position_ids" in data, "position_ids is required in left-right padding data" + + input_ids = data.pop("input_ids") + attention_mask = data.pop("attention_mask") + response_mask = data["response_mask"] + if "responses" in data: + _ = data.pop("responses") + + max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1] + tu.assign_non_tensor_data(data, "max_seq_len", max_seq_len) + tu.assign_non_tensor_data(data, "max_response_len", max_response_len) + + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) + tu.assign_non_tensor_data(data, "indices", indices) + + input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens) + + seq_lens = cu_seqlens.diff().tolist() + response_lens = response_mask.sum(dim=1).tolist() + + position_ids_list = [] + loss_mask_list = [] + for seq_len, response_len in zip(seq_lens, response_lens, strict=False): + position_ids_list.append(torch.arange(seq_len, device=input_ids.device)) + loss_mask = torch.zeros(seq_len, dtype=torch.bool, device=input_ids.device) + assert seq_len >= response_len, f"{seq_len=} is less than {response_len=}" + loss_mask[-response_len:] = 1 + loss_mask_list.append(loss_mask) + + position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged) + loss_mask_nested = torch.nested.as_nested_tensor(loss_mask_list, layout=torch.jagged) + + data["input_ids"] = input_ids_nested + data["position_ids"] = position_ids_nested + data["loss_mask"] = loss_mask_nested + + return data + + +def no_padding_2_padding(nested_tensor: torch.Tensor, data: TensorDict) -> torch.Tensor: + """ + Convert NestedTensor from no-padding to right padding format. + + Args: + nested_tensor: NestedTensor with no-padding format + data: TensorDict with + - Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids" + - NonTensorData includes "max_seq_len", "max_response_len", "indices" + + Returns: + values: regular tensor right padded to max_response_len + """ + assert "indices" in data, "indices is required in left-right padding data" + assert "max_seq_len" in data, "max_seq_len is required in left-right padding data" + assert "max_response_len" in data, "max_response_len is required in left-right padding data" + + indices = tu.get_non_tensor_data(data=data, key="indices", default=None) + max_seq_len = tu.get_non_tensor_data(data=data, key="max_seq_len", default=2048) + max_response_len = tu.get_non_tensor_data(data=data, key="max_response_len", default=1024) + batch_size = nested_tensor.size(0) + + values = nested_tensor.values() + full_values = pad_input( + hidden_states=values.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=max_seq_len, + ) + values = full_values.squeeze(-1)[:, -max_response_len - 1 : -1] # (bsz, response_length) + + return values