mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and
`pre-commit` as the managing
tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and
`pylint`, simply merging might trigger overwhelming linting errors,
while **you are only expected to resolve ones in the files related to
your PR**.
>
> To resolve this issue, please try the following workaround to only
include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check
--fix && ruff-format`
> 2. Squash into a single commit in a new branch: `git reset --soft
$(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`
We add the reminder above to the documentation to tell contributors how
to avoid overwhelming linting errors.
### Motivation
According to dicussion in #896, this PR immigrates from yapf & pylint to
ruff based on pre-commit, which allows unified version control and
automatic hook on committing.
### Summary
The `pre-commit` hook and CI
- checks staged / committed files in commits / PR's
- checks all files each month (This should fail before we fix all the
files by the ruff standard)
### Explanation for the Failing CI Workflow `pre-commit`
For now, we only apply `ruff format` and `ruff check --fix` **without
resolving all the errors**, since there are too many errors to resolve,
which causes the CI workflow `pre-commit` fails.
For resolving the remaining errors, we leave to future commits.
Specifically, the `pre-commit` hook and CI will require every commit to
fix its related files with `ruff`, which will fix all the files
incrementally.
### Reviewing Suggestion
The commit
3d93f51ba8
is huge since we apply `ruff` to all the files. To review the main
changes, please check the commits before and after it.
162 lines
5.5 KiB
Python
162 lines
5.5 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Using FSDPTrainer
|
|
"""
|
|
|
|
import os
|
|
|
|
import hydra
|
|
import ray
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
|
|
from verl import DataProto
|
|
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
|
from verl.utils.fs import copy_to_local
|
|
|
|
|
|
def make_reward_function(tokenizer, num_examine):
|
|
def arithmetic_sequence_reward_function(data: DataProto, return_dict: bool = False):
|
|
from tests.e2e.envs.digit_completion.task import compute_reward
|
|
|
|
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
|
|
|
for i in range(data.batch.batch_size[0]):
|
|
data_item = data[i] # DataProtoItem
|
|
|
|
prompt_ids = data_item.batch["prompts"]
|
|
|
|
prompt_length = prompt_ids.shape[-1]
|
|
|
|
# extract raw prompt
|
|
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
|
|
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
|
|
|
|
# extract response
|
|
response_ids = data_item.batch["responses"]
|
|
response_length = response_ids.shape[-1]
|
|
response_mask = data.batch["attention_mask"][i][-response_length:]
|
|
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
|
|
valid_response_ids = response_ids[:valid_response_length]
|
|
|
|
# decode
|
|
prompt = tokenizer.decode(valid_prompt_ids)
|
|
response = tokenizer.decode(valid_response_ids)
|
|
# remove bos and eos
|
|
prompt = prompt.replace(tokenizer.sep_token, "")
|
|
response = response.replace(tokenizer.eos_token, "")
|
|
if i < num_examine:
|
|
print(prompt, response)
|
|
|
|
reward_output = compute_reward(prompt, response)
|
|
dense_reward = reward_output[0].tolist()
|
|
ground_truth_response = reward_output[1]["ground_truth_response"]
|
|
if len(dense_reward) > 0:
|
|
last_reward = dense_reward[-1]
|
|
else:
|
|
if len(ground_truth_response) == 0:
|
|
last_reward = 1
|
|
else:
|
|
last_reward = 0
|
|
|
|
# pad to response_length
|
|
for _ in range(reward_tensor.shape[-1] - len(dense_reward)):
|
|
dense_reward.append(last_reward)
|
|
|
|
dense_reward = torch.as_tensor(dense_reward, dtype=torch.float32, device=reward_tensor.device)
|
|
reward_tensor[i] = dense_reward * response_mask
|
|
|
|
if return_dict:
|
|
return {"reward_tensor": reward_tensor}
|
|
else:
|
|
return reward_tensor
|
|
|
|
return arithmetic_sequence_reward_function
|
|
|
|
|
|
@hydra.main(config_path="../../../../verl/trainer/config", config_name="ppo_trainer", version_base=None)
|
|
def main(config):
|
|
ray.init(
|
|
runtime_env={
|
|
"env_vars": {
|
|
"MEGATRON_USE_CUDA_TIMER": "0",
|
|
"MEGATRON_START_PROCESS_TIMER": "False",
|
|
"TOKENIZERS_PARALLELISM": "true",
|
|
"NCCL_DEBUG": "WARN",
|
|
}
|
|
}
|
|
)
|
|
|
|
# print initial config
|
|
from pprint import pprint
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
|
|
# print the config
|
|
# print initial config
|
|
print("Config after normalizing batch_size")
|
|
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
|
|
# download the checkpoint from hdfs
|
|
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
|
local_path = os.path.expanduser(local_path)
|
|
# instantiate tokenizern
|
|
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
|
print(f"Tokenizer vocab_size: {tokenizer.vocab_size}")
|
|
|
|
# define worker classes
|
|
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
|
|
|
role_worker_mapping = {
|
|
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
|
Role.Critic: ray.remote(CriticWorker),
|
|
}
|
|
|
|
global_pool_id = "global_pool"
|
|
resource_pool_spec = {
|
|
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
|
}
|
|
mapping = {
|
|
Role.ActorRollout: global_pool_id,
|
|
Role.Critic: global_pool_id,
|
|
}
|
|
|
|
# use reward model
|
|
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
|
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
|
mapping[Role.RefPolicy] = global_pool_id
|
|
|
|
reward_fn = make_reward_function(tokenizer=tokenizer, num_examine=1)
|
|
|
|
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
|
|
|
trainer = RayPPOTrainer(
|
|
config=config,
|
|
tokenizer=tokenizer,
|
|
role_worker_mapping=role_worker_mapping,
|
|
resource_pool_manager=resource_pool_manager,
|
|
reward_fn=reward_fn,
|
|
val_reward_fn=reward_fn,
|
|
)
|
|
trainer.init_workers()
|
|
trainer.fit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|