There are still some differences between CUDA and non-CUDA custom devices when
construct FSDP because CUDA is selected as the default device. For example,
when construct FSDP from CPU model and device_id is not passed, device_handle
will choose CUDA as default device. This PR will autoselect the real device
as the default device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127609
Approved by: https://github.com/awgu
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```
We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
Some toy example:
<img width="998" alt="Screenshot 2024-04-17 at 2 00 05 PM" src="https://github.com/pytorch/pytorch/assets/31054793/b5665a63-beb0-4ca1-92c6-c57a052812fd">
We define `FullyShardedDataParallel._unshard(async_op: bool = False)` that can be used to prefetch all-gathers. The user should make sure:
1. Run lazy init before the first `_unshard` call of training. For example, this can hackily be done via `root_module.check_is_root()` on the root FSDP module `root_module`.
2. Call `root_module._wait_unshard_streams_on_current_stream()` before the first `_unshard` call of the current iteration (just need to call it once after last optimizer step and before first `_unshard` of this iteration).
Differential Revision: [D56262876](https://our.internmc.facebook.com/intern/diff/D56262876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124304
Approved by: https://github.com/wanchaol
Summary:
This would otherwise yield
> ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')
which is odd.
Remove the extra tailing commas.
Test Plan: CI
Differential Revision: D55549851
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123019
Approved by: https://github.com/Skylion007
Removes raising error if a device_mesh has a parent.
The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- https://github.com/pytorch/pytorch/pull/118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.
I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118620
Approved by: https://github.com/Skylion007
Removes raising error if a device_mesh has a parent.
The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- https://github.com/pytorch/pytorch/pull/118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.
I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118620
Approved by: https://github.com/wz337, https://github.com/wanchaol
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
reland of https://github.com/pytorch/pytorch/pull/116559, which was reverted by internal.
The underlying reason for the revert is that the torch.dynamo.disable can't be used by the
pytorch codebase, as it's conflicting with some torch.deploy together, although the later one
only run some inference, but it somehow take that weird dependency on fsdp..
We have seen this issue with our functional collectives that we can't
use any dynamo components otherwise torch.deploy would complain..
verified internally that after removing torch.dynamo.disable the test
passed again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117020
Approved by: https://github.com/awgu
Context: Existing FSDPExtension have some bug in the case when the
unflatten tensor involves some compute/communications in cuda stream,
the current logic of FSDPExtension unflatten tensor happens in the
unshard stream, which makes runtime lost sync with the compute stream,
and if there're some dependencies between the compute stream and the
unflatten tensor logic, currently it would lose sync point, which could
possibly lead to NaN.
This PR make the FSDPExtension to record the compute stream and let
DTensorExtension to directly use the compute stream for unflatten_tensor.
In long term we might want to directly make the FSDP runtime logic to only
make the unshard happen in unshard stream, and use unshard views to
happen in the compute stream. We currently fix this in the Extension
directly as this is the simplest thing to do without affecting FSDP
runtime logic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116559
Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/yifuwang
ghstack dependencies: #116426
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.
Test Plan: CI.
Differential Revision: D51861018
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.
Test Plan: CI.
Differential Revision: D51825114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).
This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
Replacing https://github.com/pytorch/pytorch/pull/109553 as it gets reverted.
This PR enables training with new 2D flow and adds associated test. In addition, this PR moves the tensor/parallel/_data_parallel_utils.py that are fsdp specific back to tensor/parallel/fsdp.py to avoid circular dependency for ddp.py and test/distributed/tensor/parallel/test_ddp_2d_parallel.py.
state_dict related changes would be in later PRs.
cc. @fegin, @fduwjj, @wanchaol, @awgu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110034
Approved by: https://github.com/fduwjj