mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and
`pre-commit` as the managing
tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and
`pylint`, simply merging might trigger overwhelming linting errors,
while **you are only expected to resolve ones in the files related to
your PR**.
>
> To resolve this issue, please try the following workaround to only
include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check
--fix && ruff-format`
> 2. Squash into a single commit in a new branch: `git reset --soft
$(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`
We add the reminder above to the documentation to tell contributors how
to avoid overwhelming linting errors.
### Motivation
According to dicussion in #896, this PR immigrates from yapf & pylint to
ruff based on pre-commit, which allows unified version control and
automatic hook on committing.
### Summary
The `pre-commit` hook and CI
- checks staged / committed files in commits / PR's
- checks all files each month (This should fail before we fix all the
files by the ruff standard)
### Explanation for the Failing CI Workflow `pre-commit`
For now, we only apply `ruff format` and `ruff check --fix` **without
resolving all the errors**, since there are too many errors to resolve,
which causes the CI workflow `pre-commit` fails.
For resolving the remaining errors, we leave to future commits.
Specifically, the `pre-commit` hook and CI will require every commit to
fix its related files with `ruff`, which will fix all the files
incrementally.
### Reviewing Suggestion
The commit
3d93f51ba8
is huge since we apply `ruff` to all the files. To review the main
changes, please check the commits before and after it.
149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
# Copyright 2024 PRIME team and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
|
"""
|
|
|
|
import hydra
|
|
import ray
|
|
|
|
from .prime_ray_trainer import RayPRIMETrainer
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="prime_trainer", version_base=None)
|
|
def main(config):
|
|
run_prime(config)
|
|
|
|
|
|
def run_prime(config, compute_score=None):
|
|
if not ray.is_initialized():
|
|
# this is for local ray cluster
|
|
ray.init(
|
|
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
|
|
)
|
|
|
|
ray.get(main_task.remote(config, compute_score))
|
|
|
|
|
|
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
|
|
def main_task(config, compute_score=None):
|
|
# print initial config
|
|
from pprint import pprint
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from verl.utils.fs import copy_local_path_from_hdfs
|
|
|
|
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
OmegaConf.resolve(config)
|
|
|
|
# download the checkpoint from hdfs
|
|
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
|
|
|
# instantiate tokenizer
|
|
from verl.utils import hf_tokenizer
|
|
|
|
tokenizer = hf_tokenizer(local_path)
|
|
|
|
# define worker classes
|
|
if config.actor_rollout_ref.actor.strategy == "fsdp":
|
|
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
from verl.single_controller.ray import RayWorkerGroup
|
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker
|
|
|
|
ray_worker_group_cls = RayWorkerGroup
|
|
|
|
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
|
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
|
|
from verl.workers.megatron_workers import ActorRolloutRefWorker
|
|
|
|
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
|
|
|
role_worker_mapping = {
|
|
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
|
}
|
|
|
|
global_pool_id = "global_pool"
|
|
resource_pool_spec = {
|
|
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
|
}
|
|
mapping = {
|
|
Role.ActorRollout: global_pool_id,
|
|
}
|
|
|
|
# use reference model
|
|
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
|
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
|
mapping[Role.RefPolicy] = global_pool_id
|
|
|
|
if config.reward_model.enable:
|
|
from .prime_fsdp_workers import PRIMERewardModelWorker
|
|
|
|
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
|
|
mapping[Role.RewardModel] = global_pool_id
|
|
|
|
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
|
if reward_manager_name == "naive":
|
|
from verl.workers.reward_manager import NaiveRewardManager
|
|
|
|
reward_manager_cls = NaiveRewardManager
|
|
elif reward_manager_name == "prime":
|
|
from verl.workers.reward_manager import PrimeRewardManager
|
|
|
|
reward_manager_cls = PrimeRewardManager
|
|
else:
|
|
raise NotImplementedError
|
|
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
|
|
|
|
# Note that we always use function-based RM for validation
|
|
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
|
|
|
|
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
|
|
|
trainer = RayPRIMETrainer(
|
|
config=config,
|
|
tokenizer=tokenizer,
|
|
role_worker_mapping=role_worker_mapping,
|
|
resource_pool_manager=resource_pool_manager,
|
|
ray_worker_group_cls=ray_worker_group_cls,
|
|
reward_fn=reward_fn,
|
|
val_reward_fn=val_reward_fn,
|
|
)
|
|
trainer.init_workers()
|
|
trainer.fit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|