Files
verl/tests/checkpoint/test_fsdp_ckpt.py
Shawn/Yuxuan Tong b00f77d855 [dev] feat: immigrate from yapf & pylint to ruff based on pre-commit (#1010)
> [!WARNING]
> We are [immigrating to `ruff` as the linter and formatter and
`pre-commit` as the managing
tool](https://github.com/volcengine/verl/pull/1010).
>
> If your branch is based on a previous commit using `yapf` and
`pylint`, simply merging might trigger overwhelming linting errors,
while **you are only expected to resolve ones in the files related to
your PR**.
>
> To resolve this issue, please try the following workaround to only
include the files you **really changed** in the PR:
>
> 1. In your branch, fix linting and format with `ruff`: `ruff check
--fix && ruff-format`
> 2. Squash into a single commit in a new branch: `git reset --soft
$(git merge-base main HEAD) && git add -A && git commit -m "feat: ..."`
> 3. Merge with the latest main: `git merge origin/main`
> 4. Force push to your branch: `git push --force`

We add the reminder above to the documentation to tell contributors how
to avoid overwhelming linting errors.

### Motivation

According to dicussion in #896, this PR immigrates from yapf & pylint to
ruff based on pre-commit, which allows unified version control and
automatic hook on committing.

### Summary

The `pre-commit` hook and CI

- checks staged / committed files in commits / PR's
- checks all files each month (This should fail before we fix all the
files by the ruff standard)

### Explanation for the Failing CI Workflow `pre-commit`

For now, we only apply `ruff format` and `ruff check --fix` **without
resolving all the errors**, since there are too many errors to resolve,
which causes the CI workflow `pre-commit` fails.

For resolving the remaining errors, we leave to future commits.
Specifically, the `pre-commit` hook and CI will require every commit to
fix its related files with `ruff`, which will fix all the files
incrementally.

### Reviewing Suggestion

The commit
3d93f51ba8
is huge since we apply `ruff` to all the files. To review the main
changes, please check the commits before and after it.
2025-04-18 07:49:31 -07:00

127 lines
4.5 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import torch
import torch.distributed
from torch.distributed import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.distributed import initialize_global_process_group
def test_fsdp_ckpt():
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",))
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
config = Qwen2Config(num_hidden_layers=1)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device="cuda")
# Wrap model with FSDP
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
model = FSDP(
model,
use_orig_params=False,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
# Create checkpoint manager
tokenizer = AutoTokenizer.from_pretrained(model_name)
checkpoint_manager = FSDPCheckpointManager(
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer
)
# Generate sample input
batch_size = 2
seq_len = 32
vocab_size = 32000
# First input for initial update
input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
attention_mask1 = torch.ones_like(input_ids1)
# Second input for verification
input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
attention_mask2 = torch.ones_like(input_ids2)
# Step 1: Initial update and save checkpoint
outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)
loss1 = outputs1.logits.mean()
loss1.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Save checkpoint after first update
temp_dir = tempfile.mkdtemp()
checkpoint_path = os.path.join(temp_dir, "checkpoint")
checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
# Step 2: Second update and forward pass
outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)
loss2 = outputs2.logits.mean()
loss2.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Record logits after second update
with torch.no_grad():
logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
# Step 3: Load checkpoint and repeat second update
checkpoint_manager.load_checkpoint(checkpoint_path)
# Repeat the second update with same input
outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)
loss3 = outputs3.logits.mean()
loss3.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Record logits after loaded checkpoint and update
with torch.no_grad():
logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
# Step 4: Verify outputs match
torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0)
print("Checkpoint save/load test passed!")
# Cleanup
shutil.rmtree(temp_dir)
torch.distributed.barrier()
if __name__ == "__main__":
test_fsdp_ckpt()