Commit Graph

90 Commits

Author SHA1 Message Date
f903bc475c [BE] add noqa for flake8 rule B036: found except BaseException without re-raising (#159043)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159043
Approved by: https://github.com/Skylion007
2025-07-25 02:56:34 +00:00
4ccc0381de [BE][5/16] fix typos in torch/ (torch/distributed/) (#156315)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156315
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #156313, #156314
2025-06-23 02:57:28 +00:00
145d4cdc11 Revert "[BE][5/16] fix typos in torch/ (torch/distributed/) (#156315)"
This reverts commit c2f0292bd5b4b3206f5b295e96f81cd6c178eb18.

Reverted https://github.com/pytorch/pytorch/pull/156315 on behalf of https://github.com/atalman due to export/test_torchbind.py::TestCompileTorchbind::test_compile_error_on_input_aliasing_contents_backend_aot_eager [GH job link](https://github.com/pytorch/pytorch/actions/runs/15804799771/job/44548489912) [HUD commit link](c95f7fa874) ([comment](https://github.com/pytorch/pytorch/pull/156313#issuecomment-2994171213))
2025-06-22 12:31:57 +00:00
c2f0292bd5 [BE][5/16] fix typos in torch/ (torch/distributed/) (#156315)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156315
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #156313, #156314
2025-06-22 08:43:26 +00:00
4aded85e79 Fix space typo in warning message (#143473)
Warning shows up like this (no space between willbe):
```
/home/xxx/.local/lib/python3.11/site-packages/torch/distributed/fsdp/_state_dict_utils.py:827:
UserWarning: When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict willbe returned.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143473
Approved by: https://github.com/mikaylagawarecki, https://github.com/kwen2501
2025-03-31 07:38:02 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
6344ca1dd4 [BE][Ez]: Apply FURB188: use str remove(pre|suf)fix (#146997)
Since we are on 3.9, we can use this nice str builtin which is more readable and more efficient.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146997
Approved by: https://github.com/XuehaiPan, https://github.com/cyyever, https://github.com/jansel
2025-02-14 03:38:07 +00:00
c64e657632 PEP585 update - torch/distributed/fsdp (#145162)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145162
Approved by: https://github.com/bobrenjc93
2025-01-19 20:04:05 +00:00
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
e248c1d7eb Update real device in FSDP state_dict_utils (#134994)
## Motivation
The default device for tensor.device both for sharded as well as non sharded is set to cuda by default. Hence while checking the FSDP UTs we see the following errors. This change updates the actual device type based on the created tensor.

```
[rank3]   File "/root/repos/pytorch-training-tests/tests/pytorch/v2.4.0/distributed_hpu/fsdp/test_fsdp_dtensor_state_dict.py", line 143, in test_dtensor_sharded_tensor_state_dict_identical
[rank3]     sharded_tensor_sd = ref_model.state_dict()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1944, in state_dict
[rank3]     hook_result = hook(self, destination, prefix, local_metadata)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]     return func(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 752, in _post_state_dict_hook
[rank3]     tensor.device,
[rank3]   File "/usr/local/lib/python3.10/dist-packages/typing_extensions.py", line 2853, in wrapper
[rank3]     return arg(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1152, in __torch_function__
[rank3]     return dispatch(st_instance, func)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in dispatch
[rank3]     return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/op_registry_utils.py", line 33, in wrapper
[rank3]     return wrapped_func(types, args, kwargs, process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py", line 52, in tensor_device
[rank3]     dev = torch.device(torch.cuda.current_device())
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 878, in current_device
[rank3]     _lazy_init()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 305, in _lazy_init
[rank3]     raise AssertionError("Torch not compiled with CUDA enabled")
[rank3] AssertionError: Torch not compiled with CUDA enabled
````

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134994
Approved by: https://github.com/fegin
2024-09-17 04:39:08 +00:00
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00
35f36363ec Revert "[dtensor] move DTensor to public namespace (#133113)"
This reverts commit 2ee6b97464d17fcf4c1fc67c29868fa30d0c16e1.

Reverted https://github.com/pytorch/pytorch/pull/133113 on behalf of https://github.com/wanchaol due to looks like it break some internal type imports ([comment](https://github.com/pytorch/pytorch/pull/133113#issuecomment-2295670911))
2024-08-19 05:00:19 +00:00
2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00
87053132ea [DeviceMesh] Remove parent mesh concept from _MeshEnv and replace by root mesh (#132339)
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```

We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:

```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
2024-08-07 07:01:12 +00:00
4d7bf72d93 [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130206
Approved by: https://github.com/malfet
2024-07-14 08:17:52 +00:00
3b798df853 [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869
Approved by: https://github.com/fegin
ghstack dependencies: #128868
2024-06-18 21:49:08 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
de8af28083 [FSDP][StateDict] Allow FULL_STATE_DICT option for 2D (#120837)
Fixes #120722

TL;DR for the issue:
As users are expected to use get_model_state_dict to do state_dict retrieval, I think it's fine to remove the warning and RuntimeError.
More context in #120722.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120837
Approved by: https://github.com/Skylion007
2024-03-05 10:03:44 +00:00
9db6a849ed [FSDP] Clean missing and unexpected keys (#120600)
Currently, when loading w/strict=False or w/strict=True and looking at
error message, FQNs are garbled w/FSDP details such as "_fsdp_wrapped_module".
This makes it tricky for upstream applications to validate the expected set of
keys are missing / unexpected (for example with PEFT where state_dict is loaded
non-strict), and makes error message more complicated w/FSDP details.

This PR cleans those prefixes by using `clean_tensor_name` in FSDP's existing
post load_state_dict hooks. Currently, only full_state_dict impl is tested, can test the rest of the impls as follow up work.

Differential Revision: [D54182472](https://our.internmc.facebook.com/intern/diff/D54182472/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120600
Approved by: https://github.com/XilunWu, https://github.com/fegin
2024-02-27 07:43:45 +00:00
72fec96e59 fix no shard state dict loading (#120367)
Summary: fix no shard state dict loading

Test Plan: CI tests

Differential Revision: D51058607

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120367
Approved by: https://github.com/fegin
2024-02-23 07:25:43 +00:00
23fa9621e4 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099) (#115193)
Summary:

Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.

Test Plan: CI.

Differential Revision: D51861018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
2023-12-08 08:44:32 +00:00
a827ac71f2 Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)"
This reverts commit eaa64339d640ed1d36520ada379213f8361be5ff.
2023-12-05 08:59:36 -08:00
eaa64339d6 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.

Test Plan: CI.

Differential Revision: D51825114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-12-05 05:44:52 +00:00
3a2e2044cd Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)"
This reverts commit 729ac7317a50a6a195b324cf6cefd748bf4f5498.

Reverted https://github.com/pytorch/pytorch/pull/114991 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114991#issuecomment-1837214567))
2023-12-02 17:55:51 +00:00
729ac7317a [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)
Summary:

Same content of changes as https://github.com/pytorch/pytorch/pull/114710

Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
ghstack-source-id: 208980207
exported-using-ghexport

Test Plan: CI.

Reviewed By: wanchaol

Differential Revision: D51629761

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114991
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/fegin
2023-12-02 04:39:41 +00:00
4ba649e207 [FSDP][state_dict] Avoid assigning the root _device_mesh to the children _device_mesh (#114384)
Assigning the root _device_mesh to the children _device_mesh is not correct as each FSDP state can have a different DeviceMesh. We are also replacing fully_shard with a new implementation. So there is no need to worry about the fully_shard behavior.

Differential Revision: [D51507959](https://our.internmc.facebook.com/intern/diff/D51507959/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114384
Approved by: https://github.com/wz337
2023-11-30 02:08:31 +00:00
ca9e654353 [FSDP] Fix FSDP submodule with DeviceMesh does not return DTensor state_dict error (#113593)
For scenarios where FSDP is not the root module, the `_use_dtensor` flag would not be switched on. This PR fixes it by checking whether the submodule has the `device_mesh` and turn `_use_dtensor` flag on accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113593
Approved by: https://github.com/fegin
2023-11-15 19:00:19 +00:00
31ded95cd5 [2D] Bind _fsdp_extension to FSDP instances (#113237)
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).

This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-11-09 03:31:03 +00:00
12c1465d76 [DeviceMesh] Make mesh_resources private (#112294)
This is to prepare moving DeviceMesh as a standalone distributed package.

`_mesh_resources` should only be used in torch.distributed package.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112294
Approved by: https://github.com/fegin
2023-10-28 17:28:46 +00:00
2a86bcbac2 [FSDP][state_dict] Cleanup the usage of _get_pg_default_device (#112168)
_get_pg_default_device is not suitable for FSDP use case. We should always use the compute_device when communicating.

Differential Revision: [D50698730](https://our.internmc.facebook.com/intern/diff/D50698730/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112168
Approved by: https://github.com/wz337
2023-10-27 08:09:08 +00:00
80dfc974dd [2D] Enable 2D FSDP+TP model.load_state_dict() (#110925)
This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.

cc. @fegin

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110925
Approved by: https://github.com/fegin
ghstack dependencies: #110831, #110846
2023-10-11 18:22:20 +00:00
6c136c3302 [2D] Enable 2D DTensor state_dict for FSDP + TP (#110846)
This PR adds a `chunk_dtensor()` method to fsdp/_fsdp_extensions.py and the actual implementation of `chunk_dtensor()` in tensor/parallel/fsdp.py. This enables FSDP to return 2D DTensor state_dict when composing FSDP with TP.

cc. @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110846
Approved by: https://github.com/fegin, https://github.com/wanchaol
ghstack dependencies: #110831
2023-10-11 17:40:39 +00:00
b89ce814c0 [FSDP] Remove _set_use_dtensor in post_load_state_dict_hook (#109924)
This is a follow up for https://github.com/pytorch/pytorch/pull/109767.
We only need _set_use_dtensor in pre_state_dict_hook() and pre_load_state_dict_hook() and we do not need _set_use_dtensor in _post_load_state_dict_hook(). This PR removes _set_use_dtensor in post_load_state_dict_hook().

In addition, this PR adjusts the test cases in test_hsdp_dtensor_state_dict.py to capture changes in https://github.com/pytorch/pytorch/pull/109767

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109924
Approved by: https://github.com/fegin
2023-09-23 22:34:36 +00:00
a5145364d9 [FSDP] Fix _use_dtensor not automatically turn on for model state dict when using DeviceMesh (#109767)
Fixes #109648

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109767
Approved by: https://github.com/fegin
2023-09-21 15:15:45 +00:00
f558e86fa0 [FSDP] continue if param not exist in sharded load (#109116)
If I add a param and then wrap with FSDP + load state dict, when
strict=False don't hard error here.

Differential Revision: [D49170812](https://our.internmc.facebook.com/intern/diff/D49170812/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109116
Approved by: https://github.com/fegin
2023-09-14 07:14:18 +00:00
66af4f6ec7 [HSDP] Add device_mesh to FSDP kwarg and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-05 21:21:21 +00:00
591cb776af [FSDP][state_dict][optim_state_dict] Log slow optim and model state_dict paths (#108290)
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.

This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.

Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
2023-09-01 06:57:59 +00:00
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a80d7c01a4a58b0f386ca07236a6927a.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
ddf36c82b8 [PT-D][FSDP] Handle corner case of load with multi-backend PG (#107172)
Summary:
When loading a CPU state_dict with a pg initialized with
cpu:gloo,cuda:nccl, we hit a gloo crash since dest tensor is on GPU and input
is on CPU.

As a workaround, just enforce that if local_tensor.is_cpu, the dest tensor is
also cpu.

Test Plan: CI

Differential Revision: D48324752

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107172
Approved by: https://github.com/fegin
2023-08-14 23:24:44 +00:00
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
c54f630201 [7/n][FSDP] make use_dtensor=True work with offload_to_cpu=True for load_state_dict (#105378)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105378
Approved by: https://github.com/fegin
2023-07-19 21:36:37 +00:00
434fcffa21 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-06 05:36:19 +00:00
fcb53c1394 Revert "[6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)"
This reverts commit 49af83cf442ef569c8eb4f5a29f46a65abc0e2d2.

Reverted https://github.com/pytorch/pytorch/pull/104087 on behalf of https://github.com/huydhn due to This is failing in trunk 49af83cf44, probably due to a land race ([comment](https://github.com/pytorch/pytorch/pull/104087#issuecomment-1615608189))
2023-07-01 07:50:31 +00:00
49af83cf44 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-01 01:02:59 +00:00
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
1c33c398c7 [FSDP][state_dict] Add a summary log when finishing state_dict (#103784)
Add a summary log when finishing state_dict

Differential Revision: [D46807103](https://our.internmc.facebook.com/intern/diff/D46807103/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103784
Approved by: https://github.com/fduwjj
2023-06-22 16:29:24 +00:00