mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-31 04:04:57 +08:00
When running a distributed job with compiler collectives enabled, if one rank recompiles while others do not, this leads to a deadlock (as not everyone will rendezvous with the compiler collective from the recompile). Although there aren't any convenient ways to cheaply solve this problem, if you are willing to force everyone to sync when evaluating guards, you can just force everyone to recompile if anyone requires a recompile. So the way guard collectives work is: 1. Perform compiled code lookup (evaluating guards) 2. Run a collective, communicating if you found a compiled code or not 3. If anyone requires recompile, force everyone to recompile One current deficiency in the implementation is we can't conveniently track the time it takes to run this collective. I need to test if we actually successfully are running the collective on a separate stream, or if we have to wait for user collectives to all finish. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/155558 Approved by: https://github.com/Microve
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
"""
|
|
Manages process groups for distributed compilation in TorchDynamo.
|
|
|
|
This module handles the initialization and management of process groups used for
|
|
distributed compilation. Key features:
|
|
|
|
- Lazy initialization of compilation process groups
|
|
- Only creates groups when distributed mode is enabled and available
|
|
- Integrates with compiler_collectives configuration setting
|
|
- Provides a single global process group for compilation coordination
|
|
|
|
The process group is created only when needed and if the distributed environment
|
|
is properly initialized, making it safe to import and use this module even in
|
|
non-distributed scenarios.
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
import torch.distributed as dist
|
|
|
|
from . import config
|
|
|
|
|
|
_COMPILE_PG: Optional[dist.ProcessGroup] = None
|
|
_GUARD_PG: Optional[dist.ProcessGroup] = None
|
|
|
|
|
|
def get_compile_pg() -> Optional[dist.ProcessGroup]:
|
|
if (
|
|
config.enable_compiler_collectives
|
|
and dist.is_available()
|
|
and dist.is_initialized()
|
|
):
|
|
global _COMPILE_PG
|
|
if _COMPILE_PG is None:
|
|
# , timeout=datetime.timedelta(seconds=2)
|
|
_COMPILE_PG = dist.distributed_c10d._new_group_with_tag(
|
|
pg_tag="pt2_compile_pg"
|
|
)
|
|
return _COMPILE_PG
|
|
|
|
return None
|
|
|
|
|
|
def get_guard_pg() -> Optional[dist.ProcessGroup]:
|
|
if (
|
|
config.enable_guard_collectives
|
|
and dist.is_available()
|
|
and dist.is_initialized()
|
|
):
|
|
global _GUARD_PG
|
|
if _GUARD_PG is None:
|
|
_GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg")
|
|
return _GUARD_PG
|
|
|
|
return None
|