Adds Checkpointer Wrapper for DCP [3/N] (#114603)

Adds a useful high level wrapper for calling `dist.save/load` with the correct storage readers and writers.

Instead of doing:

```
DCP.save(
    state_dict={...},
    storage_writer=StorageWriter(...)
)

DCP.load(
    state_dict={...},
    storage_reader=StorageReader(...)
)
```

We can now do:

```
checkpointer = Checkpointer(...)

checkpointer.save(state_dict={...})
checkpointer.load(state_dict={...})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114603
Approved by: https://github.com/fegin, https://github.com/wz337
This commit is contained in:
Lucas Pasqualin
2023-12-08 01:03:17 +00:00
committed by PyTorch MergeBot
parent 3b01f30b20
commit 5432088098
7 changed files with 155 additions and 13 deletions

View File

@ -78,16 +78,15 @@ def run(rank, world_size, device="cuda"):
model, optim = _init_model(device, world_size)
_train(model, optim, train_steps=2)
DCP.save(
checkpointer = DCP.FileSystemCheckpointer(CHECKPOINT_DIR)
checkpointer.save(
state_dict={"model": model, "optimizer": optim},
storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
)
# presumably do something else
model, optim = _init_model(device, world_size)
DCP.load(
checkpointer.load(
state_dict={"model": model, "optimizer": optim},
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
_train(model, optim, train_steps=2)