Commit Graph

47 Commits

Author SHA1 Message Date
d2494cbb2b Revert "[distributed] Replace assert statements with AssertionError exceptions (#165216)"
This reverts commit 74db92b21868b7e9e77cc966e5d57a8246723cbd.

Reverted https://github.com/pytorch/pytorch/pull/165216 on behalf of https://github.com/clee2000 due to I think this broke distributed/test_pg_wrapper.py::ProcessGroupNCCLWrapperTest::test_debug_level_detail_no_gloo [GH job link](https://github.com/pytorch/pytorch/actions/runs/18492765290/job/52693842750) [HUD commit link](74db92b218), note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165216#issuecomment-3402838765))
2025-10-14 17:05:16 +00:00
74db92b218 [distributed] Replace assert statements with AssertionError exceptions (#165216)
Replaces 71 assert statements across 11 files in `torch.distributed` with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.

Fixes #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165216
Approved by: https://github.com/albanD
2025-10-14 09:58:59 +00:00
7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00
da003d7b95 [3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
2025-09-30 00:28:53 +00:00
ad81eeb7c7 Refactor to use torch.accelerator.device_index instead of torch.cuda.device for generic device context manager (#148880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148880
Approved by: https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #148864
2025-04-25 09:45:25 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279bc179a6bb63f0226e24ab42a07b394.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
361f42bc42 Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)"
This reverts commit 9aba0b91c8df4a15654f9ccc02abca31bdd81650.

Reverted https://github.com/pytorch/pytorch/pull/137821 on behalf of https://github.com/wdvr due to Reverting this for now, it is failing test_public_bindings in trunk ([comment](https://github.com/pytorch/pytorch/pull/137821#issuecomment-2417351788))
2024-10-16 16:38:29 +00:00
9aba0b91c8 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-16 09:28:32 +00:00
3406ac24d9 [BE] fix circular import in torch/distributed/utils.py (#136286)
**Summary**
Fix circular import in `torch/distributed/utils.py` found when running internal test, see D62901023. Curious why this wasn't causing any issue. Is this relevant code deprecated and no longer used?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136286
Approved by: https://github.com/Skylion007
2024-09-22 20:54:12 +00:00
724faac260 [FSDP] casting input args with dataclass(frozen=True) (#135067)
resolve: https://github.com/pytorch/pytorch/pull/135029

when enabling mixed precision, FSDP cast input args to desired dtype by calling `_apply_to_tensors`. When input args has `dataclass(frozen=True)`, we hit following runtime error, because of using `setattr` in `_apply_to_tensors`

`dataclasses.FrozenInstanceError: cannot assign to field 'some_key'`. The fix is to use dataclasses api `dataclasses.replace`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135067
Approved by: https://github.com/awgu
2024-09-05 01:19:53 +00:00
ff7e021e94 [Reland][PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773) (#130947)
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.

---

**Changes for reland:**
- The previous PR assumed that any `func` decorated with `@contract` would return the same input `module` as output (which is true for PT-D composable APIs).
- However, TorchRec `shard` returns a different module as output (though that module _does_ satisfy the `@contract` FQN check).
- This PR removes the assumption and instead only enforces the FQN check following the input module order. In other words, if calling `func([x1, ..., xN])` for `N` modules `x1, ..., xN` that returns `[y1, ..., yM]` for `M` modules, we require that `N = M` and that FQNs are preserved coordinate-wise: `xi` and `yi` have same FQNs for all `i = 1, ..., N`.

Differential Revision: [D59863438](https://our.internmc.facebook.com/intern/diff/D59863438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130947
Approved by: https://github.com/weifengpy, https://github.com/atalman
2024-07-17 22:40:13 +00:00
d7a8e8f7c5 Revert "[PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773)"
This reverts commit b27695791e9cc4eedb1b713b1be20398bfeb911b.

Reverted https://github.com/pytorch/pytorch/pull/127773 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/127773#issuecomment-2232004006))
2024-07-16 23:48:09 +00:00
b27695791e [PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773)
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127773
Approved by: https://github.com/weifengpy
2024-07-15 23:54:10 +00:00
94dc3253a0 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin, https://github.com/wconstab
2024-06-22 18:53:28 +00:00
9c929f6ce9 Revert "[BE][Easy] enable UFMT for torch/distributed/ (#128870)"
This reverts commit a0e1e20c4157bb3e537fc784a51d7aef1e754157.

Reverted https://github.com/pytorch/pytorch/pull/128870 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128870#issuecomment-2181780356))
2024-06-21 00:38:28 +00:00
a0e1e20c41 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin
ghstack dependencies: #128868, #128869
2024-06-18 21:49:08 +00:00
979edbbe12 [Traceable FSDP2] Dynamo support FSDP2 use_training_state context manager (#127854)
Improve Dynamo to support the FSDP2 `use_training_state()` context manager.

Test command:
`
pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_dynamo_trace_use_training_state
`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127854
Approved by: https://github.com/yanboliang
2024-06-16 08:48:52 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
2122c9e2a9 [BE] Enabled lintrunner on torch/distributed/utils.py (#127771)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127771
Approved by: https://github.com/wanchaol, https://github.com/Skylion007
2024-06-04 06:10:33 +00:00
77d5f60740 [fsdp][torch.compile] FSDP changes (#115497)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115497
Approved by: https://github.com/albanD
2023-12-19 18:44:36 +00:00
a8097ed479 Fix docstring errors in _composable_state.py, remote_device.py, value_ranges.py, utils.py, run.py, rendezvous.py, launch.py, argparse_util.py, __init__.py, _cycles.py (#112953)
Fixes #112639

```txt
 torch/utils/_sympy/value_ranges.py
 torch/utils/_sympy/value_ranges.py:60 in public class `ValueRanges`:
        D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:68 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:81 in public method `__contains__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:86 in public method `tighten`:
        D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:90 in public method `__and__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:103 in public method `__or__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:113 in public method `is_singleton`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:118 in public method `unknown`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:122 in public method `wrap`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:129 in public method `increasing_map`:
        D400: First line should end with a period (not ')')
torch/utils/_sympy/value_ranges.py:135 in public method `decreasing_map`:
        D400: First line should end with a period (not ')')
torch/utils/_sympy/value_ranges.py:141 in public method `monotone_map`:
        D400: First line should end with a period (not 'g')
torch/utils/_sympy/value_ranges.py:149 in public method `convex_min_zero_map`:
        D400: First line should end with a period (not '0')
torch/utils/_sympy/value_ranges.py:149 in public method `convex_min_zero_map`:
        D403: First word of the first line should be properly capitalized ('Fn', not 'fn')
torch/utils/_sympy/value_ranges.py:158 in public method `coordinatewise_increasing_map`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/_sympy/value_ranges.py:158 in public method `coordinatewise_increasing_map`:
        D400: First line should end with a period (not ':')
torch/utils/_sympy/value_ranges.py:171 in public method `coordinatewise_monotone_map`:
        D400: First line should end with a period (not 'e')
torch/utils/_sympy/value_ranges.py:180 in private class `SymPyValueRangeAnalysis`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/_sympy/value_ranges.py:180 in private class `SymPyValueRangeAnalysis`:
        D400: First line should end with a period (not 's')
torch/utils/_sympy/value_ranges.py:386 in private method `reciprocal`:
        D210: No whitespaces allowed surrounding docstring text
torch/utils/_sympy/value_ranges.py:386 in private method `reciprocal`:
        D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:488 in public class `ValueRangeAnalysis`:
        D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:489 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:501 in public method `bool_handler`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:506 in public method `default_handler`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:511 in public method `load`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:514 in public method `store`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:517 in public method `reduction`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:520 in public method `index_expr`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:525 in public method `to_dtype`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:558 in public method `square`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:562 in public method `neg`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:566 in public method `truncdiv`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:577 in public method `sub`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:580 in public method `__getattr__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:585 in public function `bound_sympy`:
        D103: Missing docstring in public function
36
torch/utils/_sympy/value_ranges.py:60 in public class `ValueRanges`:
        D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:68 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:81 in public method `__contains__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:86 in public method `tighten`:
        D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:90 in public method `__and__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:103 in public method `__or__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:113 in public method `is_singleton`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:118 in public method `unknown`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:122 in public method `wrap`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:182 in private class `SymPyValueRangeAnalysis`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/_sympy/value_ranges.py:182 in private class `SymPyValueRangeAnalysis`:
        D400: First line should end with a period (not 's')
torch/utils/_sympy/value_ranges.py:388 in private method `reciprocal`:
        D210: No whitespaces allowed surrounding docstring text
torch/utils/_sympy/value_ranges.py:388 in private method `reciprocal`:
        D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:490 in public class `ValueRangeAnalysis`:
        D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:491 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:503 in public method `bool_handler`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:508 in public method `default_handler`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:513 in public method `load`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:516 in public method `store`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:519 in public method `reduction`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:522 in public method `index_expr`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:527 in public method `to_dtype`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:560 in public method `square`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:564 in public method `neg`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:568 in public method `truncdiv`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:579 in public method `sub`:
        D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:582 in public method `__getattr__`:
        D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:587 in public function `bound_sympy`:
        D103: Missing docstring in public function
28

torch/utils/viz/_cycles.py
torch/utils/viz/_cycles.py:14 in public function `observe_garbage`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:207 in public function `object_annotation`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/viz/_cycles.py:207 in public function `object_annotation`:
        D400: First line should end with a period (not 'g')
torch/utils/viz/_cycles.py:256 in public class `Node`:
        D101: Missing docstring in public class
torch/utils/viz/_cycles.py:262 in public function `create_graph`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:308 in public function `escape`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:312 in public function `is_cuda_tensor`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:315 in public function `cuda_allocation_context`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:335 in public function `to_dot`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:406 in public function `to_html`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:416 in public function `observe_tensor_cycles`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:429 in public function `warn_tensor_cycles`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/viz/_cycles.py:429 in public function `warn_tensor_cycles`:
        D400: First line should end with a period (not 'p')
torch/utils/viz/_cycles.py:429 in public function `warn_tensor_cycles`:
        D401: First line should be in imperative mood; try rephrasing (found 'Reference')
14
torch/utils/viz/_cycles.py:14 in public function `observe_garbage`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:256 in public class `Node`:
        D101: Missing docstring in public class
torch/utils/viz/_cycles.py:262 in public function `create_graph`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:308 in public function `escape`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:312 in public function `is_cuda_tensor`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:315 in public function `cuda_allocation_context`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:335 in public function `to_dot`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:406 in public function `to_html`:
        D103: Missing docstring in public function
torch/utils/viz/_cycles.py:416 in public function `observe_tensor_cycles`:
        D103: Missing docstring in public function
9

torch/distributed/argparse_util.py
torch/distributed/argparse_util.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/argparse_util.py:13 in public class `env`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/argparse_util.py:13 in public class `env`:
        D400: First line should end with a period (not 'g')
torch/distributed/argparse_util.py:13 in public class `env`:
        D412: No blank lines allowed between a section header and its content ('Example')
torch/distributed/argparse_util.py:43 in public method `__init__`:
        D107: Missing docstring in __init__
torch/distributed/argparse_util.py:56 in public method `__call__`:
        D102: Missing docstring in public method
torch/distributed/argparse_util.py:61 in public class `check_env`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/argparse_util.py:61 in public class `check_env`:
        D400: First line should end with a period (not 's')
torch/distributed/argparse_util.py:61 in public class `check_env`:
        D412: No blank lines allowed between a section header and its content ('Example')
torch/distributed/argparse_util.py:97 in public method `__init__`:
        D107: Missing docstring in __init__
torch/distributed/argparse_util.py:102 in public method `__call__`:
        D102: Missing docstring in public method
11
torch/distributed/argparse_util.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/argparse_util.py:43 in public method `__init__`:
        D107: Missing docstring in __init__
torch/distributed/argparse_util.py:56 in public method `__call__`:
        D102: Missing docstring in public method
torch/distributed/argparse_util.py:97 in public method `__init__`:
        D107: Missing docstring in __init__
torch/distributed/argparse_util.py:102 in public method `__call__`:
        D102: Missing docstring in public method
5

torch/distributed/_composable_state.py
torch/distributed/_composable_state.py:20 in private function `_get_module_state`:
        D202: No blank lines allowed after function docstring (found 1)
torch/distributed/_composable_state.py:20 in private function `_get_module_state`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/_composable_state.py:20 in private function `_get_module_state`:
        D400: First line should end with a period (not '`')
3
0

torch/distributed/launch.py
torch/distributed/launch.py:1 at module level:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/launch.py:1 at module level:
        D400: First line should end with a period (not 'd')
torch/distributed/launch.py:156 in public function `parse_args`:
        D103: Missing docstring in public function
torch/distributed/launch.py:171 in public function `launch`:
        D103: Missing docstring in public function
torch/distributed/launch.py:180 in public function `main`:
        D103: Missing docstring in public function
5
torch/distributed/launch.py:157 in public function `parse_args`:
        D103: Missing docstring in public function
torch/distributed/launch.py:172 in public function `launch`:
        D103: Missing docstring in public function
torch/distributed/launch.py:181 in public function `main`:
        D103: Missing docstring in public function
3

torch/distributed/remote_device.py
torch/distributed/remote_device.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/remote_device.py:81 in private method `worker_name`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/remote_device.py:81 in private method `worker_name`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/distributed/remote_device.py:88 in private method `rank`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/remote_device.py:88 in private method `rank`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/distributed/remote_device.py:95 in private method `device`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/distributed/remote_device.py:95 in private method `device`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
7
torch/distributed/remote_device.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/remote_device.py:85 in private method `rank`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/remote_device.py:85 in private method `rank`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
3

torch/distributed/rendezvous.py
torch/distributed/rendezvous.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/rendezvous.py:23 in public function `register_rendezvous_handler`:
        D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/distributed/rendezvous.py:88 in public function `rendezvous`:
        D103: Missing docstring in public function
torch/distributed/rendezvous.py:147 in private function `_create_c10d_store`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/rendezvous.py:147 in private function `_create_c10d_store`:
        D400: First line should end with a period (not 'r')
5
torch/distributed/rendezvous.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/rendezvous.py:89 in public function `rendezvous`:
        D103: Missing docstring in public function
2

torch/distributed/run.py
torch/distributed/run.py:9 at module level:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/run.py:9 at module level:
        D400: First line should end with a period (not '`')
torch/distributed/run.py:393 in public function `get_args_parser`:
        D202: No blank lines allowed after function docstring (found 1)
torch/distributed/run.py:393 in public function `get_args_parser`:
        D401: First line should be in imperative mood; try rephrasing (found 'Helper')
torch/distributed/run.py:610 in public function `parse_args`:
        D103: Missing docstring in public function
torch/distributed/run.py:615 in public function `parse_min_max_nnodes`:
        D103: Missing docstring in public function
torch/distributed/run.py:629 in public function `determine_local_world_size`:
        D103: Missing docstring in public function
torch/distributed/run.py:670 in public function `get_rdzv_endpoint`:
        D103: Missing docstring in public function
torch/distributed/run.py:677 in public function `get_use_env`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/run.py:677 in public function `get_use_env`:
        D401: First line should be in imperative mood (perhaps 'Retrieve', not 'Retrieves')
torch/distributed/run.py:689 in public function `config_from_args`:
        D103: Missing docstring in public function
torch/distributed/run.py:770 in public function `run_script_path`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/run.py:770 in public function `run_script_path`:
        D401: First line should be in imperative mood (perhaps 'Run', not 'Runs')
torch/distributed/run.py:781 in public function `run`:
        D103: Missing docstring in public function
torch/distributed/run.py:804 in public function `main`:
        D103: Missing docstring in public function
15
torch/distributed/run.py:611 in public function `parse_args`:
        D103: Missing docstring in public function
torch/distributed/run.py:616 in public function `parse_min_max_nnodes`:
        D103: Missing docstring in public function
torch/distributed/run.py:630 in public function `determine_local_world_size`:
        D103: Missing docstring in public function
torch/distributed/run.py:671 in public function `get_rdzv_endpoint`:
        D103: Missing docstring in public function
torch/distributed/run.py:691 in public function `config_from_args`:
        D103: Missing docstring in public function
torch/distributed/run.py:784 in public function `run`:
        D103: Missing docstring in public function
torch/distributed/run.py:807 in public function `main`:
        D103: Missing docstring in public function
7

torch/distributed/__init__.py
torch/distributed/__init__.py:1 at module level:
        D104: Missing docstring in public package
torch/distributed/__init__.py:8 in public function `is_available`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/__init__.py:8 in public function `is_available`:
        D400: First line should end with a period (not ',')
torch/distributed/__init__.py:8 in public function `is_available`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
4
torch/distributed/__init__.py:1 at module level:
        D104: Missing docstring in public package
1

torch/distributed/utils.py:1 at module level:
        D100: Missing docstring in public module
torch/distributed/utils.py:16 in private function `_pack_kwargs`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:16 in private function `_pack_kwargs`:
        D400: First line should end with a period (not ')')
torch/distributed/utils.py:47 in private function `_cast_forward_inputs`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:88 in private function `_recursive_to`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/distributed/utils.py:141 in private function `_p_assert`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:141 in private function `_p_assert`:
        D209: Multi-line docstring closing quotes should be on a separate line
torch/distributed/utils.py:141 in private function `_p_assert`:
        D400: First line should end with a period (not 't')
torch/distributed/utils.py:141 in private function `_p_assert`:
        D401: First line should be in imperative mood; try rephrasing (found 'This')
torch/distributed/utils.py:275 in private function `_sync_module_states`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:275 in private function `_sync_module_states`:
        D400: First line should end with a period (not 'n')
torch/distributed/utils.py:275 in private function `_sync_module_states`:
        D401: First line should be in imperative mood (perhaps 'Sync', not 'Syncs')
torch/distributed/utils.py:300 in private function `_sync_params_and_buffers`:
        D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:300 in private function `_sync_params_and_buffers`:
        D400: First line should end with a period (not 'y')
torch/distributed/utils.py:300 in private function `_sync_params_and_buffers`:
        D401: First line should be in imperative mood (perhaps 'Synchronize', not 'Synchronizes')
15
torch/distributed/utils.py:1 at module level:
        D100: Missing docstring in public module
1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112953
Approved by: https://github.com/weifengpy
2023-11-08 01:13:09 +00:00
fb68aa0a92 [Easy] Remove unused return type from utils (#110887)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110887
Approved by: https://github.com/ezyang
2023-10-10 09:02:11 +00:00
ff37f6018d Enable custom device support in fsdp checkpoint (#107289)
Fixes https://github.com/pytorch/pytorch/issues/104390
Enable custom device(privateuse1 backend) support in checkpointing by a dynamic abstract device module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107289
Approved by: https://github.com/wz337
2023-08-25 11:50:03 +00:00
88ce6215f5 [FSDP/DDP] Unify _cast_forward_inputs (#102680)
Closes https://github.com/pytorch/pytorch/issues/96380

Differential Revision: [D46342814](https://our.internmc.facebook.com/intern/diff/D46342814/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102680
Approved by: https://github.com/awgu
2023-06-04 18:31:21 +00:00
c28f8e314d Add type hints in torch/distributed/utils.py (#102262)
Fixes #77190

Pretty similar to the typing in `torch/nn/parallel`, which was also improved recently: #102194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102262
Approved by: https://github.com/Skylion007, https://github.com/Neilblaze
2023-05-30 19:57:45 +00:00
0ed22fce97 Merge type stubs torch nn parallel (#102194)
Fixes merge issue for #101528

In the above PR, `torch.nn.parallel.parallel_apply.get_a_var` was marked private to appease the [public interface linter](https://github.com/pytorch/pytorch/actions/runs/4999216467/jobs/8955582204#step:14:21666): ceeb242bc7

This broke CI pipelines running external dependencies that expected `get_a_var`'s name to not change. In this PR, we change the name back to `get_a_var` and include it in the `__all__` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102194
Approved by: https://github.com/ezyang
2023-05-26 20:10:47 +00:00
dfe484a3b3 [BE]: Bugfix functorch and some generic typing improvements (#101337)
Fixes some typing bugs found with newer versions of mypy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101337
Approved by: https://github.com/ezyang
2023-05-14 14:20:56 +00:00
0731420645 [PyTorch/Distributed]Only sync buffers when broadcast_buffers is True (#100729)
Summary: Disable buffers sync in _sync_module_states(...) when broadcast_buffers is False. This change will memory usage when a model has huge buffers and does not need broadcast buffers.

Test Plan: .

Differential Revision: D45610709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100729
Approved by: https://github.com/mrshenli
2023-05-08 16:34:29 +00:00
bd07f8d2e0 DDP forward support custom stream accelerated copy. (#98723)
At present, DDP forward uses `_get_stream` to get a stream,which is cudaStream.
If the custom module already registered to torch, I can use `getattr` to get it and it's stream. Then, the custom stream is used to copy the tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98723
Approved by: https://github.com/ezyang
2023-04-14 20:19:56 +00:00
d95ee64b58 ddp forward support custom backend. (#98283)
Currently DDP only considers CUDA backend,DDP forward will transfer tensor to CUDA. We want ddp to run on custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98283
Approved by: https://github.com/ezyang
2023-04-09 01:30:42 +00:00
5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00
c43e88665a [Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
2023-02-17 18:24:20 +00:00
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
ce7751188a [DDP] Add PackedSequence support when device_ids is specified (#86614)
Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.

Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```

Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```
Without the fix, the added unit test fails with the expected error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86614
Approved by: https://github.com/rohan-varma
2022-10-10 21:50:59 +00:00
8cb7826889 [CheckpointWrapper] Reentrant kwarg support (#84908)
A temporary patch to support keyword args when reentrant checkpoint wrapper is used. This is need to unblock some crucial workloads, the ideal fix would be checking this directly into torch.utils.checkpoint.

Differential Revision: [D39453453](https://our.internmc.facebook.com/intern/diff/D39453453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84908
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
d2f37401b8 Silence namedtuple warning in dist (#84072)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84072
Approved by: https://github.com/awgu
2022-08-26 00:24:28 +00:00
b29a074882 [BE] Revert distributed change in https://github.com/pytorch/pytorch/pull/68779 (#83181)
https://github.com/pytorch/pytorch/issues/82641 points out a regression in how inputs / outputs are processed by DDP, blocking their HF use case. It was narrowed down to https://github.com/pytorch/pytorch/pull/68779 and reverting the distributed change there fixes the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83181
Approved by: https://github.com/kumpera
2022-08-23 02:38:23 +00:00
6f954d7bbb FSDP parameter sync
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77492

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
bbb1f106c7 Separate input moving to utils file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77187

Test fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77235

Lint fix

Approved by: https://github.com/awgu
2022-05-11 21:55:38 +00:00
ffb0946504 Generalize param verification and broadcast
New PR for https://github.com/pytorch/pytorch/pull/75970 to be compatible with GHF.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76374
Approved by: https://github.com/awgu
2022-04-26 22:25:53 +00:00
b8e6144e0a Add a _RemoteDevice structure for ShardedTensor/ShardingSpec. (#62927)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62927

As part of the ShardedTensor work, we realized we do need some sort of
_RemoteDevice structure that deals with our format of "workername/device" so
that users don't have to worry about parsing this string directly.

Right now this structure is just the bare minimum and is mostly a container for
describing a remote device. It is currently only used in ShardedTensor,
ShardingSpec and RemoteModule.

Once we actually have a consolidated remote device proposal, this class can be
extended appropriately if needed.
ghstack-source-id: 135534086

Test Plan:
1) unit tests
2) waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D30170689

fbshipit-source-id: 1ac2e81c7a597dc40bf3fbf2c1168c382c66649f
2021-08-11 11:27:32 -07:00
0d6fa1adc5 Introduce ChunkShardingSpec as a model sharding specification. (#55728)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55728

Full design: https://github.com/pytorch/pytorch/issues/55207

This PR introduces ChunkShardingSpec (SingleShardingSpec in the design). Used
the name ChunkShardingSpec since it is very similar to `torch.chunk` in terms
of how a Tensor is split up and feels more clear compared to SingleShardingSpec.
ghstack-source-id: 129603318

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D27694108

fbshipit-source-id: c8764abe6a4d5fc56d023fda29b74b5af2a73b49
2021-05-23 16:04:57 -07:00