Currently, the only way ProcessGroupNCCL shuts down its background threads and aborts all communicators is via the destructor.
However, given how python GC works and code holding references to the PG in multiple places, in practice calling `destroy_process_group` doesn't actually end up invoking the destructor.
As a result, in this PR I'm adding a explicit shutdown method to that users can call to cleanup all resources.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111392
Approved by: https://github.com/XilunWu, https://github.com/wanchaol, https://github.com/fduwjj
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.
This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
...
if "The client socket has timed out after" in exception_str:
...
if "Broken pipe" in exception_str:
...
if "Connection reset by peer" in exception_str:
...
```
To address this issue, in this PR I've ensured added these error types:
1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108191
Approved by: https://github.com/H-Huang
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.
This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
...
if "The client socket has timed out after" in exception_str:
...
if "Broken pipe" in exception_str:
...
if "Connection reset by peer" in exception_str:
...
```
To address this issue, in this PR I've ensured added these error types:
1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107651
Approved by: https://github.com/H-Huang
https://github.com/pytorch/pytorch/pull/95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior.
However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op.
To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`.
I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103925
Approved by: https://github.com/osalpekar
This reverts commit 03881b0c925f191ec41d6899d589ed420ac285b5.
Reverted https://github.com/pytorch/pytorch/pull/103264 on behalf of https://github.com/osalpekar due to This commits seems to have been causing failures in test_nccl_init_abort. Those failures may have been masked by pre-existing failures in the distributed jobs on trunk when running CI on this PR. Since those breaking changes are now reverted, we should be able to rebase this and get clean signal + uncover the breakages caused by this PR. ([comment](https://github.com/pytorch/pytorch/pull/103264#issuecomment-1599451197))
https://github.com/pytorch/pytorch/pull/95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior.
However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op.
To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`.
I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103264
Approved by: https://github.com/kwen2501
Adds support for multiple forward before bwd call for
static_graph=True.
There are 2 changes:
1) Change tracking of accounting of when to populate static grap related maps
from relying on forward iteration to backward calls
2) In DDP python, don't rely on num_forward iterations == 1 to enqueue the
delay allreduce. Instead use a flag.
Differential Revision: [D46673736](https://our.internmc.facebook.com/intern/diff/D46673736/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103487
Approved by: https://github.com/awgu
Summary: In cases where DDP backward is not finalized, the error is raised only in the next forward iteration of DDP. However, if there are other collective calls between those two points, training scripts could potentially get stuck.
As a result, there should be a way to check if DDP finalized after calling `.backward()`. To address this, I've added a `_check_reducer_finalized` method to validate that DDP indeed did successfully finish reduction.
Test Plan: Added unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100773
Approved by: https://github.com/rohan-varma
[BE] `require_backend_is_available` offers the a more thorough check as `require_backend` but both are often used together. This remove `require_backend` and centralizes on the `require_backend_is_available` decorator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101891
Approved by: https://github.com/awgu
## TLDR
Fix decorator to re-enable 26+ distributed tests that were previously being skipped in CI
## Explanation
As part of the UCC upstream, we updated the backend tests cases to also include "ucc".
3ed1569e86/torch/testing/_internal/common_distributed.py (L90-L92)
In distributed tests we use a decorator which reads from this config and makes sure all backends are available on the system.
3ed1569e86/torch/testing/_internal/distributed/distributed_test.py (L7131)
**However**, UCC is not configured on by default for a certain subset of CI tests, which causes the entire test to be skipped (even if the test is meant for nccl and the backend being tested is nccl).
As the fix, we should just check that only the `BACKEND` being tested is available
## Changes
- Change logic to only check if the current backend being used is available
- Rename `require_backends_available` -> `require_backend_is_available`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101704
Approved by: https://github.com/rohan-varma
This is a mirror PR of D45339293
Summary:
These tests cause the following errors internally with unknown reason:
```
AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adam'
AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw'
AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_sgd'
```
Commenting these tests out to unblock other PRs.
Test Plan: Sandcastle
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100215
Approved by: https://github.com/wz337, https://github.com/fduwjj
This reverts commit ae40a6c7356190ef86b14b10a94a58ca41ca496b.
Reverted https://github.com/pytorch/pytorch/pull/100215 on behalf of https://github.com/huydhn due to Sorry for revert your change, but it breaks lint, please run lintrunner -a torch/testing/_internal/distributed/distributed_test.py to fix the issue then reland it
This is a mirror PR of D45339293
Summary:
These tests cause the following errors internally with unknown reason:
```
AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adam'
AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw'
AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_sgd'
```
Commenting these tests out to unblock other PRs.
Test Plan: Sandcastle
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100215
Approved by: https://github.com/wz337, https://github.com/fduwjj
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.
By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
for i in range(num_coll):
dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
for i in range(num_coll):
dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.
In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.
### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us
Hence a 4x reduction in CPU overhead (dependent on `num_coll`).
Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.
By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
for i in range(num_coll):
dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
for i in range(num_coll):
dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.
In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.
### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us
Hence a 4x reduction in CPU overhead (dependent on `num_coll`).
Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
This PR fixes https://github.com/pytorch/pytorch/issues/96203.
**Details**
When using `nn.SyncBatchNorm` with the model converted to FP16, there is a dtype discrepancy in the `SyncBatchNorm.forward()` causing an error like:
```
File "/.../pytorch/torch/nn/modules/_functions.py", line 91, in forward
mean, invstd = torch.batch_norm_gather_stats_with_counts(
RuntimeError: Expected counts to have type Half but got Float
```
[`torch.batch_norm_gather_stats_with_counts()`](fe9da29842/torch/nn/modules/_functions.py (L88-L97)) requires the `running_mean`, `running_var`, and `counts` to have the same dtype. However, when the model has been converted to FP16, only `running_mean` and `running_var` use FP16, while the `counts` are in FP32 due to [`mean` being in FP32](fe9da29842/torch/nn/modules/_functions.py (L25-L30)). This PR resolves this by casting `counts` from FP32 to FP16 instead of the alternative to cast `mean` and `invstd` from FP32 to FP16.
Moreover, for the backward, this PR casts `weight` from FP16 to FP32 to match the dtype of `mean` and `invstd` as required by `torch.batch_norm_backward_elemt()` instead of the alternative to cast `mean` and `invstd` from FP32 to FP16.
**Test Plan**
I dug up this run command from 2021:
For `world_size` in `{1,2}` and `backend` in `{nccl, gloo}`:
```
WORLD_SIZE=world_size BACKEND=backend python -m pytest test/distributed/test_distributed_spawn.py -k test_DistributedDataParallel_SyncBatchNorm_half -vs
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98332
Approved by: https://github.com/rohan-varma
Fixes#97191
This PR aims to propagate collective exceptions (async error or timeout) up to the program, so as to avoid silent stuck job.
### Previous output in #97191
```
Rank 0 is the problematic rank
Rank 4 completed
Rank 5 completed
Rank 3 completed
Rank 6 completed
Rank 2 completed
Rank 7 completed
Rank 1 completed
[E ProcessGroupNCCL.cpp:464] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10917 milliseconds before timing out.
Rank 0 completed
[E ProcessGroupNCCL.cpp:478] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:483] To avoid data inconsistency, we are taking the entire process down.
```
Although it says that it is taking the process down, it sometimes fails to do so.
### New output after this PR:
```
...
[E ProcessGroupNCCL.cpp:459] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:473] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:479] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:818] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 194470) of binary: /data/home/kw2501/repos/pytorch-dev-env/bin/python
Traceback (most recent call last):
File "/pytorch-dev-env/bin/torchrun", line 33, in <module>
sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
File "/pytorch-dev/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/pytorch-dev/torch/distributed/run.py", line 794, in main
run(args)
File "/pytorch-dev/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/pytorch-dev/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/pytorch-dev/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
hang.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2023-03-20_22:00:42
host : node0
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 194470)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 194470
============================================================
```
The log suggests that TorchX monitor is triggered, and job is torn down.
### Major changes in this PR:
1. Merge ncclWatchDog thread and workCleanupLoop thread into one so that the watch action and the throw action are streamlined.
Previously, ncclWatchDog is responsible for watching comm error and timeout, and workCleanupLoop is responsible for watching Work item error and throwing exception. This two-thread design is not streamlined, raising the chance of missing the throw. Also, it is duplicated to watch at multiple level.
2. Rethrow exception at watchdog thread.
3. Clean up a bunch of duplicated functions, e.g. `checkAndThrowException` and `handleNcclException`.
4. Turn on ASYNC_ERROR_HANDLING by default
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97066
Approved by: https://github.com/rohan-varma
Fixes#97191
This PR aims to propagate collective exceptions (async error or timeout) up to the program, so as to avoid silent stuck job.
### Previous output in #97191
```
Rank 0 is the problematic rank
Rank 4 completed
Rank 5 completed
Rank 3 completed
Rank 6 completed
Rank 2 completed
Rank 7 completed
Rank 1 completed
[E ProcessGroupNCCL.cpp:464] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10917 milliseconds before timing out.
Rank 0 completed
[E ProcessGroupNCCL.cpp:478] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:483] To avoid data inconsistency, we are taking the entire process down.
```
Although it says that it is taking the process down, it sometimes fails to do so.
### New output after this PR:
```
...
[E ProcessGroupNCCL.cpp:459] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:473] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:479] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:818] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 194470) of binary: /data/home/kw2501/repos/pytorch-dev-env/bin/python
Traceback (most recent call last):
File "/pytorch-dev-env/bin/torchrun", line 33, in <module>
sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
File "/pytorch-dev/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/pytorch-dev/torch/distributed/run.py", line 794, in main
run(args)
File "/pytorch-dev/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/pytorch-dev/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/pytorch-dev/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
hang.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2023-03-20_22:00:42
host : node0
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 194470)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 194470
============================================================
```
The log suggests that TorchX monitor is triggered, and job is torn down.
### Major changes in this PR:
1. Merge ncclWatchDog thread and workCleanupLoop thread into one so that the watch action and the throw action are streamlined.
Previously, ncclWatchDog is responsible for watching comm error and timeout, and workCleanupLoop is responsible for watching Work item error and throwing exception. This two-thread design is not streamlined, raising the chance of missing the throw. Also, it is duplicated to watch at multiple level.
2. Rethrow exception at watchdog thread.
3. Clean up a bunch of duplicated functions, e.g. `checkAndThrowException` and `handleNcclException`.
4. Turn on ASYNC_ERROR_HANDLING by default
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97066
Approved by: https://github.com/rohan-varma
Summary:
When creating a new DDP instance for the same model when an old DDP instance existed, the autograd hooks from the old DDP instance might not be cleared. Also, relying on python gc to clear out old autograd hooks is fragile and may not work 100% of the time.
As a result, in this PR I'm adding a way to explicitly remove these hooks from DDP
Test Plan:
Unit test added
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96490
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:
1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously.
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision.
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.
Follow-ups:
1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads,
3. allow BN, LN, or custom units to run in reduced precision,
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order.
8. Entirely unused modules probably don't need to be cast.
Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92882
Approved by: https://github.com/zhaojuanmao
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
This PR:
- Updates the docs to say it is deprecated
- Raises a UserWarning
- Changes most of the callsites inside PyTorch to use
torch.func.functional_call, minus the test_stateless testing.
The motivation behind this is that we can now align behind a single
functional_call API in PyTorch.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92280
Approved by: https://github.com/albanD
Allow _apply_optim_in_backward to work with DDP.
Example:
```
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
e = enc().cuda(rank)
_apply_optimizer_in_backward(
optimizer_class=torch.optim.SGD,
params=e.parameters(),
optimizer_kwargs={"lr": 0.03},
)
e = DDP(e, device_ids=[rank])
inp = torch.randn(1, 10, device=rank)
e(inp).sum().backward()
```
Constraints:
1. Custom communication hook not yet supported
2. _apply_optim_in_backward needs to be called _before_ wrapping model in DDP.
3. DDP will remove the gradient hooks _apply_optim_in_backward registers, so these gradient hooks will not be fired and cannot be used.
4. All DDP managed parameters have grads set to None by default once optimizer is applied. There is no support for setting only some parameter grads to None, this must be done manually by user (and DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0 needs to be set.)
Differential Revision: [D41329694](https://our.internmc.facebook.com/intern/diff/D41329694/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41329694/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89194
Approved by: https://github.com/zhaojuanmao