Commit Graph

26 Commits

Author SHA1 Message Date
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
9da250aada type fully_shard so that the return value can be chained with typing enabled (#147489)
This allows for

```
fsdped = fully_shard(model)
fsdped.set_xyz()
```

same applies if `model` is actually a list of modules

Differential Revision: [D69888119](https://our.internmc.facebook.com/intern/diff/D69888119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147489
Approved by: https://github.com/Skylion007
ghstack dependencies: #147488
2025-02-20 08:43:16 +00:00
de1cb0f351 capture the return value in the contract typing (#147488)
----

* the existing typing makes the return type `Optional[nn.Module]`
* this doesn't seem to be what the decorator actually does as it does
  not alter the original return type
* This PR aims to fix the typing

Differential Revision: [D69888120](https://our.internmc.facebook.com/intern/diff/D69888120)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147488
Approved by: https://github.com/Skylion007
2025-02-20 03:32:34 +00:00
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279bc179a6bb63f0226e24ab42a07b394.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
612122af8f Fix type-safety of torch.nn.Module instances (#141240)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141240
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-22 00:05:05 +00:00
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
ff7e021e94 [Reland][PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773) (#130947)
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.

---

**Changes for reland:**
- The previous PR assumed that any `func` decorated with `@contract` would return the same input `module` as output (which is true for PT-D composable APIs).
- However, TorchRec `shard` returns a different module as output (though that module _does_ satisfy the `@contract` FQN check).
- This PR removes the assumption and instead only enforces the FQN check following the input module order. In other words, if calling `func([x1, ..., xN])` for `N` modules `x1, ..., xN` that returns `[y1, ..., yM]` for `M` modules, we require that `N = M` and that FQNs are preserved coordinate-wise: `xi` and `yi` have same FQNs for all `i = 1, ..., N`.

Differential Revision: [D59863438](https://our.internmc.facebook.com/intern/diff/D59863438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130947
Approved by: https://github.com/weifengpy, https://github.com/atalman
2024-07-17 22:40:13 +00:00
d7a8e8f7c5 Revert "[PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773)"
This reverts commit b27695791e9cc4eedb1b713b1be20398bfeb911b.

Reverted https://github.com/pytorch/pytorch/pull/127773 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/127773#issuecomment-2232004006))
2024-07-16 23:48:09 +00:00
b27695791e [PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773)
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127773
Approved by: https://github.com/weifengpy
2024-07-15 23:54:10 +00:00
3a0d088517 Flip default value for mypy disallow_untyped_defs [5/11] (#127842)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842
Approved by: https://github.com/oulgen
2024-06-08 18:49:18 +00:00
20eaa49dde [PT-D] Made _get_registry return None if no APIs applied (#113654)
I prefer to not modify the module if it does not have any of our APIs applied. The side effect of inserting a registry on the module when calling a getter is non-intuitive to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113654
Approved by: https://github.com/fegin
2023-11-14 20:28:11 +00:00
42660015b4 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly (#106886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106886
Approved by: https://github.com/awgu, https://github.com/wconstab
ghstack dependencies: #106884
2023-08-11 22:35:50 +00:00
487a33e38a [FSDP x dynamo] simplify registry keys (#104209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104209
Approved by: https://github.com/wconstab, https://github.com/fegin
2023-07-25 07:16:22 +00:00
cfd1b4df94 [Composable] add checking key for check_fqn function (#98961)
add checking key for check_fqn function

ghstack-source-id: d856f560f1fc449a316135e3844609d0baaf6d66
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96705

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98961
Approved by: https://github.com/awgu
2023-04-13 03:16:14 +00:00
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
933cc67e7e [pytorch] [compososable] make contract() pickle-able through functools wraps (#92120)
Summary:
make contract() pickle-able through functools wraps.
This is to get functions wrapped with contract() to work with torch package

Differential Revision: D42491056

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92120
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/rohan-varma, https://github.com/mrshenli
2023-01-17 18:14:05 +00:00
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
d08e3d2304 [Composable API] Apply ufmt to _composable and the corresponding test folders (#91255)
This PR apply ufmt to format `_composable` related code. This is a request from https://github.com/pytorch/pytorch/pull/91234 to separate formatting changes as a new PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91255
Approved by: https://github.com/awgu
2022-12-23 16:08:27 +00:00
9b42e4ef73 [Composable API] Make _StateKey as a str subclass (#91279)
The keys in object.__dict__ should be strings. Make the _StateKey be a str subclass.

Differential Revision: [D42200244](https://our.internmc.facebook.com/intern/diff/D42200244/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91279
Approved by: https://github.com/awgu, https://github.com/mrshenli
2022-12-22 06:01:06 +00:00
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
a69cdd9cf8 Add global registry to composable API contract (#90579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90579
Approved by: https://github.com/awgu, https://github.com/yhcharles
2022-12-10 22:41:10 +00:00
f3af5ba48e [WIP] Composable API: replicate and DistributedState (#87649)
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649
Approved by: https://github.com/zhaojuanmao
2022-11-17 03:06:31 +00:00
ce3e0e9856 Add state to distributed composable API (#87838)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87838
Approved by: https://github.com/yhcharles
2022-10-28 13:31:40 +00:00