mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
86 lines
3.6 KiB
Python
86 lines
3.6 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import time
|
|
|
|
import ray
|
|
|
|
from verl.single_controller.base.worker import Worker
|
|
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
|
|
|
|
|
|
@ray.remote
|
|
class TestActor(Worker):
|
|
# TODO: pass *args and **kwargs is bug prone and not very convincing
|
|
def __init__(self, cuda_visible_devices=None) -> None:
|
|
super().__init__(cuda_visible_devices)
|
|
|
|
def get_node_id(self):
|
|
return ray.get_runtime_context().get_node_id()
|
|
|
|
|
|
def test():
|
|
ray.init()
|
|
|
|
# test single-node-no-partition
|
|
print("test single-node-no-partition")
|
|
resource_pool = RayResourcePool([8], use_gpu=True)
|
|
|
|
class_with_args = RayClassWithInitArgs(cls=TestActor)
|
|
|
|
print("create actor worker group")
|
|
actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor")
|
|
print("create critic worker group")
|
|
critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic")
|
|
print("create rm worker group")
|
|
rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm")
|
|
print("create ref worker group")
|
|
ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref")
|
|
|
|
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
|
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
|
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
|
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
|
|
|
del actor_wg
|
|
del critic_wg
|
|
del rm_wg
|
|
del ref_wg
|
|
|
|
[ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
|
|
print("wait 5s to remove placemeng_group")
|
|
time.sleep(5)
|
|
# test single-node-multi-partition
|
|
|
|
print("test single-node-multi-partition")
|
|
rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
|
|
ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
|
|
total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
|
|
|
|
assert rm_resource_pool.world_size == 4
|
|
assert ref_resource_pool.world_size == 4
|
|
assert total_resource_pool.world_size == 8
|
|
|
|
actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor")
|
|
critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic")
|
|
rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm")
|
|
ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref")
|
|
|
|
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
|
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
|
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
|
|
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]
|
|
|
|
ray.shutdown()
|