mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 3 (#3600)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. This PR continues the work started in PR #3567 by deprecating and removing the left_right padding mode 1. Implement no-padding mode for Megatron engine using nested tensors in sft trainer 2. Deprecating left_right padding mode for FSDP/Megatron engine 3. Introduces a transformation layer within Actor/Critic workers, see more [here](https://github.com/volcengine/verl/blob/main/docs/workers/model_engine.rst) - **Input Format**: Actor/Critic workers continue to receive data in left_rightpadded format. - **Transformation**: This layer dynamically converts left_rightpadded data into the no-padding format using nested tensors. - **Engine Format**: FSDP and Megatron engines now operate exclusively using the no-padding data format by default. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
@ -29,7 +29,7 @@ PP_SIZE=${PP_SIZE:-1}
|
||||
VPP_SIZE=${VPP_SIZE:-null}
|
||||
CP_SIZE=${CP_SIZE:-1}
|
||||
|
||||
PAD_MODE=${PAD_MODE:-left_right}
|
||||
PAD_MODE=${PAD_MODE:-no_padding}
|
||||
|
||||
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
|
||||
|
||||
@ -80,8 +80,6 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
|
||||
data.train_files="${TRAIN_FILES}" \
|
||||
data.val_files="${VAL_FILES}" \
|
||||
data.train_batch_size=256 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=1024 \
|
||||
data.pad_mode=${PAD_MODE} \
|
||||
data.truncation=error \
|
||||
data.use_dynamic_bsz=True \
|
||||
|
@ -9,15 +9,6 @@ echo "run with single gpu as golden"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
# test with fsdp 1
|
||||
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
@ -27,18 +18,14 @@ BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_pa
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
# test use_remove_padding and pad_mode left_right/no_padding
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
# test use_remove_padding and pad_mode no_padding
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
|
||||
# test with fsdp 2
|
||||
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
@ -50,6 +37,7 @@ BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/spe
|
||||
# test with megatron
|
||||
echo "run with tp1 pp1 cp1 num_gpus1"
|
||||
BACKEND=megatron TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 NUM_GPUS=1 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
echo "run with tp2 pp2 vpp2 cp1 num_gpus8"
|
||||
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=1 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
|
@ -177,28 +177,26 @@ def test_multiturn_sft_dataset():
|
||||
assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding"
|
||||
assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding"
|
||||
|
||||
# test left right padding
|
||||
# test no-padding
|
||||
config = {
|
||||
"max_length": 512,
|
||||
"truncation": "error",
|
||||
"multiturn": {"messages_key": "messages"},
|
||||
"pad_mode": "left_right",
|
||||
"max_prompt_length": 64,
|
||||
"max_response_length": 64,
|
||||
"pad_mode": "no_padding",
|
||||
}
|
||||
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
|
||||
|
||||
item0 = dataset[0]
|
||||
|
||||
# make sure all the input_ids with attention_mask == 0 are all padding
|
||||
assert torch.all(item0["input_ids"][item0["attention_mask"] == 0] == tokenizer.pad_token_id)
|
||||
# Verify that the output contains expected keys for no-padding mode
|
||||
required_keys = ["input_ids", "position_ids", "loss_mask"]
|
||||
for key in required_keys:
|
||||
assert key in item0, f"Missing key {key} in no-padding mode dataset item"
|
||||
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key} in no-padding mode"
|
||||
|
||||
# make sure assistant_text matches with expected
|
||||
assistant_text = tokenizer.decode(item0["responses"][item0["response_mask"] == 1])
|
||||
assistant_text = tokenizer.decode(item0["input_ids"][item0["loss_mask"] == 1])
|
||||
assert assistant_text == "2+2 equals 4.<|im_end|>\n4+4 equals 8.<|im_end|>\n"
|
||||
|
||||
# make sure responses are part of input_ids
|
||||
assert torch.all(item0["input_ids"][-item0["responses"].shape[0] :] == item0["responses"])
|
||||
|
||||
print("All tests passed!")
|
||||
print("Starting test...")
|
||||
|
@ -16,6 +16,7 @@
|
||||
from .registry import (
|
||||
get_mcore_forward_fn,
|
||||
get_mcore_forward_fused_fn,
|
||||
get_mcore_forward_no_padding_fn,
|
||||
get_mcore_weight_converter,
|
||||
hf_to_mcore_config,
|
||||
init_mcore_model,
|
||||
@ -27,4 +28,5 @@ __all__ = [
|
||||
"get_mcore_forward_fn",
|
||||
"get_mcore_weight_converter",
|
||||
"get_mcore_forward_fused_fn",
|
||||
"get_mcore_forward_no_padding_fn",
|
||||
]
|
||||
|
@ -16,7 +16,14 @@
|
||||
|
||||
from verl.utils.megatron_utils import unwrap_model
|
||||
|
||||
from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding
|
||||
from .util import (
|
||||
postprocess_packed_seqs,
|
||||
postprocess_packed_seqs_no_padding,
|
||||
preprocess_packed_seqs,
|
||||
preprocess_packed_seqs_no_padding,
|
||||
recover_left_padding,
|
||||
remove_left_padding,
|
||||
)
|
||||
|
||||
|
||||
def gptmodel_forward(
|
||||
@ -146,3 +153,55 @@ def gptmodel_forward_qwen2_5_vl(
|
||||
if value_model and post_process:
|
||||
output = output[..., 0]
|
||||
return output
|
||||
|
||||
|
||||
def gptmodel_forward_no_padding(
|
||||
model,
|
||||
input_ids,
|
||||
value_model=False,
|
||||
pack_seqs=True,
|
||||
logits_processor=None,
|
||||
logits_processor_args: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Default forward pass for GPT models with optional sequence packing."""
|
||||
pre_process = unwrap_model(model).pre_process
|
||||
post_process = unwrap_model(model).post_process
|
||||
if pack_seqs:
|
||||
batch_size = input_ids.shape[0]
|
||||
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process)
|
||||
input_ids_rmpad = input_ids_rmpad.contiguous()
|
||||
output_orig = model(
|
||||
input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
packed_seq_params=packed_seq_params,
|
||||
)
|
||||
|
||||
if post_process and logits_processor is not None:
|
||||
args = {
|
||||
k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()
|
||||
}
|
||||
output_dict = logits_processor(output_orig, **args)
|
||||
# print(f'gptmodel_forward_no_padding: {output_dict=}')
|
||||
output = {
|
||||
k: postprocess_packed_seqs_no_padding(
|
||||
v, packed_seq_params, input_ids, batch_size, post_process=post_process
|
||||
)
|
||||
for k, v in output_dict.items()
|
||||
}
|
||||
else:
|
||||
output = postprocess_packed_seqs_no_padding(
|
||||
output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("gptmodel_forward_no_padding only supports packed sequences")
|
||||
|
||||
if value_model and post_process:
|
||||
# output = output[..., 0]
|
||||
# while using nested tensor, the advanced indexing operation above will result in an error at backward, i.e.
|
||||
# ValueError: NestedTensor _nested_select_backward_default(grad_output: t, self: jt_all, dim: any, index: any)
|
||||
# so we use `squeeze` to remove the last dimension
|
||||
output = output.squeeze(-1)
|
||||
|
||||
return output
|
||||
|
@ -33,8 +33,15 @@ from .config_converter import (
|
||||
hf_to_mcore_config_qwen2moe,
|
||||
hf_to_mcore_config_qwen3moe,
|
||||
)
|
||||
from .model_forward import gptmodel_forward, gptmodel_forward_qwen2_5_vl
|
||||
from .model_forward_fused import fused_forward_gptmodel, fused_forward_qwen2_5_vl
|
||||
from .model_forward import (
|
||||
gptmodel_forward,
|
||||
gptmodel_forward_no_padding,
|
||||
gptmodel_forward_qwen2_5_vl,
|
||||
)
|
||||
from .model_forward_fused import (
|
||||
fused_forward_gptmodel,
|
||||
fused_forward_qwen2_5_vl,
|
||||
)
|
||||
from .model_initializer import (
|
||||
BaseModelInitializer,
|
||||
DeepseekV3Model,
|
||||
@ -116,6 +123,23 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
|
||||
}
|
||||
|
||||
# Registry for model forward functions
|
||||
MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.LLAMA: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN2: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding,
|
||||
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
|
||||
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
|
||||
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
|
||||
# SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
|
||||
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
|
||||
SupportedModel.GLM4_MOE: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding,
|
||||
}
|
||||
|
||||
# Registry for model forward functions
|
||||
MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.LLAMA: fused_forward_gptmodel,
|
||||
@ -220,6 +244,15 @@ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
|
||||
return MODEL_FORWARD_REGISTRY[model]
|
||||
|
||||
|
||||
def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable:
|
||||
"""
|
||||
Get the forward function for given model architecture.
|
||||
"""
|
||||
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
|
||||
model = get_supported_model(hf_config.architectures[0])
|
||||
return MODEL_FORWARD_NOPAD_REGISTRY[model]
|
||||
|
||||
|
||||
def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
|
||||
"""
|
||||
Get the forward function for given model architecture.
|
||||
|
@ -162,6 +162,151 @@ def postprocess_packed_seqs(
|
||||
return output_new
|
||||
|
||||
|
||||
def preprocess_packed_seqs_no_padding(
|
||||
input_ids: torch.Tensor, pre_process: bool = True
|
||||
) -> tuple[torch.Tensor, PackedSeqParams]:
|
||||
"""
|
||||
Preprocess packed sequences
|
||||
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1
|
||||
gets second and second last chunks, and so on), this is for load balancing with causal masking.
|
||||
See https://github.com/NVIDIA/TransformerEngine/issues/1368
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
cp_size = mpu.get_context_parallel_world_size()
|
||||
cp_rank = mpu.get_context_parallel_rank()
|
||||
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
|
||||
seqlens_in_batch = input_ids.offsets().diff()
|
||||
|
||||
pad_size = (align_size - seqlens_in_batch % align_size) % align_size
|
||||
seqlens_in_batch_padded = seqlens_in_batch + pad_size
|
||||
|
||||
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
|
||||
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
|
||||
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
|
||||
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Move the index information needed in the subsequent loop to the CPU at once,
|
||||
# to avoid frequent .item() calls in the loop that cause D2H synchronization
|
||||
# ----------------------------------------------------------------------------
|
||||
seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths
|
||||
seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding
|
||||
cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding)
|
||||
|
||||
# Pure Python int calculation to avoid further synchronization
|
||||
max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)
|
||||
|
||||
shape = list(input_ids.shape[1:])
|
||||
shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size
|
||||
if pre_process:
|
||||
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
|
||||
for i in range(batch_size):
|
||||
# Use Python int, so no GPU→CPU sync in the loop
|
||||
if cp_size <= 1:
|
||||
seqlen = seqlens_in_batch_cpu[i]
|
||||
start_idx = cu_seqlens_padded_cpu[i]
|
||||
input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i]
|
||||
|
||||
seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
|
||||
seqlen = seqlen_padded_i // cp_size
|
||||
half_seqlen = seqlen // 2
|
||||
start_idx = cu_seqlens_padded_cpu[i] // cp_size
|
||||
# split to 2 chunks
|
||||
d = input_ids[i]
|
||||
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
|
||||
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
|
||||
]
|
||||
|
||||
remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)
|
||||
remain_end = seqlen_padded_i - half_seqlen * cp_rank
|
||||
remain_end = min(remain_end, d.shape[0])
|
||||
remain_len = remain_end - remain_start
|
||||
if remain_len > 0:
|
||||
input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[
|
||||
remain_start:remain_end
|
||||
]
|
||||
|
||||
packed_seq_params = PackedSeqParams(
|
||||
qkv_format="thd",
|
||||
cu_seqlens_q=cu_seqlens_padded,
|
||||
max_seqlen_q=max_seqlen_in_batch,
|
||||
cu_seqlens_kv=cu_seqlens_padded,
|
||||
max_seqlen_kv=max_seqlen_in_batch,
|
||||
cu_seqlens_q_padded=cu_seqlens_padded,
|
||||
cu_seqlens_kv_padded=cu_seqlens_padded,
|
||||
)
|
||||
if pre_process:
|
||||
return input_ids_rmpad.unsqueeze(0), packed_seq_params
|
||||
else:
|
||||
return input_ids, packed_seq_params
|
||||
|
||||
|
||||
def postprocess_packed_seqs_no_padding(
|
||||
output: torch.Tensor,
|
||||
packed_seq_params: PackedSeqParams,
|
||||
input_ids: torch.Tensor,
|
||||
batch_size: int,
|
||||
post_process: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Postprocess packed sequences
|
||||
"""
|
||||
if not post_process:
|
||||
return output
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,
|
||||
# to avoid a large number of .item() calls in the loop
|
||||
# -------------------------------------------------------------------------
|
||||
cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()
|
||||
# The reason why we use input_ids.offsets() instead of packed_seq_params.cu_seqlens_q.diff()
|
||||
# is that the latter one is the padded length, while the former one is the original length.
|
||||
cu_seqlens = input_ids.offsets()
|
||||
seq_lens_cpu: list[int] = cu_seqlens.diff().tolist()
|
||||
|
||||
output_new = []
|
||||
|
||||
cp_size = mpu.get_context_parallel_world_size()
|
||||
# all gather output across context parallel group
|
||||
if cp_size > 1:
|
||||
# output shape: [1, packed_len, hidden_dim]
|
||||
# need to gather across cp group and concatenate in sequence dimension
|
||||
output_list = [torch.empty_like(output) for _ in range(cp_size)]
|
||||
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
|
||||
output_list[mpu.get_context_parallel_rank()] = output
|
||||
else:
|
||||
output_list = [output]
|
||||
|
||||
for i in range(batch_size):
|
||||
if cp_size <= 1:
|
||||
s = seq_lens_cpu[i]
|
||||
start_idx = cu_padded_cpu[i]
|
||||
output_new.append(output[0][start_idx : start_idx + s])
|
||||
continue
|
||||
s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size
|
||||
half_seqlen = s_len_padded_chunk // 2
|
||||
s_len = seq_lens_cpu[i]
|
||||
s_len_padded = s_len_padded_chunk * cp_size
|
||||
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
|
||||
for j in range(cp_size):
|
||||
o = output_list[j][0]
|
||||
# split to 2 chunks
|
||||
packed_start_idx = cu_padded_cpu[i] // cp_size
|
||||
o0, o1 = (
|
||||
o[packed_start_idx : packed_start_idx + half_seqlen],
|
||||
o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
|
||||
)
|
||||
tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
|
||||
tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
|
||||
output_new.append(tmp[:s_len])
|
||||
|
||||
output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged)
|
||||
|
||||
return output_new_tensor
|
||||
|
||||
|
||||
def remove_left_padding(
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
|
@ -23,12 +23,9 @@ data:
|
||||
messages_key: messages # Key for messages list in multi-turn mode
|
||||
tools_key: tools # Key for tools list in multi-turn mode
|
||||
enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode
|
||||
pad_mode: left_right
|
||||
pad_mode: no_padding
|
||||
# for right padding
|
||||
max_length: 1024
|
||||
# for left right padding
|
||||
max_prompt_length: 512
|
||||
max_response_length: 512
|
||||
truncation: error
|
||||
balance_dp_token: False # to be implement
|
||||
custom_cls:
|
||||
|
@ -29,8 +29,6 @@ from transformers import PreTrainedTokenizer
|
||||
from verl.utils import hf_tokenizer
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
from verl.utils.torch_functional import pad_sequence_to_length, postprocess_data
|
||||
|
||||
|
||||
def convert_nested_value_to_list_recursive(data_item):
|
||||
@ -55,15 +53,12 @@ class MultiTurnSFTDataset(Dataset):
|
||||
# Set defaults and extract parameters from config if provided
|
||||
config = config or {}
|
||||
self.pad_mode = config.get("pad_mode", "right")
|
||||
assert self.pad_mode in ["right", "left_right", "no_padding"], (
|
||||
f"Expect pad_mode to be 'right', 'left_right' or 'no_padding'. Got {self.pad_mode}"
|
||||
assert self.pad_mode in ["right", "no_padding"], (
|
||||
f"Expect pad_mode to be 'right' or 'no_padding'. Got {self.pad_mode}"
|
||||
)
|
||||
self.truncation = config.get("truncation", "error")
|
||||
# for right padding
|
||||
self.max_length = config.get("max_length", 1024)
|
||||
# for left right paddding to be consistent with RL
|
||||
self.max_prompt_length = config.get("max_prompt_length", 512)
|
||||
self.max_response_length = config.get("max_response_length", 512)
|
||||
# Get messages_key from the new multiturn config structure
|
||||
multiturn_config = config.get("multiturn", {})
|
||||
self.messages_key = multiturn_config.get("messages_key", "messages")
|
||||
@ -320,10 +315,8 @@ class MultiTurnSFTDataset(Dataset):
|
||||
if messages[0]["role"] == "system":
|
||||
assert messages[1]["role"] == "user"
|
||||
assert messages[2]["role"] == "assistant"
|
||||
prompt_message_length = 2
|
||||
elif messages[0]["role"] == "user":
|
||||
assert messages[1]["role"] == "assistant"
|
||||
prompt_message_length = 1
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {messages[0]['role']}")
|
||||
|
||||
@ -365,68 +358,6 @@ class MultiTurnSFTDataset(Dataset):
|
||||
"position_ids": position_ids,
|
||||
"loss_mask": loss_mask,
|
||||
}
|
||||
elif self.pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
assert self.truncation == "error", "Only support error truncation for left_right pad mode"
|
||||
prompt_str = self.tokenizer.apply_chat_template(
|
||||
messages[:prompt_message_length],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=enable_thinking,
|
||||
**self.apply_chat_template_kwargs,
|
||||
)
|
||||
prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False)
|
||||
prompt_length = len(prompt_ids)
|
||||
prompt_ids = input_ids[:prompt_length].unsqueeze(0)
|
||||
prompt_attention_mask = attention_mask[:prompt_length].unsqueeze(0)
|
||||
prompt_loss_mask = loss_mask[:prompt_length].unsqueeze(0)
|
||||
response_ids = input_ids[prompt_length:].unsqueeze(0)
|
||||
response_attention_mask = attention_mask[prompt_length:].unsqueeze(0)
|
||||
response_loss_mask = loss_mask[prompt_length:].unsqueeze(0)
|
||||
|
||||
assert prompt_loss_mask.sum().item() == 0
|
||||
|
||||
prompt_ids, prompt_attention_mask = postprocess_data(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
max_length=self.max_prompt_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=True,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
|
||||
response_ids, response_attention_mask = postprocess_data(
|
||||
input_ids=response_ids,
|
||||
attention_mask=response_attention_mask,
|
||||
max_length=self.max_response_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=False,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
response_loss_mask = pad_sequence_to_length(
|
||||
response_loss_mask, max_seq_len=self.max_response_length, pad_token_id=0, left_pad=False
|
||||
)
|
||||
|
||||
prompt_ids = prompt_ids[0]
|
||||
prompt_attention_mask = prompt_attention_mask[0]
|
||||
response_ids = response_ids[0]
|
||||
response_attention_mask = response_attention_mask[0]
|
||||
response_loss_mask = response_loss_mask[0]
|
||||
|
||||
assert response_attention_mask[0].item() == 1
|
||||
assert response_loss_mask[0].item() == 1
|
||||
|
||||
input_ids = torch.cat((prompt_ids, response_ids), dim=0)
|
||||
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=0)
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"responses": response_ids,
|
||||
"response_mask": response_loss_mask,
|
||||
}
|
||||
elif self.pad_mode == DatasetPadMode.NO_PADDING:
|
||||
# truncate input_ids if it is longer than max_length
|
||||
if len(input_ids) > self.max_length:
|
||||
@ -440,3 +371,5 @@ class MultiTurnSFTDataset(Dataset):
|
||||
"position_ids": position_ids,
|
||||
"loss_mask": loss_mask,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown pad mode {self.pad_mode}")
|
||||
|
@ -35,7 +35,6 @@ from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
from verl.trainer.config import CheckpointConfig
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.activation_offload import enable_activation_offloading
|
||||
from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input
|
||||
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
@ -703,10 +702,12 @@ class EngineTrainModeCtx:
|
||||
class FSDPEngineWithLMHead(FSDPEngine):
|
||||
def prepare_model_inputs(self, micro_batch: TensorDict):
|
||||
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
|
||||
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
|
||||
temperature = micro_batch["temperature"]
|
||||
|
||||
assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported"
|
||||
|
||||
multi_modal_inputs = {}
|
||||
if "multi_modal_inputs" in micro_batch.keys():
|
||||
from verl.utils.model import extract_multi_modal_inputs
|
||||
@ -726,32 +727,6 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
input_ids_rmpad = input_ids.values().unsqueeze(0) # (1, total_nnz)
|
||||
position_ids_rmpad = position_ids.values().unsqueeze(0) # (1, total_nnz)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
attention_mask = micro_batch["attention_mask"]
|
||||
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask
|
||||
) # input_ids_rmpad (total_nnz, ...)
|
||||
output_args["indices"] = indices
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
if position_ids.dim() == 3:
|
||||
position_ids_rmpad = (
|
||||
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(1)
|
||||
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
|
||||
else:
|
||||
position_ids_rmpad = index_first_axis(
|
||||
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
|
||||
).transpose(0, 1)
|
||||
|
||||
if "image_bound" in multi_modal_inputs:
|
||||
from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo
|
||||
|
||||
multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
|
||||
input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
@ -821,13 +796,6 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
attention_mask, padding=0, output_size=(batch_size, max_seq_len)
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
attention_mask = micro_batch["attention_mask"]
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
@ -848,7 +816,7 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
|
||||
def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
|
||||
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
|
||||
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
|
||||
temperature = micro_batch["temperature"]
|
||||
calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False)
|
||||
@ -906,33 +874,10 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
cu_seqlens = input_ids.offsets()
|
||||
# (bsz, j1) for each sample, is the length of each sample: [real_prompt length + real_response length]
|
||||
# (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length]
|
||||
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
|
||||
if calculate_entropy:
|
||||
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
indices = output_args["indices"]
|
||||
response_length = micro_batch["responses"].size(-1)
|
||||
batch_size, seqlen = input_ids.shape
|
||||
full_log_probs = pad_input(
|
||||
hidden_states=log_probs.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen,
|
||||
)
|
||||
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
|
||||
# pad back to (bsz, seqlen)
|
||||
if calculate_entropy:
|
||||
full_entropy = pad_input(
|
||||
hidden_states=entropy_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen,
|
||||
)
|
||||
# only return response part:
|
||||
if calculate_entropy:
|
||||
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
@ -960,17 +905,12 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
logits_rmpad = torch.cat([t for t in logits.unbind()])
|
||||
input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"]
|
||||
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
|
||||
# (bsz, j1) for each sample, length of each sample: [real_prompt_length + real_response_length]
|
||||
# (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length]
|
||||
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
|
||||
if calculate_entropy:
|
||||
entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged)
|
||||
entropy_rmpad = torch.cat([t for t in entropy.unbind()])
|
||||
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
|
||||
log_probs = logprobs_from_logits(logits, micro_batch["responses"])
|
||||
if calculate_entropy:
|
||||
entropy = entropy[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
@ -1022,7 +962,7 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
|
||||
|
||||
def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
|
||||
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
|
||||
response_length = micro_batch["responses"].size(-1)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
|
||||
|
||||
if use_remove_padding:
|
||||
input_ids = micro_batch["input_ids"]
|
||||
@ -1033,24 +973,38 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
|
||||
values_rmpad = output[2].squeeze(0).unsqueeze(-1)
|
||||
else:
|
||||
values_rmpad = output.logits
|
||||
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
|
||||
|
||||
indices = output_args["indices"]
|
||||
values_rmpad = values_rmpad.squeeze(0) # (total_nnz, 1)
|
||||
# FIXME(houmin): confirm why should we squeeze here
|
||||
values_rmpad = values_rmpad.squeeze(-1)
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.use_ulysses_sp:
|
||||
pad_size = output_args["pad_size"]
|
||||
values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
||||
|
||||
# pad it back
|
||||
values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
|
||||
values = values[:, -response_length - 1 : -1]
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
cu_seqlens = input_ids.offsets()
|
||||
# (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length]
|
||||
values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
else:
|
||||
if hasattr(self.module, "v_head"):
|
||||
# For trl.AutoModelForCausalLMWithValueHead
|
||||
values = output[2]
|
||||
else:
|
||||
values = output.logits
|
||||
values = values[:, -response_length - 1 : -1].squeeze(-1)
|
||||
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
cu_seqlens = input_ids.offsets()
|
||||
seq_lengths = cu_seqlens.diff()
|
||||
starts = torch.zeros_like(seq_lengths, dtype=torch.int64)
|
||||
values = torch.nested.narrow(values, 1, starts, seq_lengths, layout=torch.jagged)
|
||||
values_rmpad = torch.cat([t for t in values.unbind()])
|
||||
# (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length]
|
||||
values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
return {"values": values}
|
||||
|
@ -28,6 +28,7 @@ from verl.models.mcore import get_mcore_weight_converter
|
||||
from verl.trainer.config import CheckpointConfig
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.device import get_device_id, get_device_name
|
||||
from verl.utils.megatron.pipeline_parallel import make_batch_generator
|
||||
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
|
||||
@ -512,11 +513,10 @@ class MegatronEngineWithLMHead(MegatronEngine):
|
||||
batch = batch.to(get_device_id())
|
||||
batch = batch.contiguous()
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"].to(bool)
|
||||
loss_mask = batch["loss_mask"].to(bool)
|
||||
position_ids = batch["position_ids"]
|
||||
|
||||
# process vlm inputs
|
||||
batch["attention_mask"] = batch["attention_mask"].to(bool)
|
||||
has_multi_modal_inputs = "multi_modal_inputs" in batch.keys()
|
||||
if has_multi_modal_inputs:
|
||||
batch["multi_modal_inputs"] = batch["multi_modal_inputs"]
|
||||
@ -538,20 +538,18 @@ class MegatronEngineWithLMHead(MegatronEngine):
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"loss_mask": loss_mask,
|
||||
"position_ids": position_ids,
|
||||
"multi_modal_inputs": multi_modal_inputs,
|
||||
}
|
||||
|
||||
def prepare_model_outputs(self, output: dict, data: TensorDict):
|
||||
calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False)
|
||||
responses = data["responses"]
|
||||
response_length = responses.size(1)
|
||||
|
||||
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
|
||||
log_prob = output["log_probs"]
|
||||
model_output = {"log_probs": log_prob}
|
||||
if calculate_entropy:
|
||||
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
|
||||
entropy = output["entropy"]
|
||||
model_output["entropy"] = entropy
|
||||
|
||||
return model_output
|
||||
@ -561,74 +559,58 @@ class MegatronEngineWithLMHead(MegatronEngine):
|
||||
batch = batch.to(get_device_id())
|
||||
use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False)
|
||||
calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False)
|
||||
pad_mode = tu.get_non_tensor_data(batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
|
||||
temperature = batch["temperature"]
|
||||
|
||||
model_inputs = self.prepare_model_inputs(batch)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
position_ids = model_inputs["position_ids"]
|
||||
multi_modal_inputs = model_inputs["multi_modal_inputs"]
|
||||
|
||||
responses = batch["responses"]
|
||||
response_length = responses.size(1)
|
||||
label = position_ids.clone()
|
||||
label[:, -response_length - 1 : -1] = responses
|
||||
label_mask = attention_mask.clone()
|
||||
label_mask[:, : -response_length - 1] = False
|
||||
label_mask[:, -1] = False
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
label = input_ids.clone()
|
||||
else:
|
||||
raise NotImplementedError(f"Pad mode {pad_mode} is not supported for megatron engine")
|
||||
|
||||
from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn
|
||||
from verl.models.mcore import get_mcore_forward_no_padding_fn
|
||||
|
||||
if use_fused_kernels:
|
||||
forward_fn = get_mcore_forward_fused_fn(self.model_config.hf_config)
|
||||
# return dict of [logits, entropy]
|
||||
output = forward_fn(
|
||||
model,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
sequence_parallel=self.tf_config.sequence_parallel,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
labels=label,
|
||||
labels_mask=label_mask,
|
||||
temperature=temperature,
|
||||
)
|
||||
else:
|
||||
forward_fn = get_mcore_forward_fn(self.model_config.hf_config)
|
||||
raise NotImplementedError("Fused kernels are not supported for megatron engine")
|
||||
|
||||
def logits_processor(logits, label, label_mask):
|
||||
assert logits.shape[:2] == label.shape[:2]
|
||||
assert label.shape == label_mask.shape
|
||||
logits.div_(temperature)
|
||||
ret = {}
|
||||
if calculate_entropy:
|
||||
logits_bak = logits.clone()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
logger.warning_once(
|
||||
"For memory-efficient computation, enable fused kernels via "
|
||||
"`actor_rollout_ref.model.use_fused_kernels=True`. "
|
||||
"The current `clone()` operation ensures correctness but increases memory usage."
|
||||
)
|
||||
entropy = vocab_parallel_entropy(logits)
|
||||
ret["entropy"] = entropy
|
||||
else:
|
||||
logits_bak = logits
|
||||
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
|
||||
log_probs = log_probs.masked_fill(~label_mask, 0.0)
|
||||
ret["log_probs"] = log_probs
|
||||
return ret
|
||||
forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)
|
||||
|
||||
logits_processor_args = {"label": label, "label_mask": label_mask}
|
||||
output = forward_fn(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
sequence_parallel=self.tf_config.sequence_parallel,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
logits_processor=logits_processor,
|
||||
logits_processor_args=logits_processor_args,
|
||||
)
|
||||
def logits_processor(logits, label):
|
||||
assert logits.shape[:2] == label.shape[:2]
|
||||
logits.div_(temperature)
|
||||
ret = {}
|
||||
if calculate_entropy:
|
||||
logits_bak = logits.clone()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
logger.warning_once(
|
||||
"For memory-efficient computation, enable fused kernels via "
|
||||
"`actor_rollout_ref.model.use_fused_kernels=True`. "
|
||||
"The current `clone()` operation ensures correctness but increases memory usage."
|
||||
)
|
||||
entropy = vocab_parallel_entropy(logits)
|
||||
ret["entropy"] = entropy
|
||||
else:
|
||||
logits_bak = logits
|
||||
|
||||
# FIXME(houmin): maybe shift label in another place
|
||||
label = torch.roll(label, shifts=-1, dims=1)
|
||||
|
||||
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
|
||||
ret["log_probs"] = log_probs
|
||||
return ret
|
||||
|
||||
logits_processor_args = {"label": label}
|
||||
|
||||
output = forward_fn(
|
||||
model,
|
||||
input_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
logits_processor=logits_processor,
|
||||
logits_processor_args=logits_processor_args,
|
||||
)
|
||||
|
||||
return output, partial(postprocess_micro_batch_func, data=batch)
|
||||
|
||||
@ -671,19 +653,15 @@ class MegatronEngineWithValueHead(MegatronEngineWithLMHead):
|
||||
batch = batch.to(get_device_id())
|
||||
model_inputs = self.prepare_model_inputs(batch)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
position_ids = model_inputs["position_ids"]
|
||||
multi_modal_inputs = model_inputs["multi_modal_inputs"]
|
||||
|
||||
from verl.models.mcore import get_mcore_forward_fn
|
||||
from verl.models.mcore import get_mcore_forward_no_padding_fn
|
||||
|
||||
forward_fn = get_mcore_forward_fn(self.model_config.hf_config)
|
||||
forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)
|
||||
|
||||
output = forward_fn(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
sequence_parallel=self.tf_config.sequence_parallel,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
value_model=True,
|
||||
@ -692,7 +670,4 @@ class MegatronEngineWithValueHead(MegatronEngineWithLMHead):
|
||||
return output, partial(postprocess_micro_batch_func, data=batch)
|
||||
|
||||
def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict):
|
||||
responses = data["responses"]
|
||||
response_length = responses.size(1)
|
||||
output = output[:, -response_length - 1 : -1].contiguous()
|
||||
return {"values": output}
|
||||
|
@ -66,7 +66,8 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
|
||||
"""
|
||||
|
||||
use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True)
|
||||
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING)
|
||||
assert pad_mode == DatasetPadMode.NO_PADDING, "postprocess_batch_func only support NO_PADDING pad_mode"
|
||||
|
||||
# losses_reduced is a list of dict containing outputs for each micro-batch
|
||||
# reorder entropy and outputs. Return None for other pp ranks
|
||||
@ -92,8 +93,6 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
tensors = [tensor for nt in model_output[key] for tensor in nt.unbind()]
|
||||
model_output[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
model_output[key] = torch.cat(model_output[key], dim=0)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
|
@ -33,6 +33,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.workers.config import ActorConfig
|
||||
from verl.workers.roles.utils.losses import ppo_loss
|
||||
from verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
@ -116,16 +117,23 @@ class ActorWorker(Worker, DistProfilerExtension):
|
||||
with self.engine.eval_mode():
|
||||
# TODO: make worker API to accept TensorDict as well
|
||||
data = data.to_tensordict()
|
||||
data = left_right_2_no_padding(data)
|
||||
output = self.engine.infer_batch(data)
|
||||
|
||||
if self.engine.is_mp_src_rank_with_outputs():
|
||||
output = output["model_output"]
|
||||
log_probs = output["log_probs"]
|
||||
log_probs = no_padding_2_padding(log_probs, data) # (bsz, response_length)
|
||||
|
||||
entropy = output["entropy"]
|
||||
if entropy is not None:
|
||||
entropy = no_padding_2_padding(entropy, data) # (bsz, response_length)
|
||||
|
||||
# in megatron, only last pp contains valid data and returned to the single controller
|
||||
output = DataProto.from_dict(
|
||||
tensors={"old_log_probs": output["log_probs"].float(), "entropy": output["entropy"].float()},
|
||||
tensors={"old_log_probs": log_probs.float(), "entropy": entropy.float()},
|
||||
)
|
||||
output = output.to("cpu")
|
||||
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
|
||||
@ -155,6 +163,7 @@ class ActorWorker(Worker, DistProfilerExtension):
|
||||
mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size
|
||||
# TODO: make worker API to accept TensorDict as well
|
||||
mini_batch = mini_batch.to_tensordict()
|
||||
mini_batch = left_right_2_no_padding(mini_batch)
|
||||
output = self.engine.train_batch(mini_batch, self.loss_fn)
|
||||
mini_batch_metrics = output.get("metrics", {})
|
||||
append_to_dict(metrics, mini_batch_metrics, prefix="actor/")
|
||||
|
@ -35,6 +35,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.workers.config import CriticConfig
|
||||
from verl.workers.roles.utils.losses import value_loss
|
||||
from verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
@ -140,13 +141,17 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
with self.engine.eval_mode():
|
||||
# TODO: make worker API to accept TensorDict as well
|
||||
data = data.to_tensordict()
|
||||
data = left_right_2_no_padding(data)
|
||||
output = self.engine.infer_batch(data)
|
||||
|
||||
if self.engine.is_mp_src_rank_with_outputs():
|
||||
# in megatron, only last pp contains valid data and returned to the single controller
|
||||
output = output["model_output"]
|
||||
values = output["values"]
|
||||
values = no_padding_2_padding(values, data) # (bsz, response_length)
|
||||
|
||||
output = DataProto.from_dict(
|
||||
tensors={"values": output["values"].float()},
|
||||
tensors={"values": values.float()},
|
||||
)
|
||||
output = output.to("cpu")
|
||||
|
||||
@ -177,6 +182,7 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size
|
||||
# TODO: make worker API to accept TensorDict as well
|
||||
mini_batch = mini_batch.to_tensordict()
|
||||
mini_batch = left_right_2_no_padding(mini_batch)
|
||||
output = self.engine.train_batch(mini_batch, self.loss_fn)
|
||||
mini_batch_metrics = output.get("metrics", {})
|
||||
append_to_dict(metrics, mini_batch_metrics, prefix="critic/")
|
||||
|
@ -21,10 +21,11 @@ from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.workers.config import ActorConfig, CriticConfig
|
||||
from verl.workers.roles.utils.padding import no_padding_2_padding
|
||||
|
||||
|
||||
def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):
|
||||
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING)
|
||||
|
||||
log_prob = model_output["log_probs"]
|
||||
|
||||
@ -52,6 +53,10 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
|
||||
log_prob = model_output["log_probs"]
|
||||
entropy = model_output.get("entropy", None)
|
||||
|
||||
log_prob = no_padding_2_padding(log_prob, data) # (bsz, response_length)
|
||||
if entropy is not None:
|
||||
entropy = no_padding_2_padding(entropy, data) # (bsz, response_length)
|
||||
|
||||
metrics = {}
|
||||
|
||||
response_mask = data["response_mask"].to(bool)
|
||||
@ -105,7 +110,7 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
|
||||
|
||||
def value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=None):
|
||||
vpreds = model_output["values"]
|
||||
values = data["values"]
|
||||
vpreds = no_padding_2_padding(vpreds, data) # (bsz, response_length)
|
||||
|
||||
values = data["values"]
|
||||
returns = data["returns"]
|
||||
|
119
verl/workers/roles/utils/padding.py
Normal file
119
verl/workers/roles/utils/padding.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.device import (
|
||||
is_cuda_available,
|
||||
is_npu_available,
|
||||
)
|
||||
|
||||
if is_cuda_available:
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
elif is_npu_available:
|
||||
from transformers.integrations.npu_flash_attention import pad_input, unpad_input
|
||||
|
||||
|
||||
def left_right_2_no_padding(data: TensorDict) -> TensorDict:
|
||||
"""
|
||||
Convert TensorDict from left-right padding to no-padding format.
|
||||
|
||||
Args:
|
||||
data: TensorDict with "input_ids", "attention_mask", "response_mask", "position_ids"
|
||||
|
||||
Returns:
|
||||
data: TensorDict with
|
||||
- Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids"
|
||||
- NonTensorData includes "max_seq_len", "max_response_len", "indices"
|
||||
|
||||
Note:
|
||||
1. the return input_ids/position_ids/loss_mask are nested tensor.
|
||||
2. we will remove "attention_mask", "response" in the return data, but "response_mask" is kept.
|
||||
"""
|
||||
assert "input_ids" in data, "input_ids is required in left-right padding data"
|
||||
assert "attention_mask" in data, "attention_mask is required in left-right padding data"
|
||||
assert "response_mask" in data, "response_mask is required in left-right padding data"
|
||||
assert "position_ids" in data, "position_ids is required in left-right padding data"
|
||||
|
||||
input_ids = data.pop("input_ids")
|
||||
attention_mask = data.pop("attention_mask")
|
||||
response_mask = data["response_mask"]
|
||||
if "responses" in data:
|
||||
_ = data.pop("responses")
|
||||
|
||||
max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1]
|
||||
tu.assign_non_tensor_data(data, "max_seq_len", max_seq_len)
|
||||
tu.assign_non_tensor_data(data, "max_response_len", max_response_len)
|
||||
|
||||
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
|
||||
tu.assign_non_tensor_data(data, "indices", indices)
|
||||
|
||||
input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens)
|
||||
|
||||
seq_lens = cu_seqlens.diff().tolist()
|
||||
response_lens = response_mask.sum(dim=1).tolist()
|
||||
|
||||
position_ids_list = []
|
||||
loss_mask_list = []
|
||||
for seq_len, response_len in zip(seq_lens, response_lens, strict=False):
|
||||
position_ids_list.append(torch.arange(seq_len, device=input_ids.device))
|
||||
loss_mask = torch.zeros(seq_len, dtype=torch.bool, device=input_ids.device)
|
||||
assert seq_len >= response_len, f"{seq_len=} is less than {response_len=}"
|
||||
loss_mask[-response_len:] = 1
|
||||
loss_mask_list.append(loss_mask)
|
||||
|
||||
position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged)
|
||||
loss_mask_nested = torch.nested.as_nested_tensor(loss_mask_list, layout=torch.jagged)
|
||||
|
||||
data["input_ids"] = input_ids_nested
|
||||
data["position_ids"] = position_ids_nested
|
||||
data["loss_mask"] = loss_mask_nested
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def no_padding_2_padding(nested_tensor: torch.Tensor, data: TensorDict) -> torch.Tensor:
|
||||
"""
|
||||
Convert NestedTensor from no-padding to right padding format.
|
||||
|
||||
Args:
|
||||
nested_tensor: NestedTensor with no-padding format
|
||||
data: TensorDict with
|
||||
- Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids"
|
||||
- NonTensorData includes "max_seq_len", "max_response_len", "indices"
|
||||
|
||||
Returns:
|
||||
values: regular tensor right padded to max_response_len
|
||||
"""
|
||||
assert "indices" in data, "indices is required in left-right padding data"
|
||||
assert "max_seq_len" in data, "max_seq_len is required in left-right padding data"
|
||||
assert "max_response_len" in data, "max_response_len is required in left-right padding data"
|
||||
|
||||
indices = tu.get_non_tensor_data(data=data, key="indices", default=None)
|
||||
max_seq_len = tu.get_non_tensor_data(data=data, key="max_seq_len", default=2048)
|
||||
max_response_len = tu.get_non_tensor_data(data=data, key="max_response_len", default=1024)
|
||||
batch_size = nested_tensor.size(0)
|
||||
|
||||
values = nested_tensor.values()
|
||||
full_values = pad_input(
|
||||
hidden_states=values.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=max_seq_len,
|
||||
)
|
||||
values = full_values.squeeze(-1)[:, -max_response_len - 1 : -1] # (bsz, response_length)
|
||||
|
||||
return values
|
Reference in New Issue
Block a user