## Motivation
The default device for tensor.device both for sharded as well as non sharded is set to cuda by default. Hence while checking the FSDP UTs we see the following errors. This change updates the actual device type based on the created tensor.
```
[rank3] File "/root/repos/pytorch-training-tests/tests/pytorch/v2.4.0/distributed_hpu/fsdp/test_fsdp_dtensor_state_dict.py", line 143, in test_dtensor_sharded_tensor_state_dict_identical
[rank3] sharded_tensor_sd = ref_model.state_dict()
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1944, in state_dict
[rank3] hook_result = hook(self, destination, prefix, local_metadata)
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3] return func(*args, **kwargs)
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 752, in _post_state_dict_hook
[rank3] tensor.device,
[rank3] File "/usr/local/lib/python3.10/dist-packages/typing_extensions.py", line 2853, in wrapper
[rank3] return arg(*args, **kwargs)
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1152, in __torch_function__
[rank3] return dispatch(st_instance, func)
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in dispatch
[rank3] return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/op_registry_utils.py", line 33, in wrapper
[rank3] return wrapped_func(types, args, kwargs, process_group)
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py", line 52, in tensor_device
[rank3] dev = torch.device(torch.cuda.current_device())
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 878, in current_device
[rank3] _lazy_init()
[rank3] File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 305, in _lazy_init
[rank3] raise AssertionError("Torch not compiled with CUDA enabled")
[rank3] AssertionError: Torch not compiled with CUDA enabled
````
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134994
Approved by: https://github.com/fegin
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```
We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
Currently, when loading w/strict=False or w/strict=True and looking at
error message, FQNs are garbled w/FSDP details such as "_fsdp_wrapped_module".
This makes it tricky for upstream applications to validate the expected set of
keys are missing / unexpected (for example with PEFT where state_dict is loaded
non-strict), and makes error message more complicated w/FSDP details.
This PR cleans those prefixes by using `clean_tensor_name` in FSDP's existing
post load_state_dict hooks. Currently, only full_state_dict impl is tested, can test the rest of the impls as follow up work.
Differential Revision: [D54182472](https://our.internmc.facebook.com/intern/diff/D54182472/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120600
Approved by: https://github.com/XilunWu, https://github.com/fegin
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.
Test Plan: CI.
Differential Revision: D51861018
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.
Test Plan: CI.
Differential Revision: D51825114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
For scenarios where FSDP is not the root module, the `_use_dtensor` flag would not be switched on. This PR fixes it by checking whether the submodule has the `device_mesh` and turn `_use_dtensor` flag on accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113593
Approved by: https://github.com/fegin
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).
This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.
cc. @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110925
Approved by: https://github.com/fegin
ghstack dependencies: #110831, #110846
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.
This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.
Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
Summary:
When loading a CPU state_dict with a pg initialized with
cpu:gloo,cuda:nccl, we hit a gloo crash since dest tensor is on GPU and input
is on CPU.
As a workaround, just enforce that if local_tensor.is_cpu, the dest tensor is
also cpu.
Test Plan: CI
Differential Revision: D48324752
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107172
Approved by: https://github.com/fegin
issue resolved: https://github.com/pytorch/pytorch/issues/97791
before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce
after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.
Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.
Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).
- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin