mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and
`pre-commit` as the managing
tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and
`pylint`, simply merging might trigger overwhelming linting errors,
while **you are only expected to resolve ones in the files related to
your PR**.
>
> To resolve this issue, please try the following workaround to only
include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check
--fix && ruff-format`
> 2. Squash into a single commit in a new branch: `git reset --soft
$(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`
We add the reminder above to the documentation to tell contributors how
to avoid overwhelming linting errors.
### Motivation
According to dicussion in #896, this PR immigrates from yapf & pylint to
ruff based on pre-commit, which allows unified version control and
automatic hook on committing.
### Summary
The `pre-commit` hook and CI
- checks staged / committed files in commits / PR's
- checks all files each month (This should fail before we fix all the
files by the ruff standard)
### Explanation for the Failing CI Workflow `pre-commit`
For now, we only apply `ruff format` and `ruff check --fix` **without
resolving all the errors**, since there are too many errors to resolve,
which causes the CI workflow `pre-commit` fails.
For resolving the remaining errors, we leave to future commits.
Specifically, the `pre-commit` hook and CI will require every commit to
fix its related files with `ruff`, which will fix all the files
incrementally.
### Reviewing Suggestion
The commit
3d93f51ba8
is huge since we apply `ruff` to all the files. To review the main
changes, please check the commits before and after it.
127 lines
4.5 KiB
Python
127 lines
4.5 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
|
|
import torch
|
|
import torch.distributed
|
|
from torch.distributed import init_device_mesh
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
|
|
|
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
|
from verl.utils.distributed import initialize_global_process_group
|
|
|
|
|
|
def test_fsdp_ckpt():
|
|
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
|
|
local_rank, rank, world_size = initialize_global_process_group()
|
|
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",))
|
|
|
|
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
config = Qwen2Config(num_hidden_layers=1)
|
|
|
|
with torch.device("cuda"):
|
|
model = AutoModelForCausalLM.from_config(
|
|
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
|
)
|
|
model = model.to(device="cuda")
|
|
|
|
# Wrap model with FSDP
|
|
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
|
|
|
|
model = FSDP(
|
|
model,
|
|
use_orig_params=False,
|
|
device_id=torch.cuda.current_device(),
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
mixed_precision=mixed_precision,
|
|
device_mesh=device_mesh,
|
|
)
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
|
|
|
|
# Create checkpoint manager
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
checkpoint_manager = FSDPCheckpointManager(
|
|
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer
|
|
)
|
|
|
|
# Generate sample input
|
|
batch_size = 2
|
|
seq_len = 32
|
|
vocab_size = 32000
|
|
# First input for initial update
|
|
input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
|
|
attention_mask1 = torch.ones_like(input_ids1)
|
|
|
|
# Second input for verification
|
|
input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
|
|
attention_mask2 = torch.ones_like(input_ids2)
|
|
|
|
# Step 1: Initial update and save checkpoint
|
|
outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)
|
|
loss1 = outputs1.logits.mean()
|
|
loss1.backward()
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
# Save checkpoint after first update
|
|
temp_dir = tempfile.mkdtemp()
|
|
checkpoint_path = os.path.join(temp_dir, "checkpoint")
|
|
checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
|
|
|
|
# Step 2: Second update and forward pass
|
|
outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)
|
|
loss2 = outputs2.logits.mean()
|
|
loss2.backward()
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
# Record logits after second update
|
|
with torch.no_grad():
|
|
logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
|
|
|
|
# Step 3: Load checkpoint and repeat second update
|
|
checkpoint_manager.load_checkpoint(checkpoint_path)
|
|
|
|
# Repeat the second update with same input
|
|
outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)
|
|
loss3 = outputs3.logits.mean()
|
|
loss3.backward()
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
# Record logits after loaded checkpoint and update
|
|
with torch.no_grad():
|
|
logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
|
|
|
|
# Step 4: Verify outputs match
|
|
torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0)
|
|
print("Checkpoint save/load test passed!")
|
|
|
|
# Cleanup
|
|
shutil.rmtree(temp_dir)
|
|
torch.distributed.barrier()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_fsdp_ckpt()
|