mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adds Checkpointer Wrapper for DCP [3/N] (#114603)
Adds a useful high level wrapper for calling `dist.save/load` with the correct storage readers and writers. Instead of doing: ``` DCP.save( state_dict={...}, storage_writer=StorageWriter(...) ) DCP.load( state_dict={...}, storage_reader=StorageReader(...) ) ``` We can now do: ``` checkpointer = Checkpointer(...) checkpointer.save(state_dict={...}) checkpointer.load(state_dict={...}) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114603 Approved by: https://github.com/fegin, https://github.com/wz337
This commit is contained in:
committed by
PyTorch MergeBot
parent
3b01f30b20
commit
5432088098
@ -78,16 +78,15 @@ def run(rank, world_size, device="cuda"):
|
||||
model, optim = _init_model(device, world_size)
|
||||
_train(model, optim, train_steps=2)
|
||||
|
||||
DCP.save(
|
||||
checkpointer = DCP.FileSystemCheckpointer(CHECKPOINT_DIR)
|
||||
checkpointer.save(
|
||||
state_dict={"model": model, "optimizer": optim},
|
||||
storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
|
||||
)
|
||||
|
||||
# presumably do something else
|
||||
model, optim = _init_model(device, world_size)
|
||||
DCP.load(
|
||||
checkpointer.load(
|
||||
state_dict={"model": model, "optimizer": optim},
|
||||
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
|
||||
)
|
||||
_train(model, optim, train_steps=2)
|
||||
|
||||
|
Reference in New Issue
Block a user