Files
verl/tests/ray/test_worker_group_torch.py
Shawn/Yuxuan Tong b00f77d855 [dev] feat: immigrate from yapf & pylint to ruff based on pre-commit (#1010)
> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and
`pre-commit` as the managing
tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and
`pylint`, simply merging might trigger overwhelming linting errors,
while **you are only expected to resolve ones in the files related to
your PR**.
>
> To resolve this issue, please try the following workaround to only
include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check
--fix && ruff-format`
> 2. Squash into a single commit in a new branch: `git reset --soft
$(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`

We add the reminder above to the documentation to tell contributors how
to avoid overwhelming linting errors.

### Motivation

According to dicussion in #896, this PR immigrates from yapf & pylint to
ruff based on pre-commit, which allows unified version control and
automatic hook on committing.

### Summary

The `pre-commit` hook and CI

- checks staged / committed files in commits / PR's
- checks all files each month (This should fail before we fix all the
files by the ruff standard)

### Explanation for the Failing CI Workflow `pre-commit`

For now, we only apply `ruff format` and `ruff check --fix` **without
resolving all the errors**, since there are too many errors to resolve,
which causes the CI workflow `pre-commit` fails.

For resolving the remaining errors, we leave to future commits.
Specifically, the `pre-commit` hook and CI will require every commit to
fix its related files with `ruff`, which will fix all the files
incrementally.

### Reviewing Suggestion

The commit
3d93f51ba8
is huge since we apply `ruff` to all the files. To review the main
changes, please check the commits before and after it.
2025-04-18 07:49:31 -07:00

112 lines
3.5 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["NCCL_DEBUG"] = "WARN"
import ray
import torch
import torch.distributed
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ray.remote
class TestAllGatherActor(Worker):
def __init__(self, size) -> None:
super().__init__()
self.size = size
def init(self):
torch.distributed.init_process_group()
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda")
self.tensor += self.rank
def all_gather(self):
world_size = self._world_size
output = torch.zeros(
size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
)
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
return output
@ray.remote
class TestAllGatherActorV2(Worker):
def __init__(self, size) -> None:
super().__init__()
self.size = size
torch.distributed.init_process_group()
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda")
self.tensor += self.rank
def all_gather(self):
world_size = self._world_size
output = torch.zeros(
size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
)
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
return output
def test_all_gather_torch():
"""
In this test, we instantiate 4 GPUs in a group and test the all_gather
"""
ray.init()
# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)
worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
worker_group.execute_all_sync("init")
output = worker_group.execute_all_sync("all_gather")
for i in range(1, len(output)):
assert torch.all(output[i] == output[0])
output = output[0].cpu()
print(output)
assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
ray.shutdown()
def test_all_gather_torch_v2():
"""
In this test, we instantiate 4 GPUs in a group and test the all_gather
"""
ray.init()
# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)
worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
output = worker_group.execute_all_sync("all_gather")
for i in range(1, len(output)):
assert torch.all(output[i] == output[0])
output = output[0].cpu()
print(output)
assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
ray.shutdown()