mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842 Approved by: https://github.com/oulgen
132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
"""
|
|
The following example demonstrates how to use Pytorch Distributed Checkpoint to save a FSDP model.
|
|
|
|
This is the current recommended way to checkpoint FSDP.
|
|
torch.save() and torch.load() is not recommended when checkpointing sharded models.
|
|
"""
|
|
|
|
import os
|
|
import shutil
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.checkpoint as dist_cp
|
|
import torch.multiprocessing as mp
|
|
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
|
|
|
CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"
|
|
|
|
|
|
def opt_at(opt, idx):
|
|
return list(opt.state.values())[idx]
|
|
|
|
|
|
def init_model():
|
|
model = FSDP(torch.nn.Linear(4, 4).cuda(dist.get_rank()))
|
|
optim = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
model(torch.rand(4, 4)).sum().backward()
|
|
optim.step()
|
|
|
|
return model, optim
|
|
|
|
|
|
def print_params(stage, model_1, model_2, optim_1, optim_2):
|
|
with FSDP.summon_full_params(model_1):
|
|
with FSDP.summon_full_params(model_2):
|
|
print(
|
|
f"{stage} --- rank: {dist.get_rank()}\n"
|
|
f"model.weight: {model_1.weight}\n"
|
|
f"model_2.weight:{model_2.weight}\n"
|
|
f"model.bias: {model_1.bias}\n"
|
|
f"model_2.bias: {model_2.bias}\n"
|
|
)
|
|
|
|
print(
|
|
f"{stage} --- rank: {dist.get_rank()}\n"
|
|
f"optim exp_avg:{opt_at(optim_1, 0)['exp_avg']}\n"
|
|
f"optim_2 exp_avg:{opt_at(optim_2, 0)['exp_avg']}\n"
|
|
f"optim exp_avg_sq:{opt_at(optim_1, 0)['exp_avg_sq']}\n"
|
|
f"optim_2 exp_avg_sq:{opt_at(optim_2, 0)['exp_avg_sq']}\n"
|
|
)
|
|
|
|
|
|
def run_fsdp_checkpoint_example(rank, world_size):
|
|
# Set up world pg
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12355"
|
|
|
|
# Initialize the process group
|
|
dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
|
|
torch.cuda.set_device(rank)
|
|
|
|
# Create a model
|
|
model_1, optim_1 = init_model()
|
|
|
|
# Save the model to CHECKPOINT_DIR
|
|
with FSDP.state_dict_type(model_1, StateDictType.SHARDED_STATE_DICT):
|
|
state_dict = {
|
|
"model": model_1.state_dict(),
|
|
"optim": FSDP.optim_state_dict(model_1, optim_1),
|
|
}
|
|
|
|
dist_cp.save_state_dict(
|
|
state_dict=state_dict,
|
|
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
|
|
)
|
|
|
|
# Create a second model
|
|
model_2, optim_2 = init_model()
|
|
|
|
# Print the model parameters for both models.
|
|
# Before loading, the parameters should be different.
|
|
print_params("Before loading", model_1, model_2, optim_1, optim_2)
|
|
|
|
# Load model_2 with parameters saved in CHECKPOINT_DIR
|
|
with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT):
|
|
state_dict = {
|
|
"model": model_2.state_dict(),
|
|
# cannot load the optimizer state_dict together with the model state_dict
|
|
}
|
|
|
|
dist_cp.load_state_dict(
|
|
state_dict=state_dict,
|
|
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
|
)
|
|
model_2.load_state_dict(state_dict["model"])
|
|
|
|
optim_state = load_sharded_optimizer_state_dict(
|
|
model_state_dict=state_dict["model"],
|
|
optimizer_key="optim",
|
|
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
|
)
|
|
|
|
flattened_osd = FSDP.optim_state_dict_to_load(
|
|
model_2, optim_2, optim_state["optim"]
|
|
)
|
|
optim_2.load_state_dict(flattened_osd)
|
|
|
|
# Print the model parameters for both models.
|
|
# After loading, the parameters should be the same.
|
|
print_params("After loading", model_1, model_2, optim_1, optim_2)
|
|
|
|
# Shut down world pg
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
world_size = torch.cuda.device_count()
|
|
print(f"Running fsdp checkpoint example on {world_size} devices.")
|
|
shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
|
|
mp.spawn(
|
|
run_fsdp_checkpoint_example,
|
|
args=(world_size,),
|
|
nprocs=world_size,
|
|
join=True,
|
|
)
|