* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Skipping importing some packages for now to make this change more
tractable.
For some reason, lintrunner on CI raises errors in all imported `.pyi` files,
even though it doesn't on my local machine. The errors are all from missing
generic types, as the MYPYINDUCTOR config has `disallow_any_generics`
set. I have thus added `disable-error-code` comments to the relevant files,
though I fixed a few that were easy enough.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113830
Approved by: https://github.com/Skylion007
ghstack dependencies: #113722, #113721
This PR reduces docstring erros to 0 from total 98. This can be verified by running,
`pydocstyle path-to-zero_redundancy_optimizer.py --count`
BEFORE the PR:
`pydocstyle torch/distributed/optim/zero_redundancy_optimizer.py --count`
98
AFTER the PR:
`pydocstyle torch/distributed/optim/zero_redundancy_optimizer.py --count`
0
Fixes#112642
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113200
Approved by: https://github.com/weifengpy
Those functions enable membership introspection into a ProcessGroup. A common scenario
that needs this is library code that consumes a PG but doesn't create it, which means
it likely doesn't know the global ranks used to create it.
Translating from local to global is necessary when using c10d collectives like broadcast
so if your library code adopts the convention of using local rank 0, it needs
to the following:
```python
import torch.distributed as dist
my_pg: dist.ProcessGroup = ...
def my_library_bcast(tensor)
dist.broadcast(tensor, src=dist.get_global_rank(my_pg, local_rank=0), my_pg)
```
This implements some of the helpers needed to implement the `clone` API from: https://github.com/pytorch/pytorch/issues/81291
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82134
Approved by: https://github.com/rohan-varma
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
### Description
Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.
### Testing
There shouldn't be any testing required.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
Near term fix for https://github.com/pytorch/pytorch/issues/76368.
Q. Why does the user need to request `capturable=True` in the optimizer constructor? Why can't capture safety be completely automatic?
A. We need to set up capture-safe (device-side) state variables before capture. If we don't, and step() internally detects capture is underway, it's too late: the best we could do is create a device state variable and copy the current CPU value into it, which is not something we want baked into the graph.
Q. Ok, why not just do the capture-safe approach with device-side state variables all the time?
A. It incurs several more kernel launches per parameter, which could really add up and regress cpu overhead for ungraphed step()s. If the optimizer won't be captured, we should allow step() to stick with its current cpu-side state handling.
Q. But cuda RNG is a stateful thing that maintains its state on the cpu outside of capture and replay, and we capture it automatically. Why can't we do the same thing here?
A. The graph object can handle RNG generator increments because its capture_begin, capture_end, and replay() methods can see and access generator object. But the graph object has no explicit knowledge of or access to optimizer steps in its capture scope. We could let the user tell the graph object what optimizers will be stepped in its scope, ie something like
```python
graph.will_use_optimizer(opt)
graph.capture_begin()
...
```
but that seems clunkier than an optimizer constructor arg.
I'm open to other ideas, but right now I think constructor arg is necessary and the least bad approach.
Long term, https://github.com/pytorch/pytorch/issues/71274 is a better fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77862
Approved by: https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74215
### Overview of API
This PR introduces full optimizer state dict checkpointing.
- This allows users to save the optimizer state for a `torch.nn.Module` (not necessarily a `FullyShardedDataParallel` instance) that contains `FullyShardedDataParallel` instances and later load that optimizer state.
- This supports loading to a module with a different world size, but the `FSDP` wrapping scheme must be the same.
To **save** the optimizer state, run the following (on all ranks):
```
model: torch.nn.Module = ...
optim = torch.optim.Adam(model.parameters(), ...)
# Train for some steps...
full_osd = FSDP.full_optim_state_dict(model, optim) # returns non-empty dict only on rank 0
if rank == 0:
torch.save(full_osd, ...)
```
To **load** the optimizer state, run the following (on all ranks):
```
new_model: torch.nn.Module = ... # may use different world size
full_osd = torch.load(...)
sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
optim = torch.optim.Adam(new_model.parameters(), ...)
optim.load_state_dict(sharded_osd)
```
To support **multiple parameter groups**, we require using an additional argument `optim_input`, which is the first argument that the user passes into the optimizer constructor.
```
optim_input = ...
optim = torch.optim.Adam(optim_input, ...)
FSDP.full_optim_state_dict(model, optim, optim_input) # one more argument
...
new_optim_input = ...
new_optim = torch.optim.Adam(new_optim_input, ...)
FSDP.shard_full_optim_state_dict(full_osd, new_model, new_optim_input) # one more argument
```
One caveat is that the user should be careful of generators, which are exhausted after their first use. The `optim_input` passed into the `FSDP` APIs should be refreshed version of the generator if using generators.
### Test Plan
**`full_optim_state_dict()`**
- [x] `full_optim_state_dict()` for a non-`FSDP` root model matches that of an equivalent local model, up to parameter IDs being rearranged, when optimizer input is `model.parameters()`.
- [x] `full_optim_state_dict()` for a non-`FSDP` root model matches that of an equivalent local model, up to parameter IDs being rearranged, when optimizer input is multiple parameter groups (changing parameter order).
**`shard_full_optim_state_dict()`**
- [x] `shard_full_optim_state_dict()` for a non-`FSDP` root model matches the local `optim.state_dict()` of the same model with halved world size, when optimizer input is `model.parameters()`.
- [x] `shard_full_optim_state_dict()` for a non-`FSDP` root model matches the local `optim.state_dict()` of the same model with halved world size, when optimizer input is multiple parameter groups (changing parameter order).
- [x] `shard_full_optim_state_dict()` raises a `ValueError` when changing the `FSDP` wrapping scheme.
On the AWS cluster, the TTS contribution for these tests is ~45 seconds.
### Developer Notes
**Relaxing the Problem**
For optimizer state checkpointing, we have relaxed the problem to **not support changing the `FSDP` wrapping scheme** between save and load time. It is unclear how to solve without this relaxation. This was the least restrictive way to relax the problem since it does not affect most expected use cases. Rather, the expected change between save and load time is the **world size**, which this implementation **does support**.
Even with the relaxation, the `optim_input` argument is necessary to determine the `flat_param_id_to_param` mapping, which is important to know which parameter IDs in the flattened space correspond to `FlatParameter`s that hence need to be unflattened.
**Differences with Local Equivalent**
Suppose `full_osd = full_optim_state_dict()` and `local_osd = state_dict()` for a purely local equivalent. The difference between `full_osd` and `local_osd` is that the parameter IDs of unflattened parameters comprising a single flattened parameter are always consecutive in `full_osd`, while they may be non-consecutive in `local_osd`. Suppose in the following that each layer has 1 parameter `param`:
```
FSDP(model)
layer1
FSDP(layer2)
layer3
```
`layer1.param` and `layer3.param` are flattened and attributed to `model`. `layer2.param` is flattened and attributed to itself.
- In `local_osd`, the parameter IDs would be `0: layer1.param`, `1: layer2.param`, and `2: layer3.param`.
- In `full_osd`, the parameter IDs would be `0: layer1.param`, `1: layer3.param`, and `2: layer2.param`. (Parameter IDs of unflattened parameters sharing a flattened parameter are consecutive.)
The idea is that as long as `full_optim_state_dict()` and `shard_full_optim_state_dict()` are internally consistent, then there is no need to match the local equivalent (assuming no change in `FSDP` wrapping).
### Follow-Ups
**API**
- If needed, we can follow-up this PR by adding an argument `key_by_name: bool = False` to both methods that may be set to `True` to key parameters by `str` names instead of `int` parameter IDs. We still need to investigate if keying by name enables changing the `FSDP` wrapping scheme.
**Refactoring**
- In this optimizer state checkpointing, all optimizer state is saved to CPU on rank 0 (set as `OPTIM_TARGET_RANK`). We should unify and refactor these assumptions with model state checkpointing.
**Testing**
- The code path for unused parameters is not tested. The testing and any needed implementation fixes can be done in a follow-up.
- The code path for non-tensor states (e.g. `Adam` `"step"` as `float` instead of as zero-dimension `FloatTensor`) is not tested. However, it is identical to that of zero-dimension tensor states, so I have some confidence. If needed, I can add tests for it in a follow-up.
- Would I have to write my own optimizer? I do not want to introduce dependencies on third party libraries like Nvidia `apex`.
- We may want to add end-to-end checkpointing tests that include both model state dict and optimizer state dict.
Test Plan: Imported from OSS
Reviewed By: zhaojuanmao
Differential Revision: D35045121
Pulled By: awgu
fbshipit-source-id: 33c650dc960acbd7613d4f444a852b9f76ca4a9b
(cherry picked from commit 2bbc2e344296dc455cf686f3a9b097989504be81)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73842
**Overview**
This cleans up the `ZeroRedundancyOptimizer` tests. I apologize for strong formatting changes mixed in with actually-beneficial changes. It was convenient to unify the formatting while doing a deep comb through the full test file.
The main non-formatting changes include:
- Using `parametrize` instead of manually including `for` loops over possible argument values
- Removing the `DEVICE` global variable, which was used only for the `TestZeroRedundancyOptimizerSingleRank` tests, in favor of consistent usage of `self.device` in both `TestZeroRedundancyOptimizerSingleRank` and `TestZeroRedundancyOptimizerDistributed`
- Moving `assert ... == ...` to `self.assertEqual(..., ...)` when the assert is part of the test's correctness
- Removing the `if self.rank >= self.world_size or (torch.cuda.is_available() and torch.cuda.device_count() < 2):` conditional guards in favor of `common_distributed.skip_if_no_gpu` for `TestZeroRedundancyOptimizerDistributed`
- For `TestZeroRedundancyOptimizerDistributed`, `self.device` is `torch.device(self.rank)` if CUDA is available, while `self.world_size` is at least 2, even if `torch.cuda.device_count() == 1`.
- The problematic case is exactly when `torch.cuda.device_count() == 1` but `self.world_size == 2` since then calling `self.device` on rank 1 will error. The existing conditional guard prevented this case for some tests, but it was not used consistently (e.g. `test_multiple_groups()`), which is most likely the reason for the hangs and resulting test flakiness. (From my experience landing the recent ZeRO constructor changes, the Windows environment uses a world size of 2 but only has 1 device available.)
- A more robust solution is to always use the `skip_if_no_gpu` decorator as long as the test uses `self.device` and CUDA is available. This is in line with the recommended SPSD usage of ZeRO.
- Renaming `test_multiple_groups()` to `test_nondefault_process_group()`
- The existing `test_multiple_groups()` was slightly misnamed. Also, it is only nontrivial for a world size of (at least) 4 since it tests using a process group including only even ranks. It was marked as flaky on Windows, and I believe this is because of the world size and `torch.cuda.device_count()` mismatch. Now, the test only uses GPU if there are enough available and falls back to CPU otherwise, which is safe since the test uses Gloo backend.
- There was also a duplicated section, which I was unsure how to non-naively de-duplicate. The top half and bottom half are identical even though they claim to target fitting into the broadcast bucket and not fitting into the broadcast bucket:
1d497114e7/test/distributed/optim/test_zero_redundancy_optimizer.py (L658-L684)
- Changing `_test_zero_model_parallel()` to not use CPU
- This is my own fault, having introduced this inefficiency last summer. It makes more sense to simply designate one of the two GPUs for a process to be its default device rather than routing through CPU.
**Questions**
- How might we limit the runs for `test_ddp_zero_overlap()`? Because it parameterizes over many values, it contributes significantly to the time-to-signal. However, it is an experimental feature, so it is not critical that the tests run every time.
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D34675709
Pulled By: awgu
fbshipit-source-id: 71ce9ac968fb34415cd65206855b4bb5e67754fb
(cherry picked from commit 34e3dd0a184318ea9f63a1ee20cd14b111af3501)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73166
This PR refactors, cleans up, and optimizes the implementation of `TORCH_DISTRIBUTED_DEBUG`. It also introduces three new user APIs: `get_debug_level()`, `set_debug_level()`, and `set_debug_level_from_env()` to retrieve and modify the debug level after a process has started.
ghstack-source-id: 149778566
Test Plan: Run the existing unit tests.
Reviewed By: rohan-varma
Differential Revision: D34371226
fbshipit-source-id: e18443b411adcbaf39b2ec999178c198052fcd5b
(cherry picked from commit 26d6bb1584b83a0490d8b766482656a5887fa21d)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.
**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).
To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.
**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932
Reviewed By: rohan-varma
Differential Revision: D34281482
Pulled By: awgu
fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc6349ff1aad403563206fb170a3af0c70)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72578
**Overview**
This adds `ZeroRedundancyOptimizer` constructor support for multiple parameter groups (i.e. passing an `iterable` of `dict`s instead of an `iterable` of `torch.Tensor` as the `parameters` argument) to mirror the API for non-sharded optimizers.
Fixes https://github.com/pytorch/pytorch/issues/71347 and https://github.com/pytorch/pytorch/issues/59973.
This modifies `test_collect_shards()` to skip if ROCm.
**Test Plan**
I adjusted the existing constructor test, and I added a test for parity between constructing with two parameter groups up front versus constructor with one parameter group and adding the second parameter group after (via `add_param_group()`) versus a non-sharded optimizer.
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D34106940
Pulled By: awgu
fbshipit-source-id: 7e70fc0b3cec891646e0698eaedf02ff4354c128
(cherry picked from commit 40f2d45172ba3286b64000a466e42c055cca8ddc)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71620
Remove from_functional_optim and make it the default constructor since
that is the only way _OptimizerHookState is now being built. Also, no longer
need to expose create_functional_optim helper function
ghstack-source-id: 147577174
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D33700593
fbshipit-source-id: ba089ce3bf66ccf8f71cffdd0f4d4bddc03e8b14
(cherry picked from commit a50b2caf0e19f9793fbf18b371d30e3dd8c5c0cf)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62937
reland due to windows + cuda failure, fix by running it on gloo on windows even with cuda.
ghstack-source-id: 135306176
Test Plan: ci
Reviewed By: mrshenli
Differential Revision: D30177734
fbshipit-source-id: 7625746984c8f858648c1b3632394b98bd4518d2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62774
Gates DistributedOptimizer which relies on RRef based on if RPC is available. This should enable ZeRo to work with Windows as Windows should not try to import the DIstributedOptimizer. If this works as expected we can enable the windows tests for functional/local sgd optimizers as well.
ghstack-source-id: 135216642
Test Plan: CI
Reviewed By: pbelevich
Differential Revision: D30117838
fbshipit-source-id: e6365a910a3d1ca40d95fa6777a7019c561957db
Summary:
**Overview:**
This removes the preceding `_` from `_Join`, `_Joinable`, and `_JoinHook` in preparation for adding the generic join context manager tutorial (see [here](https://github.com/pytorch/tutorials/pull/1610)). This also adds a docs page, which can be linked from the tutorial. [Here](https://github.com/pytorch/pytorch/files/6919475/render.pdf) is a render of the docs page.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62605
Test Plan:
`DistributedDataParallel.join()`:
```
touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py -- TestDistBackendWithFork.test_ddp_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_inputs_stop_iteration_sync_bn TestDistBackendWithFork.test_ddp_grad_div_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_input_join_disable TestDistBackendWithFork.test_ddp_uneven_input_exception
```
`ZeroRedundancyOptimizer`:
```
gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py
```
NOTE: DDP overlap tests are failing due to a landing race. See https://github.com/pytorch/pytorch/pull/62592. Once the fix is landed, I will rebase, and tests should be passing.
`Join`:
```
gpurun4 python test/distributed/algorithms/test_join.py
```
Reviewed By: mrshenli
Differential Revision: D30055544
Pulled By: andwgu
fbshipit-source-id: a5ce1f1d9f1904de3bdd4edd0b31b0a612d87026
Summary:
**Overview:**
This adds two approaches to overlapping `DistributedDataParallel.backward()` with `ZeroRedundancyOptimizer.step()` by providing two hook constructors: `hook_with_zero_step()` and `hook_with_zero_step_interleaved()`. The former waits for all backward computation to finish before starting optimizer computation, while the latter launches a partial optimizer computation using the contents of a gradient bucket once that bucket's all-reduce completes. The two approaches each suffer from their own weaknesses, and which one to use depends on the specific hardware configuration.
Both approaches can share changes to `ZeroRedundancyOptimizer`. A user should pass `overlap_with_ddp=True` to `ZeroRedundancyOptimizer`, construct a DDP communication hook using either `hook_with_zero_step()` or `hook_with_zero_step_interleaved()`, and register that communication hook. `ZeroRedundancyOptimizer.step()` should still be called in the training loop, though the optimizer computation and communication will be offloaded to originate from the communication hook. Currently, the first two iterations are vacuous, meaning they do not result in parameter updates and the inputs are ignored. This is required to finalize the DDP bucket strategy and to then initialize the `ZeroRedundancyOptimizer`'s local optimizer based on that bucketing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62157
Test Plan:
The existing `ZeroRedundancyOptimizer` tests pass, and new unit tests for both hooks pass:
- ~~`test_ddp_with_zero_step_parity_cpu`~~ (removed for now due to flakiness in CI -- under investigation, could possibly be similar Gloo issue as with `hook_with_zero_step_interleaved()`)
- `test_ddp_with_zero_step_parity_gpu`
- `test_ddp_with_zero_step_interleaved_parity_gpu`
These were tested on the AI AWS cluster.
An analogous `test_ddp_with_zero_step_interleaved_parity_cpu` is missing due to existing bugs with Gloo. See https://github.com/pytorch/pytorch/pull/62302.
Both approaches have been verified using an internal accuracy benchmark.
Reviewed By: mrshenli
Differential Revision: D29971046
Pulled By: andwgu
fbshipit-source-id: a7234c23c7ea253f144a698fd7e3c0fe039de5e8
Summary:
Revised version of https://github.com/pytorch/pytorch/issues/60573.
**Overview:**
This makes two changes:
- It introduces a `map_location` argument to `broadcast_object_list()`. The argument specifies the device to load tensors contained in objects received from the broadcast. This change requires modifying the implementation of `_object_to_tensor()` and `_tensor_to_object()` to use `torch.save()` and torch.load()` respectively.
- It removes all calls to `_broadcast_object()` in `ZeroRedundancyOptimizer` and the corresponding test file in favor of `broadcast_object_list()`.
The default value of `map_location` is `None`, in which case `_object_to_tensor()` and hence `broadcast_object_list()` preserve their original behavior. Namely, contained tensors are loaded to their original device.
In `consolidate_state_dict()`, I specify `map_location=torch.device("cpu")` instead of `self._default_device`. This slightly changes the behavior from before when using `_broadcast_object()`. The reason I do so is that it saves one GPU to CPU data transfer since the action immediately after receiving the broadcasted `local_state_dict` is to copy it to CPU.
Explicitly, if `map_location=self._default_device`, then the data transfer path assuming NCCL backend is as follows:
`source GPU --[before serialize]--> source CPU --[before broadcast]--> source GPU --[broadcast]--> destination GPU --[before deserialize]--> destination CPU --[deserialize]--> destination GPU --[copy]--> destination CPU`
Hence, by setting `map_location=torch.device("cpu")` instead, the suffix becomes:
`destination CPU --[deserialize]--> destination CPU --[copy]--> destination CPU`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61539
Test Plan:
I added a test `test_broadcast_object_list_map_location()` that checks for both `map_location` as CPU and GPU that (1) tensors contained in broadcasted objects are appropriately loaded onto the specified device and (2) that the contents of the tensors are correct.
The existing `ZeroRedundancyOptimizer` tests pass.
```
gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py
```
The existing `broadcast_object_list()` test passes:
```
touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py -- TestDistBackendWithFork.test_broadcast_object_list
```
Reviewed By: zou3519
Differential Revision: D29701479
Pulled By: andwgu
fbshipit-source-id: c8d5f9057b32e5e9f40e8edc5b2cc25fb21414a9
Summary:
**Overview:**
This refactors the computation on non-joined processes relating to the join context manager. The concept was inspired by a comment from pritamdamania.
**Changes:**
This introduces a `_Joinable` abstract base class, which requires a `_join_hook()` method and `_join_device()` and `_join_process_group()` property methods. Any class that we want to be compatible with the generic join context manager should inherit from `_Joinable` and implement `_join_hook()`, `_join_device()`, and `_join_process_group()`. (The `device` and `process_group` information has been moved from `_JoinHook` to `_Joinable`.)
The generic join context manager now takes in a `List[_Joinable]` instead of `List[_JoinHook]`. The motivation for this is that previously, by passing the `_JoinHook`s into the context manager, the class providing a `_JoinHook` can modify the context manager's behavior, but the context manager cannot modify the class's behavior. This is solved by giving the context manager a reference to the class's instance.
This implementation reserves the field `_join_config` in every `_Joinable` to store a `_JoinConfig` instance, which holds all dynamic fields needed from the `_Joinable` for the join context manager: `enable`, `throw_on_early_termination`, and `is_first_joinable`. ("dynamic" here means that for a given `_Joinable` instance, the values for those fields may change across different join context usages.) In particular, these fields are needed to implement a method `notify_join_context()`, which encapsulates the computation performed on non-joined processes relating to the join context manager --- (1) the all-reduce to indicate that the process has not yet joined and (2) the all-reduce to check whether to throw an exception if `throw_on_uneven_inputs=True`. The idea is that every `_Joinable` class only needs to make a call to `notify_join_context()` before its per-iteration collective communications; it is a simple one-line addition.
Only the first `_Joinable` instance passed into the context manager actually performs the collective communications in `notify_join_context()`. In that case, the method returns an async work handle for the initial all-reduce indicating that the process not yet joined. Otherwise, the method returns `None`. This conditional logic is handled internally without additional input from the user.
**New API:**
Now, the example usage would look like:
```
ddp_model = DistributedDataParallel(...)
zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), ...)
with _Join([ddp_model, zero_optim]):
...
```
Any arguments meant for a join hook (e.g. `divide_by_initial_world_size`) must be specified as keyword arguments. For example:
```
with _Join([ddp_model, zero_optim], divide_by_initial_world_size=False):
...
```
They will be forwarded to every `_join_hook()` function via `**kwargs`. This creates a clear separation between the variables needed by the context manager (`enable` and `throw_on_early_termination`) and those needed by the `_Joinable` class (e.g. `divide_by_initial_world_size`).
**Recap:**
After this change, the relevant information to use the generic join context manager looks like the following (omitting prefix `_` from names):
- Suppose we have a class `C` (e.g. `DistributedDataParallel`) that we want to be able to use the `Join` context.
- We make `C` inherit from `Joinable` and implement `join_hook() -> JoinHook`, `join_device()`, and `join_process_group()`.
- To implement `join_hook()`, we define a `CJoinHook` class inheriting from `JoinHook` and implement `main_hook()` and `post_hook()` as needed.
- We locate a place before `C`'s per-iteration collective communications and add a call to `Join.notify_join_context()`.
- We call `Joinable.__init__(self)` in `C`'s constructor.
- The `C.join_config` field will be used internally by the context manager. This does not affect `C`'s serializability.
- Run time arguments for `C`'s join hook can be passed in as keyword arguments to the context manager: `with Join([C()], arg1=..., arg2=...):`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61555
Test Plan:
I ran the existing DDP join tests:
```
touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py -- TestDistBackendWithFork.test_ddp_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_inputs_stop_iteration_sync_bn TestDistBackendWithFork.test_ddp_grad_div_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_input_join_disable TestDistBackendWithFork.test_ddp_uneven_input_exception
```
I ran the ZeRO join tests:
```
gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py TestZeroRedundancyOptimizerDistributed.test_zero_join_gpu TestZeroRedundancyOptimizerDistributed.test_zero_join_cpu
```
Reviewed By: zou3519
Differential Revision: D29690359
Pulled By: andwgu
fbshipit-source-id: 2950f78de755eb5fb13b95b803dd7c705879a9c7
Summary:
**Overview:**
The existing `ZeroRedundancyOptimizer` implementation assumes that all model parameters are stored on the same device (due to the recent [refactor](https://github.com/pytorch/pytorch/pull/59834)). This change allows model parameters to be sharded across multiple devices, as in the DDP with Model Parallelism example [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
The only logic affected is the bucketing strategy used when `parameters_as_bucket_view=True`. Let `n` denote the world size and `k` denote the number of devices per process.
- Previously, `k = 1`, and `self._buckets` was a `List[torch.Tensor]`, where `self._buckets[j]` is a tensor (i.e. bucket) containing the parameters assigned to rank `j` for `j = 0, ..., n - 1`.
- Now, `self._buckets` is a `List[List[torch.Tensor]]`, where `self._buckets[i][j]` is a tensor containing the parameters stored on device `i` assigned to rank `j` for `i = 0, ..., k - 1` and `j = 0, ..., n - 1`.
This bucket construction uses an auxiliary data structure `self._device_to_per_rank_params`, which is a `Dict[torch.device, List[List[torch.Tensor]]]`. It maps:
- `dev_0` to `[rank 0's assigned parameters on dev_0, rank 1's assigned parameters on dev_1, ...]`,
- `...`
- `dev_{k-1}` to `[rank 0's assigned parameters on dev_{k-1}, rank 1's assigned parameters on dev_{k-1}, ...]`
I removed the invariant checker `_verify_same_param_device()` and its corresponding test since it is no longer an invariant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61370
Test Plan: I added a new test `test_zero_model_parallel()` that checks for parity between a DDP model with model parallelism using `ZeroRedundancyOptimizer` and a local model with the same architecture using a local optimizer. I also verified that the existing tests still pass.
Reviewed By: soulitzer
Differential Revision: D29637132
Pulled By: andwgu
fbshipit-source-id: 07112959fa4e94a3f40e67e88cbb58ce3cd1e033
Summary:
Targets https://github.com/pytorch/pytorch/issues/54318.
**Overview:**
DDP offers a `join()` context manager to accommodate training on uneven inputs. This creates a new generic `_Join()` API permitting custom hooks, refactors DDP `join()` to call this generic `_Join()`, and implements a hook for ZeRO. (For now, the generic `_Join()` is implemented as private, but this may change after design discussions are cleared.)
There are two classes introduced: `_JoinHook`, the class defining the customizable join hook, and `_Join`, the generic join context manager.
The `_JoinHook` provides two entry points: `main_hook()`, which is called repeatedly while there exists a non-joined process, and `post_hook()`, which is called once all process have joined with the additional `bool` argument `is_last_joiner`. The class also requires `process_group` and `device` information by defining corresponding abstract property methods. Thus, to implement a join hook, (1) inherit from `_JoinHook`, (2) override `main_hook()` and `post_hook()` as appropriate, and (3) override `process_group()` and `device()` to provide process group and device information to be used by the join context manager implementation for collective communications.
The `_Join` constructor requires `join_hooks: List[_JoinHook]` and optionally `enable: bool = True` and `throw_on_early_termination: bool = False`. A training loop only needs to be wrapped with `with _Join(join_hooks):` (using the appropriate `join_hooks`) to be able to train on uneven inputs without hanging/erroring. The context manager requires a `dist.all_reduce(torch.ones(1))` to be called on every non-joined process each time before it performs its collective communications in order to indicate that the process has not yet joined. It also requires that all `process_group` attributes in the `_JoinHook` objects are the same.
**Notes:**
- The argument `is_last_joiner` to `post_hook()` may be useful for finding an authoritative rank when synchronizing.
- `enable` is a flag that can be set to `False` if the user knows the current training loop will not have uneven inputs. This may be used to disable join-related computation in the classes providing join hooks.
- `throw_on_early_termination` is a flag that can be set to `True` to notify processes to terminate upon detecting uneven inputs (i.e. upon the first process joining when there exists a non-joined process). Notably, the notification requires an all-reduce, so to prevent hanging/erroring, non-joined process must participate in the all-reduce. The first-joining process raises a `RuntimeError`, and the other processes are expected (but not required) to do the same. This may be used to implement training on uneven inputs in cases that do not conform to the generic join context manager (e.g. `SyncBatchNorm`).
- Classes providing a join hook should do so via a `_join_hook()` method that returns a `_JoinHook` instance with the methods appropriately overridden.
- If there are multiple join hooks, the device specified by the first is used by the join context manager implementation to perform its collective communications.
- If there are multiple join hooks, both the main and post-hooks are iterated in the order in which the `_JoinHook` objects are passed into the context manager constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60757
Test Plan:
The current implementation preserves backward compatibility by not changing the existing DDP `join()` API at all. To check this, I ran through the uneven input tests (`test_ddp_grad_div_uneven_inputs`, `test_ddp_uneven_inputs_stop_iteration_sync_bn`, `test_ddp_uneven_inputs`, `test_ddp_uneven_input_join_disable`, `test_ddp_uneven_input_exception`) on the AI AWS cluster:
```
touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py --
```
Because the existing DDP join logic does not provide correct gradients to the joined processes if `gradient_as_bucket_view=False` and a joined process requires those gradients to correctly update its shard of the parameters in `ZeroRedundancyOptimizer.step()`, DDP and ZeRO are not fully compatible at the moment. To work around this and to test ZeRO's join hook separately, I added a test `_test_zero_join()` (with `test_zero_join_gpu()` and `test_zero_join_cpu()` flavors), which compares DDP with a local optimizer on uneven inputs against ZeRO on uneven inputs with the gradients set manually.
Reviewed By: iramazanli, mrshenli
Differential Revision: D29624636
Pulled By: andwgu
fbshipit-source-id: ec70a290e02518b0d8b683f9fed2126705b896c7
Summary:
**Overview:**
This is a follow-up to [this PR](https://github.com/pytorch/pytorch/pull/59586) and corrects the ZeRO partitioning algorithm to sort by the number of elements in the tensor rather than the size of the first dimension. As context, that PR was meant to migrate from using a _naive greedy_ algorithm to a _sorted-greedy_ algorithm when partitioning parameters in ZeRO.
**Updated Results:**
The updated table for the partitions can be found [here](https://github.com/pytorch/pytorch/pull/59410#issuecomment-865203219). There, I also considered a third algorithm (sometimes known as multifit), which is more computationally expensive than the greedy and sorted-greedy algorithms but cannot perform worse. However, because of its increased complexity and lack of improved results, I chose to settle with the simpler sorted-greedy algorithm.
The `step()` latencies show slight improvements, but the improvements may be in the noise. The values below are in seconds and were generated using NCCL backend (unlike in the previous PR which used Gloo):
Two processes:
| Model | Max `optimizer.step()` Time - Greedy (Std.) | Max `optimizer.step()` Time - Sorted-Greedy (Std.) |
| --- | --- | --- |
| ResNet-50 | 0.047 (0.00142) | **0.044 (0.00025)** |
| ResNet-152 | 0.057 (0.00034) | **0.054 (0.00022)** |
| BERT | 0.021 (0.00008) | **0.020 (0.00008)** |
Four processes:
| Model | Max `optimizer.step()` Time - Greedy | Max `optimizer.step()` Time - Sorted-Greedy (Std.) |
| --- | --- | --- |
| ResNet-50 | 0.019 (0.00065) | **0.013 (0.00040)** |
| ResNet-152 | 0.045 (0.00024) | 0.045 (0.00025) |
| BERT | 0.019 (0.00022) | **0.018 (0.00016)** |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60556
Test Plan:
I verified that the ZeRO tests pass (via the AI AWS cluster):
```
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4 python test/distributed/optim/test_zero_redundancy_optimizer.py
```
Reviewed By: VitalyFedyunin
Differential Revision: D29335260
Pulled By: andwgu
fbshipit-source-id: 469d1c6e029b77c1b300a94cd1fd94b633cd28dd
Summary:
**Overview:**
Being relatively new to PyTorch and ZeRO, I found parts of the code slightly hard to follow. This change strives to clean up the `ZeroRedundancyOptimizer` code in `zero_redundancy_optimizer.py` by reorganizing some computations, making variable names more explicit and consistent, and unifying terminology in the documentation. The goal is for the code to be easier to extend afterwards.
**Changes:**
1) `state_dict()`: The [logic](85517a2b70/torch/distributed/optim/zero_redundancy_optimizer.py (L510)) for updating the global `state_dict` with each rank's local `state_dict` is simplified and made more explicit. Notably, the `dict` [`local_index_to_param_id`](85517a2b70/torch/distributed/optim/zero_redundancy_optimizer.py (L513)) is unneeded. It maps `local_pg["params"][i]` to `id(global_pg["params"][i])`, so it is equivalent to make a single pass over both lists in tandem, effectively iterating over `i`, without a need for the explicit `dict`.
2) `_update_trainable()`: The function [initializes](85517a2b70/torch/distributed/optim/zero_redundancy_optimizer.py (L597)) the local optimizer if it does not exist. I am unaware of any reason for the local optimizer to be destroyed after initialization, so I moved that logic to its own function `_init_local_optimizer()`, which is called once in the constructor.
After [discussion](https://github.com/pytorch/pytorch/pull/60285#discussion_r654706728), I removed the function `_update_trainable()` itself in favor of adding a check for `parameters_as_bucket_view` in `build_param_buckets()` directly.
3) `rank_local_state_dict()`: This [function](85517a2b70/torch/distributed/optim/zero_redundancy_optimizer.py (L528)) is currently broken. It appears to be legacy and relies on the input `state_dict` to have the key `"partitions"`. For now, I have removed it and added an [issue](https://github.com/pytorch/pytorch/issues/60284). Is it a notable use case to want to access another rank's `state_dict` in particular (as opposed to consolidating the entire state and then accessing)?
4) `local_state_dict():` After [discussion](https://github.com/pytorch/pytorch/pull/60285#discussion_r655571043), I removed the function.
5) `partition_parameters()`: After [discussion](https://github.com/pytorch/pytorch/pull/60285#discussion_r654708183), I renamed the function to `_partition_parameters()` to mark it as private.
6) `_param_to_index`: After [discussion](https://github.com/pytorch/pytorch/pull/60285#discussion_r654828100), I changed the key to be the parameter itself rather than its integer ID.
7) `buckets`: I renamed the data structure to `_buckets` to mark it as private.
8) Terminology: I tried to reduce the set of terms being used instead of juggling a number of synonyms. In particular, I made an effort to distinguish between "local" and "global" and to make names more indicative of typing.
9) Style: Per the [PyTorch contributing guide](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md#writing-documentation), I made all docstrings abide by the 80 character limit, except for the one [line](554891f6fa/torch/distributed/optim/zero_redundancy_optimizer.py (L142)) showing the example ZeRO usage. Some code lines violate the limit for readability. Also, I unified some of the minor stylistic usages out of habit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60285
Test Plan:
The test suite passes as expected (on the AI AWS cluster):
```
gpurun python test/distributed/optim/test_zero_redundancy_optimizer.py
```
I visually inspected the generated HTML doc (as generated following [this](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md#writing-documentation)).
Reviewed By: mrshenli
Differential Revision: D29320726
Pulled By: andwgu
fbshipit-source-id: 23f69a19ecc5e877a38fe1df0da11329428311dd
Summary:
**Overview:**
This refactors the `ZeroRedundancyOptimizer` implementation to assume single-process single-device (SPSD) instead of accommodating single-process multiple-device (SPMD). `DistributedDataParallel` [retired SPMD recently](https://github.com/pytorch/pytorch/issues/47012), so this change follows the same spirit.
**Changes:**
The parent-class `Optimizer` constructor permits the input argument `params` to be both an `iterable` of `torch.Tensor` and an `iterable` of `dict`. The latter usage is for initializing the optimizer with multiple `param_group`s to start. However, currently, `ZeroRedundancyOptimizer` only supports the former usage, requiring explicit calls to `add_param_group()` for multiple `param_group`s. Given the existing implementation, the type error would be silent and not manifest until much later (e.g. since `super().__init__()` would have no issue). Hence, I added a series of checks to begin the `__init__()` function (encapsulated in `_verify_and_init_params()`). A postcondition of this validation is that `self._all_params` is a non-empty list of all model parameters.
Additionally, I added a check for SPSD usage assuming that all model parameters exist on the same device. This logic is included in `_verify_same_param_device()` and is called immediately after the `params` type-checking. Support for SPSD with model parameters sharded across devices may be added in the future.
Related to that aforementioned post-condition on `self._all_params`, previously there was undefined behavior resulting from different typing of the passed in `params` input argument. If `params` was a `List`, then the usage of `self._reference_is_trainable_mask` was as expected. However, if `params` was a generator (e.g. as in the canonical usage of passing `model.parameters()`), then the ensuing behavior was divergent. This is because after a generator is iterated over, it is empty. As a result, when we set `self._all_params = params` [in the old code](68d690ffbd/torch/distributed/optim/zero_redundancy_optimizer.py (L165)), `self._all_params` is empty, reducing `training_mask` to always be the empty list. This causes missed calls to `_update_trainable()` in `step()`. (A consequence of this is that `test_pytorch_parity()`, which is renamed to `test_local_optimizer_parity()`, now outputs warnings about the trainable parameters changing.)
The existing implementation assumes that all parameters share the same dense type when allocating the bucket buffers. This change preserves this assumption, which may be removed in the future. I added a check for this in `_verify_same_dense_param_type()` to avoid erroring silently later on. Note that it is insufficient to simply check for the same `dtype` since dense and sparse tensors may share the same `dtype` but require differing storage sizes. One solution is to use `torch.typename()` as the means for comparison.
---
The primary change in this refactor is with respect to `self._per_device_params` and `self.buckets`. `self._per_device_params` mapped `torch.device` to `List[List[Parameter]]`. The keys were the devices that the model parameters exist on, and the values designated which ranks are assigned to updating those parameters. `self.buckets` mapped `torch.device` to `List[torch.Tensor]`. The keys were the same as `self._per_device_params`, and the values were the buckets for that device. The usage of these two data structures were confined to each other only. Hence, because the notions of device and rank are now in 1:1 correspondence, we can eliminate the former completely and only use rank. As such, I removed `self._per_device_params` and made `self.buckets` directly a list of buckets (i.e. `torch.Tensor`s).
Iteration over the parameters of a rank for a given device could be simplified to just iteration over the parameters of a rank. Hence, I relied on `self.partition_parameters()` now for that iteration. Refer to `_setup_flat_buffers()` and `step()` for these changes.
One convenient side effect of removing `self._per_device_params` is that there is no longer the re-computation of the parameter partitions mentioned at the end of this [PR](https://github.com/pytorch/pytorch/pull/59410).
---
I changed the data structure `self._index_to_param_cache` from a `dict` to a `List` because the domain is `0`, `1`, ..., `k-1` where `k` is the number of parameters. This should yield marginal improvements in memory usage and access speed.
`_sync_param_groups()` is a static method, meaning it can be called either via `self._sync_param_groups()` or `ZeroRedundancyOptimizer._sync_param_groups()` when inside the class. I made the usage consistently `self._sync_param_groups()` rather than have instances of both.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59834
Test Plan:
I ran through the existing test suite on an AI AWS cluster:
```
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4 python test/distributed/optim/test_zero_redundancy_optimizer.py
```
Note: The only test where `parameters_as_bucket_view` is `True` is `test_step_with_closure()`, meaning that that is the test that exercises the core changes of removing `self._per_device_params` and changing `self.buckets`.
Also, I added tests for the `ZeroRedundancyOptimizer` constructor changes and the assumption checks.
Reviewed By: mrshenli
Differential Revision: D29177065
Pulled By: andwgu
fbshipit-source-id: 0ff004ae3959d6d3b521024028c7156bfddc93d8
Summary:
Addresses https://github.com/pytorch/pytorch/issues/59548
**Overview:**
Recently, we changed ZeRO's partitioning algorithm to first sort the parameters by decreasing size and then greedily allocate to shards. See [here](ea1de87f4b).
The current tests `test_sharding()` and `test_add_param_group()` check for a uniform partitioning, which is not achieved with the old naive greedy partitioning algorithm for general world sizes but is achieved with the new sorted-greedy algorithm. This reliance is not ideal, but for now, we opt to simply add comments to document the dependency.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59713
Test Plan:
I tested for world sizes of 1, 2, 3, and 4 via the AI AWS cluster:
```
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=1 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_sharding
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=2 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_sharding
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=3 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_sharding
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_sharding
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=1 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_add_param_group
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=2 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_add_param_group
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=3 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_add_param_group
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4 python test/distributed/optim/test_zero_redundancy_optimizer.py -- TestZeroRedundancyOptimizerDistributed.test_add_param_group
```
However, because the train queue (which offers instances with 8 GPUs) is not working at the moment, I was unable to test for world sizes of 5+. Nonetheless, I believe that they should still work.
First, consider `test_sharding()`. Given the sorted-greedy algorithm, each shard will be assigned one of the parameters with size `9`, then one of the parameters with size `7`, then `5`, and finally `3`. Hence, each will have a uniform partition. Now, consider `test_add_param_group()`. Similarly, the same allocation behavior occurs, only the last shard is not assigned the final parameter with size `3` to begin. However, after adding the new `param_group` with the parameter with size `3`, a re-partitioning occurs. The first `param_group` is partitioned as before, and the parameter with size `3` in the new `param_group` is assigned to the last shard since it has the minimal total size. Thus, in the end, all shards have a uniform partition.
Reviewed By: mrshenli
Differential Revision: D28996460
Pulled By: andwgu
fbshipit-source-id: 22bdc638d8569ed9a20836812eac046d628d6df2
Summary:
Pull Request: https://github.com/pytorch/pytorch/pull/59586
Task: https://www.internalfb.com/tasks/?t=90847711
**Overview:**
Suppose we have `n` items with positive integer sizes and `k` buckets. We want to assign items to buckets with the goal of uniformity. The precise criteria for uniformity can vary: e.g. minimize the maximum size, maximize the minimum size, etc. This is known as [multiway number partitioning](https://en.wikipedia.org/wiki/Multiway_number_partitioning). ZeRO's partitioning task reduces to solving this problem. In particular, this is the subproblem to be solved for each `param_group` in `self.param_groups`, where the parameters are the items and the ranks give the buckets.
The existing implementation uses the linear-time [greedy number partitioning algorithm](https://en.wikipedia.org/wiki/Greedy_number_partitioning#Linear-time_algorithm), which assigns the next tensor-parameter to the process with the smallest total parameter size so far. In this task, I explore the [extension](https://en.wikipedia.org/wiki/Greedy_number_partitioning#Improved_algorithm) where each parameter group is sorted by decreasing size before applying the greedy algorithm, requiring linearithmic time (as dominated by the sort).
**Experiments**
The mean number of parameters represents a perfectly uniform allocation and hence the ideal allocation (which may be even better than the optimal partition). In the following tables, I present the maximum number of parameters for any one process and the difference from the mean in parentheses for ResNet-50, ResNet-152, and BERT (the bare BERT model). The best-performing partitioning strategy for each model is bolded.
Two processes:
| Model | Max Num Params - Greedy (Diff) | Max Num Params - Greedy-Sorted (Diff) | Mean Num Params |
| --- | --- | --- | --- |
| ResNet-50 | 13,249,600 (471,084) | **12,794,816 (16,300)** | 12,778,516 |
| ResNet-152 | 30,567,488 (471,084) | **30,111,424 (15,020)** | 30,096,404 |
| BERT | **54,749,184 (8,064)** | 55,327,488 (586,368) | 54,741,120 |
Four processes:
| Model | Max Num Params - Greedy (Diff) | Max Num Params - Greedy-Sorted (Diff) | Mean Num Params |
| --- | --- | --- | --- |
| ResNet-50 | 7,524,864 (1,135,606) | **6,436,864 (47,606)** | 6,389,258 |
| ResNet-152 | 16,232,192 (1,183,990) | **15,090,152 (41,950)** | 15,048,202 |
| BERT | **28,151,040 (780,480)** | 28,352,256 (981,696) | 27,370,560 |
---
I also investigated the latency of `optimizer.step()` for the different partitioning algorithms. I measured the latency for 30 iterations and took the mean latency per process (excluding the first iteration due to cache coldness). In the following tables, I present the maximum of those mean latencies over all processes and the standard deviation of the latencies contributing to that maximum. Again, the best-performing partitioning strategy for each model is bolded. All entries are presented in seconds and used `gloo` backend.
Two processes:
| Model | Max `optimizer.step()` Time - Greedy (Std.) | Max `optimizer.step()` Time - Greedy-Sorted (Std.) |
| --- | --- | --- |
| ResNet-50 | **0.060 (0.002)** | 0.061 (0.002) |
| ResNet-152 | 0.166 (0.003) | **0.160 (0.004)** |
| BERT | 0.220 (0.009) | **0.199 (0.006)** |
Four processes:
| Model | Max `optimizer.step()` Time - Greedy | Max `optimizer.step()` Time - Greedy-Sorted |
| --- | --- | --- |
| ResNet-50 | 0.094 (0.004) | **0.093 (0.004)** |
| ResNet-152 | **0.228 (0.011)** | 0.231 (0.009) |
| BERT | **0.328 (0.015)** | 0.329 (0.021) |
Based on the standard deviations, the differences in the latency measurements across the different algorithms appear to be within the uncertainty in the measurement itself. Hence, it is difficult to argue that one algorithm is clearly the fastest.
---
`zero.py` is my experiment script, and I use the AI AWS cluster. The run command looks like:
```
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4 python zero.py -b nccl greedy 2 4
```
This runs the experiment script on an instance with 4 GPUs using `nccl` backend, outputting to a directory named `greedy/`, and using world sizes of 2 and 4. An analogous command can be used after modifying `partition_parameters()`, e.g. replacing `greedy` with `greedy_sorted` as the output directory name. Then, to run the analysis script:
```
python analyze.py greedy greedy_sorted
```
For more details on the experiment code, refer to: https://www.internalfb.com/diff/D28946756
**Notes:**
There exists an optimal solution to this partitioning problem. An algorithm that finds such a solution is the [complete greedy algorithm (CGA)](https://en.wikipedia.org/wiki/Greedy_number_partitioning#An_exact_algorithm), which reduces to the brute-force combinatorial search in the worst case. There exist heuristics to improve the `k = 2` case (i.e. when there are two processes); however, given that `n` in typical use cases is very large, any algorithm that is quadratic or slower is unrealistic. Other exact algorithms are similarly exponential in the worst case, rendering them intractable. Given this, I do not currently see a need for future proofing the partitioning algorithm against the introduction of algorithms beyond the naive greedy and the sorted greedy algorithms.
---
In the current ZeRO implementation, the core `partition_parameters()` computation happens twice upon initialization (i.e. call to `__init__()`): first from a call to `_param_to_rank()` (i.e. an access to `_param_to_rank`) and then from a call to `_update_trainable()`. `_update_trainable()` sees that no optimizer has been constructed yet, so it clears the cache, eliminating the first `partition_parameters()` computation and performing a redundant re-computation.
Here is a typical trace:
- [The ZeRO optimizer object is initialized, calling `__init__()`.](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L142))
- [In `__init__()`, `self._device` is set, so it accesses `self._per_device_params`.](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L182))
- [`self._per_device_params` is not cached, so it accesses `self._param_to_rank`.](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L340))
- [`self._param_to_rank` is not cached, so it calls `partition_parameters()`.](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L353)) (first call to `partition_parameters()`)
- [`__init__()` later calls `_update_trainable()`.](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L185))
- [In `_update_trainable()`, `self` does not have `attr` `"optim"`, so it clears the cached objects (notably, `self._partition_parameters_cache`).](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L591))
- [`_update_trainable()` calls `self.partition_parameters()`.](d125694d0b/torch/distributed/optim/zero_redundancy_optimizer.py (L593)) (second call to `partition_parameters()`)
Based on the discussion [here](https://github.com/pytorch/pytorch/pull/59410), this recomputation is unintentional and should be addressed in a future diff.
Test Plan: I verified that the total number of parameters across the processes was consistent after the partitioning algorithm change. Otherwise, no additional modifications were made to existing tests.
Reviewed By: mrshenli
Differential Revision: D28946755
fbshipit-source-id: 7ad66a21a963555b3b2e693ba8069d2dddc94c60