Summary:
Avoid in-place update and deepcopy during dudpe. Deepcopy becomes prohibitively expensive with models having a huge number of FQNs. This was manifestd in the Ads 2K experiment as well. Here are the results from the TextRay model in Mitra:
#### Control job with deepcopy regression:
First save ~24.8s
Global step latency is ~7-8s
Test job with the new fix to avoid deepcopy:
First save is ~21s
global step latency ~2s
Test Plan:
```
buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/distributed/checkpoint:test_planner
```
https://www.internalfb.com/intern/testinfra/testrun/3940649945104822
Differential Revision: D71245218
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149320
Approved by: https://github.com/MeetVadakkanchery
Summary:
This PR caches the save plans to significantly reduce the collective cost for successive checkpoint save attempts. Here is the high level approach:
- Create the local plan and cache the same.
- In next iteration, compare the local plan with the cached plan metadata. If no change, do not send that local plan in the collective.
- Global plan step, will only create the global plan with the new delta plans and empty plans for the cached ones.
- Finish plan step will check for the empty plans. If its empty, it will grab the cached plan. If not, it will use the new plan provided.
Test Plan: UTs
Differential Revision: D69224491
## How to enable the caching:
DefaultSavePlanner introduces the enable_plan_caching which is set to False by default for now.
https://github.com/pytorch/pytorch/pull/147343/files#diff-579bbb7b82572753afa91085fbf954f7c7613ff8376da9b26153d5cc3a3c4ee8R77
Set this to True to enable the caching and we should see significant speed up in the subsequent checkpoint save attempts, specially for larger scale jobs. Reference issue: https://github.com/pytorch/pytorch/issues/123695
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147343
Approved by: https://github.com/MeetVadakkanchery
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
The original DCP doesn't flattening all the containers, which can cause issues, https://github.com/pytorch/pytorch/pull/125335 intends to solve the issue by flattening all the dictionaries.
Unfortunately, it breaks the checkpoints that are saved before 2.4. This
also shows some issues of the DCP:
1. DCP should record version in the metadata.
2. DCP should have a nice way to load old state_dict.
3. DCP should unflatten all containers (map, list) not just map.
This PR only addresses issue 2 to unblock users. Issue 1 and issue 3 need to be addressed in the future.
@pradeepfn Please let me know if this summary matches our discussion.
Fixes https://github.com/pytorch/pytorch/issues/133923
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134158
Approved by: https://github.com/wz337, https://github.com/pradeepfn
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
This PR seeks to increase observability of save/load requests. This is accomplished with two main changes:
1. The creation of save_id and load_id:
- a save_id and load_id is added to the filesystem writer. `save_id` is re-generated on every save call, and `load_id` is also re-generated on every load call.
- both these ID's are stored in a new `StorageMeta` class, and saved as part of Metadata. (`load_id` is None when we save, and only set during load)
2. A new mechanism is implemented in the save path which gives the SavePlanner a chance to inspect the `storage_meta` object. The mechanism mirrors the same metadata exchange in the load path. In the load path, `storage_meta` is added to `metadata` such that the LoadPlanner can also access `storage_meta` before we begin loading.
*If users now wish to access the checkpoint_id in the SavePlanner, they simple need to access the value in `storage_meta` from the `set_up_planner` call*
*Additionally, users now have a generic way of passing data to the SavePlanner from the StorageWriter at the start of the save path, similar to the load path*
This PR has been tested for backwards compatibility -- meaning any checkpoints saved before this PR can continue being loaded after this PR.
One major consideration is that there is limited forwards compatibility. If a checkpoint is generated _past_ this PR, there is no support for loading it using older torch versions. This brings up a fairly important point: since we expect the metadata object (which is saved to the disk) to continue evolving, and we want to support forwards compatibility, we explore patching `pickle` so we can at least add new members to `metadata` and maintain fwd compat.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124772
Approved by: https://github.com/fegin
~Users may have custom use cases for the `strict` parameter in load. In my mind, if we automatically call `state_dict` and `load_state_dict` in save/load, we need to support the same functionality in `nn.Modules`.~
It turns out this is actually not related to nn.Module's strict param. Since `state_dict` is called inside `dcp.load`, it's actually impossible to create a model such that the following would raise an error:
```
state_dict = module.state_dict()
module.load_state_dict(state_dict, strict=True)
```
The issue is actually just when there are elements in `state_dict` which do not exist in the checkpoint. This PR adds the ability to configure this behavior through the DefaultSavePlanner (see tests).
Concretely, if module has extra attributes not present in the checkpoint, we will only raise an error if `DefaultLoadPlanner.allow_partial_load==False`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123869
Approved by: https://github.com/fegin
Fixes some issues with `_load_state_dict_keys`, including:
* updates broken test, which was failing due to incorrect parameters
* adds support for specifying nested keys e.g. (load state dict keys can now specify `something like "optimizer.state"`, which loads all keys under `optimzier.state`.
* updates call site to use the private implementation of `_load_state_dict`, which properly handles empty state dicts (otherwise the keys are ignored)
Big shout out to @diego-urgell who not only identified current issues, but recommended the right solutions!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123679
Approved by: https://github.com/diego-urgell, https://github.com/wz337
Enables the deduplication of saved entries by load balancing duplicates across ranks.
Tested with existing and modified tests. Additionally tested with the following code snippet, which saves a 20GB DDP model in **~3 seconds on 8 ranks**. Before this PR, the same operation has been measured at ~19 seconds.
```
def run(local_rank, world_size, param_size, num_params, work_dir):
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", rank=local_rank, world_size=world_size)
model = Model(param_size=param_size, num_params=num_params)
model = DistributedDataParallel(model, gradient_as_bucket_view=True)
_patch_model_state_dict(model)
sz = sum(t.nelement() * t.element_size() for t in model.parameters())
rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB")
rank_0_print("Saving the model with DCP...")
checkpointer = _FileSystemCheckpointer(
f"{args.work_dir}/dcp",
sync_files=False,
single_file_per_rank=False,
thread_count=1
)
begin_ts = time.monotonic()
checkpointer.save(state_dict={"model": model})
end_ts = time.monotonic()
rank_0_print(f"Took {end_ts - begin_ts} seconds with DCP")
```
Differential Revision: [D52435926](https://our.internmc.facebook.com/intern/diff/D52435926/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116469
Approved by: https://github.com/fegin, https://github.com/wz337
Currently, DCP requires the `model.state_dict()` to be materialized before passing it to DCP to load, since DCP uses the pre-allocated storage from the initialized model state_dict. Therefore, even for fine-tuning and distributed inference, users would need to explicitly materialize the model on GPU before `DCP.load_state_dict()`.
Today's flow:
```
with torch.device("meta"):
model2 = parallelize_module(
MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
)
model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load)
```
This PR adds support for meta tensor loading. In DCP's planner, when encountering tensors/DTensor on meta device, we initialize tensor/DTensor on the current device on the fly and replace the tensor/DTensor on meta device in the state_dict. After the change, users no longer needs to manually call `model.to_empty()` when loading existing checkpoints for fine-tuning and distributed inference.
Updated user flow:
```
with torch.device("meta"):
model2 = parallelize_module(
MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
)
# no longer need to call model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load, assign=True)
```
Note that for distributed training, it's still the users' responsibility to reset the parameters (`model.reset_parameters()`) as checkpoint might not exist.
Note that we need to loop thru the state_dict to replace meta tensor/DTensor instead of calling `model.to_empty()` since `DCP.load()` only takes in state_dict but not model.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113319
Approved by: https://github.com/fegin, https://github.com/LucasLLC
Fixes: #113193
`pydocstyle <all_files_in_issue> --count`
- Before: 345
- After: 130
For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
Currently, DCP treats tensors as duplicates and only saves them on rank0. This won't work for PiPPy as PiPPy does have unique tensors across different ranks. With the current setup, we would only be saving the tensors on rank0 (coordinator rank).
In this PR, we are changing to letting each rank create its own WriteItem for tensors. For the ones that does replicate across different ranks, we are handling it thru dedup_tensors(), which will dedup the replicate WriteItem so we only do the actual writing once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106415
Approved by: https://github.com/wz337
With distributed checkpointing in PyTorch/XLA SPMD, the WriteItem index hints should not be modified when creating the global plan. In order to reuse the default planner logic for checkpoint metadata creation, we need to make the behavior of rewriting index hints optional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105861
Approved by: https://github.com/kumpera
DTensor submesh support is added in https://github.com/pytorch/pytorch/pull/95458.
This PR adds support for DTensor submesh by adding an extra check when create local save/load plan.
If the rank is not participating in the mesh, we simply skip creating WriteItem/ReadItem for the local SavePlan/LoadPlan.
Updated the associated test as well.
cc. @wanchaol, @kumpera
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96802
Approved by: https://github.com/wanchaol
DTensor submesh support is added in https://github.com/pytorch/pytorch/pull/95458.
This PR adds support for DTensor submesh by adding an extra check when create local save/load plan.
If the rank is not participating in the mesh, we simply skip creating WriteItem/ReadItem for the local SavePlan/LoadPlan.
Updated the associated test as well.
cc. @wanchaol, @kumpera
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96802
Approved by: https://github.com/wanchaol
This is the last PR for integrating 2D into core distributed.
This PR does the following:
1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state.
2. Update default_planner.py to support 2D checkpoint.
3. Add test_fsdp_optim_state.py as a unit test for No. 1.
4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py
5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. #90328
Docstring and integration test will be added in the following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90212
Approved by: https://github.com/wanchaol