Commit Graph

39 Commits

Author SHA1 Message Date
381d0cb239 [DCP] Avoid in-place update and deepcopy during dudpe (#149320)
Summary:
Avoid in-place update and deepcopy during dudpe. Deepcopy becomes prohibitively expensive with models having a huge number of FQNs. This was manifestd in the Ads 2K experiment as well. Here are the results from the TextRay model in Mitra:

#### Control job with deepcopy regression:
First save ~24.8s
Global step latency is ~7-8s

Test job with the new fix to avoid deepcopy:
First save is ~21s
global step latency ~2s

Test Plan:
```
buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/distributed/checkpoint:test_planner
```
https://www.internalfb.com/intern/testinfra/testrun/3940649945104822

Differential Revision: D71245218

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149320
Approved by: https://github.com/MeetVadakkanchery
2025-03-18 16:08:40 +00:00
136b8165d1 [DCP] Save Plan Caching: Fix the missing all_plans update in the cache. (#148577)
Summary: Save Plan Caching: Fix the missing all_plans update in the cache.

Test Plan:
```
buck2 test //aiplatform/modelstore/experimental/integration_tests/tests/nosan:checkpoint_dist_save_load_test
```

https://www.internalfb.com/intern/testinfra/testrun/17451448626323264

Reviewed By: MeetVadakkanchery

Differential Revision: D70229019

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148577
Approved by: https://github.com/MeetVadakkanchery
2025-03-07 17:00:59 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
6eb3d1e762 [DCP] Cache save plans in default planner (#147343)
Summary:
This PR caches the save plans to significantly reduce the collective cost for successive checkpoint save attempts. Here is the high level approach:
-  Create the local plan and cache the same.
- In next iteration, compare the local plan with the cached plan metadata. If no change, do not send that local plan in the collective.
- Global plan step, will only create the global plan with the new delta plans and empty plans for the cached ones.
- Finish plan step will check for the empty plans. If its empty, it will grab the cached plan. If not, it will use the new plan provided.

Test Plan: UTs

Differential Revision: D69224491

## How to enable the caching:
DefaultSavePlanner introduces the enable_plan_caching which is set to False by default for now.
https://github.com/pytorch/pytorch/pull/147343/files#diff-579bbb7b82572753afa91085fbf954f7c7613ff8376da9b26153d5cc3a3c4ee8R77
Set this to True to enable the caching and we should see significant speed up in the subsequent checkpoint save attempts, specially for larger scale jobs. Reference issue: https://github.com/pytorch/pytorch/issues/123695

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147343
Approved by: https://github.com/MeetVadakkanchery
2025-02-25 20:59:25 +00:00
316808e4e9 PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163
Approved by: https://github.com/Skylion007
2025-01-19 20:55:59 +00:00
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
371bcc1e33 [checkpointing][oss] Throw an error when loading a different size than saved tensor (#141571)
Summary: Fixing issue reported in https://github.com/pytorch/pytorch/issues/126604

Test Plan: buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/distributed/checkpoint:test_planner -- --exact 'caffe2/test/distributed/checkpoint:test_planner - test_planner.TestLoadPlanner: test_strict

Differential Revision: D66389578

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141571
Approved by: https://github.com/mhorowitz
2024-12-11 15:35:48 +00:00
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00
c7338f457c [DCP] Fixes the BC issue where the traversal doesn't support versions before 2.4 (#134158)
The original DCP doesn't flattening all the containers, which can cause issues, https://github.com/pytorch/pytorch/pull/125335 intends to solve the issue by flattening all the dictionaries.

Unfortunately, it breaks the checkpoints that are saved before 2.4. This
also shows some issues of the DCP:

1. DCP should record version in the metadata.
2. DCP should have a nice way to load old state_dict.
3. DCP should unflatten all containers (map, list) not just map.

This PR only addresses issue 2 to unblock users. Issue 1 and issue 3 need to be addressed in the future.

@pradeepfn Please let me know if this summary matches our discussion.

Fixes https://github.com/pytorch/pytorch/issues/133923

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134158
Approved by: https://github.com/wz337, https://github.com/pradeepfn
2024-08-28 16:31:44 +00:00
35f36363ec Revert "[dtensor] move DTensor to public namespace (#133113)"
This reverts commit 2ee6b97464d17fcf4c1fc67c29868fa30d0c16e1.

Reverted https://github.com/pytorch/pytorch/pull/133113 on behalf of https://github.com/wanchaol due to looks like it break some internal type imports ([comment](https://github.com/pytorch/pytorch/pull/133113#issuecomment-2295670911))
2024-08-19 05:00:19 +00:00
2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00
ad314a2f05 Pass torch.load(weights_only=) internally to avoid FutureWarning (#130663)
Fixes #130658

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130663
Approved by: https://github.com/malfet, https://github.com/LucasLLC
2024-07-16 01:24:38 +00:00
e6d4451ae8 [BE][Easy] enable UFMT for torch/distributed/{algorithms,autograd,benchmarks,checkpoint,elastic}/ (#128866)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128866
Approved by: https://github.com/fegin
2024-06-18 13:51:53 +00:00
3a0d088517 Flip default value for mypy disallow_untyped_defs [5/11] (#127842)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842
Approved by: https://github.com/oulgen
2024-06-08 18:49:18 +00:00
e57f51b80f Update _dedup_save_plans.py (#126569)
To resolve https://github.com/pytorch/pytorch/issues/125740, save each tensor on the lowest rank.

Fixes #125740

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126569
Approved by: https://github.com/LucasLLC
2024-06-03 01:55:03 +00:00
ec7f2b2626 [DCP] adds type safety to str filtering in EmptyStateDict (#126082)
[DCP] adds type safety to str filtering in EmptyStateDict

Differential Revision: [D57281133](https://our.internmc.facebook.com/intern/diff/D57281133/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126082
Approved by: https://github.com/fegin, https://github.com/wz337, https://github.com/Skylion007
2024-05-13 22:13:05 +00:00
bb6ba31250 [DCP] Adds storage metadata, and passes it during the save path (#124772)
This PR seeks to increase observability of save/load requests. This is accomplished with two main changes:

1. The creation of save_id and load_id:
    - a save_id and load_id is added to the filesystem writer. `save_id` is re-generated on every save call, and `load_id` is also re-generated on every load call.
    - both these ID's are stored in a new `StorageMeta` class, and saved as part of Metadata. (`load_id` is None when we save, and only set during load)

2. A new mechanism is implemented in the save path which gives the SavePlanner a chance to inspect the `storage_meta` object. The mechanism mirrors the same metadata exchange in the load path. In the load path, `storage_meta` is added to `metadata` such that the LoadPlanner can also access `storage_meta` before we begin loading.

*If users now wish to access the checkpoint_id in the SavePlanner, they simple need to access the value in `storage_meta` from the `set_up_planner` call*

*Additionally, users now have a generic way of passing data to the SavePlanner from the StorageWriter at the start of the save path, similar to the load path*

This PR has been tested for backwards compatibility -- meaning any checkpoints saved before this PR can continue being loaded after this PR.

One major consideration is that there is limited forwards compatibility. If a checkpoint is generated _past_ this PR, there is no support for loading it using older torch versions. This brings up a fairly important point: since we expect the metadata object (which is saved to the disk) to continue evolving, and we want to support forwards compatibility, we explore patching `pickle` so we can at least add new members to `metadata` and maintain fwd compat.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124772
Approved by: https://github.com/fegin
2024-05-07 23:53:53 +00:00
22767e4791 [DCP] Always create requests for non-tensor objects (#125334)
Summary:
If an object only exists on certain non-coordinator ranks, we still need to save them. Otherwise, we lose these objects. If they are duplicated, DCP will deduplicate them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125334
Approved by: https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #125333, #125501
2024-05-07 17:04:36 +00:00
5c7b71dccf [DCP] Adds strict option to DefaultPlanner (#123869)
~Users may have custom use cases for the `strict` parameter in load. In my mind, if we automatically call `state_dict` and `load_state_dict` in save/load, we need to support the same functionality in `nn.Modules`.~

It turns out this is actually not related to nn.Module's strict param. Since `state_dict` is called inside `dcp.load`, it's actually impossible to create a model such that the following would raise an error:
```
state_dict = module.state_dict()
module.load_state_dict(state_dict, strict=True)
```

The issue is actually just when there are elements in `state_dict` which do not exist in the checkpoint. This PR adds the ability to configure this behavior through the DefaultSavePlanner (see tests).

Concretely, if module has extra attributes not present in the checkpoint, we will only raise an error if `DefaultLoadPlanner.allow_partial_load==False`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123869
Approved by: https://github.com/fegin
2024-05-02 22:50:32 +00:00
b7fac76fc2 [DCP fixes for _load_state_dict_keys and supports nested keys (#123679)
Fixes some issues with `_load_state_dict_keys`, including:
  * updates broken test, which was failing due to incorrect parameters
  * adds support for specifying nested keys e.g. (load state dict keys can now specify `something like "optimizer.state"`, which loads all keys under `optimzier.state`.
  * updates call site to use the private implementation of `_load_state_dict`, which properly handles empty state dicts (otherwise the keys are ignored)

Big shout out to @diego-urgell who not only identified current issues, but recommended the right solutions!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123679
Approved by: https://github.com/diego-urgell, https://github.com/wz337
2024-04-11 20:52:06 +00:00
bcb6e5aa72 [DCP] Support partial load (#122829)
Adds ability to load a subset of keys directly from a checkpoint, avoiding the need to initialize state dict first

Differential Revision: [D55441391](https://our.internmc.facebook.com/intern/diff/D55441391/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122829
Approved by: https://github.com/fegin
2024-04-02 19:22:22 +00:00
ff8e33556e Enables load balancing duplicates in DCP (#116469)
Enables the deduplication of saved entries by load balancing duplicates across ranks.

Tested with existing and modified tests. Additionally tested with the following code snippet, which saves a 20GB DDP model in **~3 seconds on 8 ranks**.  Before this PR, the same operation has been measured at ~19 seconds.

```
def run(local_rank, world_size, param_size, num_params, work_dir):

    os.environ["RANK"] = str(local_rank)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    dist.init_process_group(backend="nccl", rank=local_rank, world_size=world_size)

    model = Model(param_size=param_size, num_params=num_params)
    model = DistributedDataParallel(model, gradient_as_bucket_view=True)
    _patch_model_state_dict(model)

    sz = sum(t.nelement() * t.element_size() for t in model.parameters())
    rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB")
    rank_0_print("Saving the model with DCP...")

    checkpointer = _FileSystemCheckpointer(
        f"{args.work_dir}/dcp",
        sync_files=False,
        single_file_per_rank=False,
        thread_count=1
    )

    begin_ts = time.monotonic()
    checkpointer.save(state_dict={"model": model})
    end_ts = time.monotonic()
    rank_0_print(f"Took {end_ts - begin_ts} seconds with DCP")
```

Differential Revision: [D52435926](https://our.internmc.facebook.com/intern/diff/D52435926/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116469
Approved by: https://github.com/fegin, https://github.com/wz337
2024-01-26 22:34:14 +00:00
f518cf811d [DCP] Adds support for meta tensor loading for DCP.load_state_dict() (#113319)
Currently, DCP requires the `model.state_dict()` to be materialized before passing it to DCP to load, since DCP uses the pre-allocated storage from the initialized model state_dict. Therefore, even for fine-tuning and distributed inference, users would need to explicitly materialize the model on GPU before `DCP.load_state_dict()`.

Today's flow:
```
with torch.device("meta"):
    model2 = parallelize_module(
        MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
    )

model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
    state_dict=state_dict_to_load,
    storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load)
```

This PR adds support for meta tensor loading. In DCP's planner, when encountering tensors/DTensor on meta device, we initialize tensor/DTensor on the current device on the fly and replace the tensor/DTensor on meta device in the state_dict.  After the change, users no longer needs to manually call `model.to_empty()` when loading existing checkpoints for fine-tuning and distributed inference.

Updated user flow:
```
with torch.device("meta"):
    model2 = parallelize_module(
        MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
    )
# no longer need to call model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
    state_dict=state_dict_to_load,
    storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load, assign=True)
```

Note that for distributed training, it's still the users' responsibility to reset the parameters (`model.reset_parameters()`) as checkpoint might not exist.

Note that we need to loop thru the state_dict to replace meta tensor/DTensor instead of calling `model.to_empty()` since `DCP.load()` only takes in state_dict but not model.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113319
Approved by: https://github.com/fegin, https://github.com/LucasLLC
2024-01-17 00:23:29 +00:00
db8d409d08 [DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)
No logic change. Just typing and ufmt.

Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #115523
2023-12-13 10:32:36 +00:00
44c0521e8c fix: docstring error in torch/distributed module (#113241)
Fixes: #113193

`pydocstyle <all_files_in_issue> --count`

- Before: 345
- After: 130

For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
2023-11-09 19:10:20 +00:00
6b39cf863f Fix invalid arg to getLogger in torch distributed checkpoint (#110008)
Ran the experimental LOG002 ruff check and found a bug in our codebase. Logger should not be instantiated from `__file__`, it should be instantiated from `__name__`

https://docs.astral.sh/ruff/rules/invalid-get-logger-argument/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110008
Approved by: https://github.com/ezyang
2023-09-25 18:21:18 +00:00
cbcd9083be [DCP] Modify tensor saving logic in DCP (#106415)
Currently, DCP treats tensors as duplicates and only saves them on rank0. This won't work for PiPPy as PiPPy does have unique tensors across different ranks. With the current setup, we would only be saving the tensors on rank0 (coordinator rank).

In this PR, we are changing to letting each rank create its own WriteItem for tensors. For the ones that does replicate across different ranks, we are handling it thru dedup_tensors(), which will dedup the replicate WriteItem so we only do the actual writing once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106415
Approved by: https://github.com/wz337
2023-08-09 00:16:10 +00:00
1032a2541e Add option to disable rewriting index hints in default global save plan (#105861)
With distributed checkpointing in PyTorch/XLA SPMD, the WriteItem index hints should not be modified when creating the global plan. In order to reuse the default planner logic for checkpoint metadata creation, we need to make the behavior of rewriting index hints optional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105861
Approved by: https://github.com/kumpera
2023-07-25 06:00:13 +00:00
5a458a9df4 Convert logging f-strings to use % format, part three (#98704)
This does triple-quoted strings.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98704
Approved by: https://github.com/voznesenskym, https://github.com/albanD
2023-04-11 13:17:56 +00:00
b09722f540 Convert logging f-strings to use % format, part two (#98700)
This hits multi-line logging strings

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98700
Approved by: https://github.com/voznesenskym
2023-04-10 12:19:31 +00:00
38b687ed4d [PTD][Checkpoint] Add checkpointing support for DTensor submesh (#96802)
DTensor submesh support is added in https://github.com/pytorch/pytorch/pull/95458.

This PR adds support for DTensor submesh by adding an extra check when create local save/load plan.
If the rank is not participating in the mesh, we simply skip creating WriteItem/ReadItem for the local SavePlan/LoadPlan.

Updated the associated test as well.

cc. @wanchaol, @kumpera
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96802
Approved by: https://github.com/wanchaol
2023-03-21 08:17:17 +00:00
417e7bc09f Revert "[PTD][Checkpoint] Add checkpointing support for DTensor submesh (#96802)"
This reverts commit cfa6b52e02eb61f71c0034d5b7e73e365420f35a.

Reverted https://github.com/pytorch/pytorch/pull/96802 on behalf of https://github.com/huydhn due to This breaks distributed test cfa6b52e02. Probably a landrace as PR signal was green
2023-03-17 01:04:43 +00:00
cfa6b52e02 [PTD][Checkpoint] Add checkpointing support for DTensor submesh (#96802)
DTensor submesh support is added in https://github.com/pytorch/pytorch/pull/95458.

This PR adds support for DTensor submesh by adding an extra check when create local save/load plan.
If the rank is not participating in the mesh, we simply skip creating WriteItem/ReadItem for the local SavePlan/LoadPlan.

Updated the associated test as well.

cc. @wanchaol, @kumpera
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96802
Approved by: https://github.com/wanchaol
2023-03-16 22:16:58 +00:00
bb347dc3c3 [PTD][DCP] Add 1D DTensor based DCP (#94868)
Add 1D DTensor based DCP along with its test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94868
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-02-16 23:38:04 +00:00
e16daa78a0 [PT-D][Checkpoint] Turn on all default planner flags (#92933)
Fixes #92823

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92933
Approved by: https://github.com/kumpera
2023-02-08 06:30:45 +00:00
dd4b46e010 [PT-D][Checkpoint]rename init() (#92829)
Fixes [#90346](https://github.com/pytorch/pytorch/issues/90346)

Rename init() method in planner to be set_up_planner() to avoid confusion between __init__() and init().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92829
Approved by: https://github.com/kumpera
2023-01-24 00:12:21 +00:00
f7e1f3e8bb [PT-D][Checkpoint]Resolve issue #89501: Rename _nested_tensor.py to (#92705)
Fixes https://github.com/pytorch/pytorch/issues/90350.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92705
Approved by: https://github.com/kumpera
2023-01-23 21:45:11 +00:00
b8b7480065 [Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed (#90212)
This is the last PR for integrating 2D into core distributed.

This PR does the following:
1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state.
2. Update default_planner.py to support 2D checkpoint.
3. Add test_fsdp_optim_state.py as a unit test for No. 1.
4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py
5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. #90328

Docstring and integration test will be added in the following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90212
Approved by: https://github.com/wanchaol
2022-12-08 02:53:29 +00:00
aee96bbf5a [PT-D][Checkpointing] Move distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (#88698)
Context in RFC: https://github.com/pytorch/pytorch/issues/86620

.rst file will be finalized in subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88698
Approved by: https://github.com/wanchaol
2022-11-16 21:06:38 +00:00