Refactor TORCH_DISTRIBUTED_DEBUG implementation (#73166)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73166

This PR refactors, cleans up, and optimizes the implementation of `TORCH_DISTRIBUTED_DEBUG`. It also introduces three new user APIs: `get_debug_level()`, `set_debug_level()`, and `set_debug_level_from_env()` to retrieve and modify the debug level after a process has started.
ghstack-source-id: 149778566

Test Plan: Run the existing unit tests.

Reviewed By: rohan-varma

Differential Revision: D34371226

fbshipit-source-id: e18443b411adcbaf39b2ec999178c198052fcd5b
(cherry picked from commit 26d6bb1584b83a0490d8b766482656a5887fa21d)
This commit is contained in:
Can Balioglu
2022-02-23 18:25:26 -08:00
committed by PyTorch MergeBot
parent 3d9706c464
commit e1db2f13ce
22 changed files with 197 additions and 118 deletions

View File

@ -1418,7 +1418,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
# Log information about the DDP and ZeRO bucketing
if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF:
if dist.get_debug_level() != dist.DebugLevel.OFF:
local_numel = sum(p.numel() for p in params)
num_assigned_buckets = len(self._bucket_assignments_per_rank[self.global_rank])
logger.info(