Files
pytorch/torch/_dynamo/distributed.py
Edward Z. Yang 17eb649d55 Implement guard collectives (optimized version) (#156562)
This is a remix of https://github.com/pytorch/pytorch/pull/155558

Instead of mediating guard collective via a config option, in this one it's done via a `set_stance` like API. The motivation is that checking for the config value on entry on torch.compile is apparently quite expensive, according to functorch_maml_omniglot. So this makes it a bit cheaper.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156562
Approved by: https://github.com/Microve
2025-06-24 04:59:49 +00:00

55 lines
1.6 KiB
Python

"""
Manages process groups for distributed compilation in TorchDynamo.
This module handles the initialization and management of process groups used for
distributed compilation. Key features:
- Lazy initialization of compilation process groups
- Only creates groups when distributed mode is enabled and available
- Integrates with compiler_collectives configuration setting
- Provides a single global process group for compilation coordination
The process group is created only when needed and if the distributed environment
is properly initialized, making it safe to import and use this module even in
non-distributed scenarios.
"""
from typing import Optional
import torch.distributed as dist
from . import config
_COMPILE_PG: Optional[dist.ProcessGroup] = None
_GUARD_PG: Optional[dist.ProcessGroup] = None
def get_compile_pg() -> Optional[dist.ProcessGroup]:
if (
config.enable_compiler_collectives
and dist.is_available()
and dist.is_initialized()
):
global _COMPILE_PG
if _COMPILE_PG is None:
# , timeout=datetime.timedelta(seconds=2)
_COMPILE_PG = dist.distributed_c10d._new_group_with_tag(
pg_tag="pt2_compile_pg"
)
return _COMPILE_PG
return None
# NB: Unlike get_compile_pg, this is only called when guard collectives were
# explicitly requested
def get_guard_pg() -> Optional[dist.ProcessGroup]:
if dist.is_available() and dist.is_initialized():
global _GUARD_PG
if _GUARD_PG is None:
_GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg")
return _GUARD_PG
return None